import os.path
from functools import cached_property
from pathlib import Path
from typing import Dict, List, Optional
import albumentations
import cv2
import numpy as np
import tifffile
from skimage.util import img_as_float32
from ..base import MaskDataset
[docs]class FRUNet(MaskDataset):
    """FRU-Net: Robust Segmentation of Small Extracellular Vesicles [1]_
    TEM images
    Parameters
    ----------
    root_dir : str
        Path to root directory
    output : {'both', 'image', 'mask'}, default: 'both'
        Change outputs. 'both' returns {'image': image, 'mask': mask}.
    transforms : albumentations.Compose, optional
        An instance of Compose (albumentations pkg) that defines augmentation in
        sequence.
    num_samples : int, optional
        Useful when ``transforms`` is set. Define the total length of the
        dataset. If it is set, it overwrites ``__len__``.
    normalize : bool, default: True
        Normalize each image by its maximum value and cast it to UINT8.
    Notes
    -----
    - Originally, dtype is UINT16
    - Max value is 20444, but contrast varies a lot. For example, some images
      have value less than 0.05 of 2^16, which makes images not visible.
      Normalization may be needed. Init param ``normalize`` is set to True by
      default for this reason. For each image, it calculates maximum value and
      divide by it.
    References
    ----------
    .. [1] E. Gómez-de-Mariscal, M. Maška, A. Kotrbová, V. Pospíchalová, P.
       Matula, and A. Muñoz-Barrutia, “Deep-Learning-Based Segmentation of Small
       Extracellular Vesicles in Transmission Electron Microscopy Images,”
       Scientific Reports, vol. 9, no. 1, Art. no. 1, Sep. 2019, doi:
       10.1038/s41598-019-49431-3.
    See Also
    --------
    MaskDataset : Super class
    Dataset : Base class
    DatasetInterface : Interface
    """
    # Dataset's acronym
    acronym = 'FRUNet'
    def __init__(
        self,
        root_dir: str,
        *,
        output: str = 'both',
        transforms: Optional[albumentations.Compose] = None,
        num_samples: Optional[int] = None,
        # specific to this dataset
        normalize: bool = True,
        **kwargs
    ):
        self._root_dir = os.path.join(root_dir, 'code', 'data')
        self._output = output
        self._transforms = transforms
        self._num_samples = num_samples
        # specific to this dataset
        self.normalize = normalize
[docs]    def get_image(self, p: Path) -> np.ndarray:
        tif = tifffile.imread(p)
        if self.normalize:
            v = tif.max()
            tif = tif / np.float32(v)
            tif = cv2.cvtColor(tif, cv2.COLOR_GRAY2RGB)
            return tif
        return img_as_float32(tif) 
[docs]    def get_mask(self, p: Path) -> np.ndarray:
        mask = tifffile.imread(p)
        return mask.astype(np.int16) 
    @cached_property
    def file_list(self) -> List[Path]:
        root_dir = self.root_dir
        file_list = sorted(root_dir.glob('dataset_*/*.tif'))
        return file_list
    @cached_property
    def anno_dict(self) -> Dict[int, Path]:
        root_dir = self.root_dir
        anno_list = sorted(root_dir.glob('annotations_*/*.tif'))
        anno_dict = dict((k, v) for k, v in enumerate(anno_list))
        return anno_dict