More: Split training/test set#

Some datasets (actually many) do not split training/test sets. It is up to us how to split. We will look at two functions, namely bioimageloader.utils.random_split_dataset() and bioimageloader.utils.split_dataset_by_indices().

Random split#

Let’s pick one dataset which does not provide training/test split, for instance bioimageloader.collections.ComputationalPathology. We will split it into three parts (training/validation/test) using bioimageloader.utils.random_split_dataset().

 1import random
 2from bioimageloader.collections import ComputationalPathology
 3from bioimageloader.utils import random_split_dataset
 4
 5# set random seed
 6SEED = 42
 7random.seed(SEED)
 8# load dataset
 9dset = ComputationalPathology('./data/ComputationalPathology')
10
11# define ratios and numbers
12r_train = 0.6
13r_val = 0.2
14#r_test = 0.2  # the rest
15# get real numbers
16n_train = int(r_train * len(dset))
17n_val = int(r_val * len(dset))
18n_test = len(dset) - n_train - n_val
19
20# SPLIT!
21dset_train, dset_val, dset_test = random_split_dataset(
22   dset,
23   [n_train, n_val, n_test]
24)
25# these assertions will not throw AssertionError
26assert len(dset_train) == n_train
27assert len(dset_val) == n_val
28assert len(dset_test) == n_test

Manual split#

Manual split means hard-coded indices. I found it useful to have hard-coded indices for training/test split for all datasets and have them saved somewhere for experiments and analyses. We will use bioimageloader.utils.split_dataset_by_indices().

 1import random
 2from bioimageloader.collections import ComputationalPathology
 3from bioimageloader.utils import split_dataset_by_indices
 4
 5# set random seed
 6SEED = 42
 7random.seed(SEED)
 8# load dataset
 9dset = ComputationalPathology('./data/ComputationalPathology')
10
11# define ratios and numbers
12r_train = 0.6
13r_val = 0.2
14#r_test = 0.2  # the rest
15# get real numbers
16n_train = int(r_train * len(dset))
17n_val = int(r_val * len(dset))
18n_test = len(dset) - n_train - n_val
19
20# get indices and save them if you want
21indices = list(range(len(dset)))
22random.shuffle(indices)
23idx_train = [indices.pop() for _ in range(n_train)]
24idx_val = [indices.pop() for _ in range(n_val)]
25idx_test = [indices.pop() for _ in range(n_test)]
26
27# SPLIT!
28dset_train = split_dataset_by_indices(dset, idx_train)
29dset_val = split_dataset_by_indices(dset, idx_val)
30dset_test = split_dataset_by_indices(dset, idx_test)
31# these assertions will not throw AssertionError
32assert len(dset_train) == n_train
33assert len(dset_val) == n_val
34assert len(dset_test) == n_test