Source code for bioimageloader.batch

"""Batch loader module
"""

import bisect
import concurrent.futures
import random
import warnings
from math import ceil
from typing import Dict, Iterator, List, Optional

import numpy as np

from .base import Dataset


[docs]class ConcatDataset: """Concatenate Datasets Todo ---- - Typing datasets with covariant class Lose param hints because of Generic type - Intermediate class linking DatasetInterface and [MaskDataset, BBoxDataset, BBoxDataset, ...] References ---------- .. [1] https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset """ def __init__(self, datasets: List[Dataset]): self.datasets = datasets self.acronym = [dset.acronym for dset in self.datasets] self.sizes = [len(dset) for dset in self.datasets] self.cumulative_sizes = np.cumsum(self.sizes) # Check and Warn if any([s == 0 for s in self.sizes]): i = self.sizes.index(0) warnings.warn(f"ind={i} {self.datasets[i].acronym} is empty", stacklevel=2) if len(set(outputs := [dset.output for dset in self.datasets])) != 1: warnings.warn(f"output types do not match {outputs}", stacklevel=2) def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, ind): ind_dataset = bisect.bisect_right(self.cumulative_sizes, ind) ind_sample = ind if ind_dataset == 0 else ind - self.cumulative_sizes[ind_dataset - 1] return self.datasets[ind_dataset][ind_sample]
[docs]class BatchDataloader: """Batch loader with multi-processing """ def __init__( self, dataset: Dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, num_workers: Optional[int] = None, ): self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last self.num_workers = num_workers self.executor = concurrent.futures.ProcessPoolExecutor(max_workers=self.num_workers) if shuffle: idx = list(range(len(self.dataset))) random.shuffle(idx) dataset.file_list = [dataset.file_list[ind] for ind in idx] if hasattr(dataset, 'anno_dict'): dataset.anno_dict = dict((i, dataset.anno_dict[ind]) for i, ind in enumerate(idx)) def __len__(self): if self.drop_last: return len(self.dataset) // self.batch_size return ceil(len(self.dataset) / self.batch_size) @property def _last_size(self): if self.drop_last: return self.batch_size remainder = len(self.dataset) - self.batch_size * (len(self) - 1) return remainder def __iter__(self): return IterBatchDataloader(self) def __del__(self): # Shutdown mp self.executor.shutdown(wait=True)
[docs]class IterBatchDataloader(Iterator): """Iterate BatchDataloader """ def __init__(self, dataloader: BatchDataloader): self.dataloader = dataloader self.dataset = dataloader.dataset self.batch_size = self.dataloader.batch_size self.drop_last = self.dataloader.drop_last self.num_workers = self.dataloader.num_workers self.executor = self.dataloader.executor self._last_size = self.dataloader._last_size self.batch_ind = 0 self.batch_end = len(dataloader) def __next__(self) -> Dict[str, np.ndarray]: if self.batch_ind == self.batch_end: raise StopIteration if self.batch_ind == self.batch_end - 1 and not self.drop_last: batch_size = self._last_size else: batch_size = self.batch_size ind_start = self.batch_size * self.batch_ind jobs = [self.executor.submit(_mp_getitem, self.dataset, ind) for ind in range(ind_start, ind_start + batch_size)] concurrent.futures.wait(jobs) self.batch_ind += 1 # return [j.result() for j in jobs] # bundle with the same keys d0 = jobs[0].result() keys = d0.keys() ret = {} for k in keys: ret[k] = np.stack([d0[k]] + list(map(lambda j: j.result()[k], jobs[1:])) ) return ret
def _mp_getitem(dataset, ind): return dataset[ind]