Source code for bioimageloader.base

"""Define a base class and its interface

``Dataset`` is the base of all datasets

``MaskDataset`` is the base of datasets that have mask annotation
"""

import abc
import inspect
import random
from pathlib import Path
from typing import Any, Dict, Iterator, Optional, Sequence, Union

import albumentations
import cv2
import numpy as np


[docs]class DatasetInterface(metaclass=abc.ABCMeta): """Dataset interface Attributes ---------- __repr__ __len__ __getitem__ acronym root_dir file_list Methods ------- get_image """ @abc.abstractmethod def __repr__(self): # common """Print info of dataset""" ... @abc.abstractmethod def __len__(self): # common ... @abc.abstractmethod def __getitem__(self, ind): # common """Given index returns a dictionary of key(s) and array""" ...
[docs] @abc.abstractmethod def get_image(self, key): # required """Get an image""" ...
@classmethod @property @abc.abstractmethod def acronym(cls): # required """Assign acroym for a subclass dataset""" ... @property @abc.abstractmethod def root_dir(self): # required """Path to root directory""" ... @property def file_list(self): # required """A list of pathes to image files""" ...
[docs]class Dataset(DatasetInterface): """Base to define common attributes and methods for [`MaskDataset`, ...] Attributes ---------- __repr__ __len__ __iter__ root_dir output transforms num_samples grayscale : optional grayscale_mode : optional num_channels : optional Methods ------- __getitem__ _drop_missing_pairs to_gray Notes ----- Required attributes in subclass - ``anno_dict`` - ``__getitem__()`` - ``get_image()`` """ def __repr__(self): signature = inspect.signature(self.__init__) params = list(signature.parameters.keys()) # remove 'kwargs'. try/except would be a better choice. if 'kwargs' in params: params.remove('kwargs') init_args_str = '(' + ', '.join(f'{k}={getattr(self, k)}' for k in params) + ')' return self.acronym + init_args_str def __len__(self): """Length of dataset. Can be overwritten with ``num_samples``""" if self.num_samples is not None: return self.num_samples return len(self.file_list) def __iter__(self): return IterDataset(self) @property def root_dir(self) -> Path: if hasattr(self, '_root_dir'): _root_dir = getattr(self, '_root_dir') if not isinstance(_root_dir, Path): return Path(_root_dir) return _root_dir raise NotImplementedError("Attr `_root_dir` not defined") @property def output(self) -> str: """Determine return(s) when called, fixed to 'image'""" return 'image' @property def transforms(self) -> Optional[albumentations.Compose]: """Transform images and masks""" if hasattr(self, '_transforms'): return getattr(self, '_transforms') return None @property def num_samples(self) -> Optional[int]: """Number of calls that will override __len__""" if hasattr(self, '_num_samples'): return getattr(self, '_num_samples') return None @num_samples.setter def num_samples(self, val): self._num_samples = val @property def grayscale(self) -> Optional[bool]: """Flag for grayscale conversion""" if hasattr(self, '_grayscale'): return getattr(self, '_grayscale') return None @grayscale.setter def grayscale(self, val): self._grayscale = val @property def grayscale_mode(self) -> Optional[Union[str, Sequence[float]]]: """Determine grayscale mode one of {'cv2', 'equal', Sequence[float]} """ if hasattr(self, '_grayscale_mode'): return getattr(self, '_grayscale_mode') return None @grayscale_mode.setter def grayscale_mode(self, val): self._grayscale_mode = val @property def num_channels(self) -> Optional[int]: """Number of image channels used for `to_gray()`""" if hasattr(self, '_num_channels'): return getattr(self, '_num_channels') return None def __getitem__(self, ind: int) -> Dict[str, np.ndarray]: """Get image Dataset does not any annotation available. It will only load 'image'. Parameters ---------- ind : int Index to get path(s) from ``file_list`` attribute Attributes ---------- self.file_list Other Parameters ---------------- self._transforms self._num_samples self._grayscale self._grayscale_mode self._num_channels """ # Randomize `ind` when `num_samples` set if self.num_samples is not None: if ind >= self.num_samples: raise IndexError('list index out of range') ind_max = len(self.file_list) ind = random.randrange(0, ind_max) # `output="image"` p = self.file_list[ind] image = self.get_image(p) if self.grayscale: num_channels = self.num_channels # exception if hasattr(self, 'image_ch') and len(self.image_ch) == 1: # BBBC020: `_num_channels=2`. When `image_ch` is set to one # channel, output images become grayscale and `to_gray()` # got `num_channels=2`. num_channels = 3 if num_channels is None: num_channels = len(p) if isinstance(p, list) else 3 image = self.to_gray( image, grayscale_mode=self.grayscale_mode, num_channels=num_channels ) if self.transforms: image = self.transforms(image=image)['image'] return {'image': image} def _drop_missing_pairs(self) -> tuple: """Drop images and reindex the anno list (dict) Sometimes, not all images have annotation. For consistence, this func simply drops those images missing annotation. For example, - MurphyLab - BBBC018 - BBBC020 """ file_list = getattr(self, 'file_list') anno_dict = getattr(self, 'anno_dict') _diff = set(range(len(file_list))).difference(set(anno_dict)) diff = sorted(_diff) # logger.info(f'{self.acronym}:Dropping indices: {diff}') for i, ind in enumerate(diff): file_list.pop(ind-i) anno_dict = dict((i, v) for i, v in enumerate(anno_dict.values())) return file_list, anno_dict
[docs] @staticmethod def to_gray( arr: np.ndarray, grayscale_mode: Optional[Union[str, Sequence[float]]] = None, num_channels: int = 3, ) -> np.ndarray: """Convert bioimage to grayscale Parameters ---------- arr : image array Numpy image array whose shape is (h, w, 3) grayscale_mode : str or sequence of float, optional Choose a strategy for gray conversion. Three options are availble. Either one of {'cv2', 'equal'} or be a sequence of float numbers, which indicate linear weights of each channel. num_channels : int Explicitly set number of channels for `grayscale_mode='equal'`. """ if isinstance(grayscale_mode, str): if grayscale_mode == 'cv2': if arr.shape[-1] != 3: raise ValueError("Image arr should have RGB channels") arr = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY) arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB) elif grayscale_mode == 'equal': # Expect (h, w, ch) shape of array arr = (arr.sum(axis=-1) / num_channels).astype(arr.dtype) arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB) else: raise ValueError(f"Wrong `grayscale_mode={grayscale_mode}`") else: raise NotImplementedError("`grayscale_mode`") return arr
[docs]class MaskDataset(Dataset): """Base for datasets with mask annotation Define ``__getitem__`` method to load mask annotation paired with image. Pre-defined attributes are prefixed with a single underscore to distinguish them from those specific to a dataset. It is required to implement two methods: ``get_image()`` and ``get_mask()`` as well as ``acronym`` and ``_root_dir`` for each subclass. Attributes ---------- output anno_dict Methods ------- __getitem__ Notes ----- Required attributes in subclass - ``acronym`` - ``_root_dir`` - ``_output`` - ``_grayscale`` (optional) - ``_grayscale_mode`` (optional) - ``_num_channels`` (optional) - ``get_image()`` - ``get_mask()`` (optional) See Also -------- Dataset : super class """ @property def output(self) -> str: """Determine return(s) when called""" return self._output @output.setter def output(self, val): self._output = val @property def anno_dict(self) -> Dict[int, Any]: """Dictionary of pathes to annotation files""" raise NotImplementedError def __getitem__(self, ind: int) -> Dict[str, np.ndarray]: """Get image, mask, or both depending on ``output`` argument For MaskDataset, available output types are ['image', 'mask', 'both']. Parameters ---------- ind : int Index to get path(s) from ``file_list`` attribute Attributes ---------- self.output self.file_list self.anno_dict Other Parameters ---------------- self._transforms self._num_samples self._grayscale self._grayscale_mode self._num_channels """ # Randomize `ind` when `num_samples` set if self.num_samples is not None: if ind >= self.num_samples: raise IndexError('list index out of range') ind_max = len(self.file_list) if (self.output != 'image') and (self.anno_dict is not None): ind_max = len(self.anno_dict) ind = random.randrange(0, ind_max) # `output="image"` if self.output == 'image': p = self.file_list[ind] image = self.get_image(p) if self.grayscale: num_channels = self.num_channels # exception if hasattr(self, 'image_ch') and len(self.image_ch) == 1: # BBBC020: `_num_channels=2`. When `image_ch` is set to one # channel, output images become grayscale and `to_gray()` # got `num_channels=2`. num_channels = 3 if num_channels is None: num_channels = len(p) if isinstance(p, list) else 3 image = self.to_gray( image, grayscale_mode=self.grayscale_mode, num_channels=num_channels ) if self.transforms: image = self.transforms(image=image)['image'] return {'image': image} # `output="mask"` elif self.output == 'mask': pm = self.anno_dict[ind] mask = self.get_mask(pm) _image = np.zeros_like(mask, dtype=np.uint8) # dummy image if self.transforms is not None: mask = self.transforms(image=_image, mask=mask)['mask'] # # Filtering out empty masks # while mask.max() == 0: # mask = self.transforms.augment_image(mask) return {'mask': mask} # both image and gt elif self.output == 'both': # 'image' p = self.file_list[ind] image = self.get_image(p) # 'mask' pm = self.anno_dict[ind] mask = self.get_mask(pm) if self.grayscale: num_channels = self.num_channels # exception if hasattr(self, 'image_ch') and len(self.image_ch) == 1: # BBBC020: `_num_channels=2`. When `image_ch` is set to one # channel, output images become grayscale and `to_gray()` # got `num_channels=2`. num_channels = 3 if num_channels is None: num_channels = len(p) if isinstance(p, list) else 3 image = self.to_gray( image, grayscale_mode=self.grayscale_mode, num_channels=num_channels ) # Make sure to apply the same augmentation both to image and mask if self.transforms is not None: augmented = self.transforms(image=image, mask=mask) image, mask = augmented['image'], augmented['mask'] return {'image': image, 'mask': mask} else: raise NotImplementedError("Choose one ['image', 'mask', 'both']")
[docs] def get_mask(self, key) -> np.ndarray: """Get a mask""" raise NotImplementedError
[docs]class IterDataset(Iterator): """Iterable """ def __init__(self, dataset: Dataset): self.dataset = dataset self.ind = 0 self.end = len(self.dataset) def __next__(self): if self.ind == self.end: raise StopIteration data = self.dataset[self.ind] self.ind += 1 return data