import os.path
import warnings
from functools import cached_property
from typing import Dict, List, Optional, Union
import albumentations
import yaml
from .base import Dataset
from .collections import *
from .types import DatasetList
[docs]class Config(dict):
"""Construct config from a yaml file
Parameters
----------
filename : str
Path to config file
"""
def __init__(self, filename=None):
if filename is not None:
self.filename = filename
cfg = self._read_yaml()
for k, v in cfg.items():
self[k] = v
def _read_yaml(self) -> dict:
with open(self.filename, 'r') as f:
cfg = yaml.safe_load(f)
return cfg
@staticmethod
def from_dict(d: dict):
return _Config(d)
[docs] def load_datasets(
self,
transforms: Optional[Union[albumentations.Compose, Dict[str, albumentations.Compose]]] = None,
) -> DatasetList:
"""Load multiple datasets from a yaml file
Note that when you provide a dictionray for ``transforms``, keys should
be the class names, not their acronyms.
Parameters
----------
config : configuration object
Config instance that contains acronyms and arguments to initialize
each dataset
transforms : albumentations.Compose or dictionary, optional
Either apply a single composed transformations for every datasets or
pass a dictionary that defines transformations for each dataset with
keys being the class names of collections.
"""
datasets: List[Dataset] = []
for dataset, kwargs in self.items():
if isinstance(transforms, dict):
if dataset in transforms:
exec(f'datasets.append({dataset}(transforms=transforms[dataset], **kwargs))')
else:
exec(f'datasets.append({dataset}(**kwargs))')
else:
exec(f'datasets.append({dataset}(transforms=transforms, **kwargs))')
return DatasetList(datasets)
@cached_property
def commonpath(self):
commonpath = os.path.commonpath([
p for p in map(lambda x: x['root_dir'], self.values())
])
return commonpath
[docs] def replace_commonpath(self, new: str):
"""Replace common path for all ``root_dir`` with a new one
All ``root_dir`` should have a commonpath. Do not put trailing '/' at
the end.
You made a config with root_dir being relative path. You do not need to
replace them manually with this method.
"""
for k in self.keys():
p = self[k]['root_dir']
self[k]['root_dir'] = p.replace(self.commonpath, new.rstrip('/'))
delattr(self, 'commonpath')
[docs] def set_training(self, val: bool):
"""Iterate config and set all ``training`` to given value
It only affects those that have ``training`` kwarg.
"""
attr = 'training'
warnings.warn(f"This method only set those that have `{attr}` kwarg.",
stacklevel=2)
for k in self.keys():
if attr in self[k]:
self[k][attr] = val
[docs] def set_ouput(self, val: str):
"""Iterate config and set all ``output`` to given value
It only affects those that have ``output`` kwarg.
"""
attr = 'output'
warnings.warn(f"This method only set those that have `{attr}` kwarg.",
stacklevel=2)
for k in self.keys():
if attr in self[k]:
self[k][attr] = val
[docs] def set_grayscale(self, val: bool):
"""Iterate config and set all ``grayscale`` to given value
It only affects those that have ``grayscale`` kwarg.
"""
attr = 'grayscale'
warnings.warn(f"This method only set those that have `{attr}` kwarg.",
stacklevel=2)
for k in self.keys():
if attr in self[k]:
self[k][attr] = val
class _Config(Config):
"""Config.from_dict()"""
def __init__(self, d: dict):
for k, v in d.items():
self[k] = v