More: Split training/test set#
Some datasets (actually many) do not split training/test sets. It is up to us how to
split. We will look at two functions, namely bioimageloader.utils.random_split_dataset()
and bioimageloader.utils.split_dataset_by_indices()
.
Random split#
Let’s pick one dataset which does not provide training/test split, for instance
bioimageloader.collections.ComputationalPathology
. We will split it
into three parts (training/validation/test) using
bioimageloader.utils.random_split_dataset()
.
1import random
2from bioimageloader.collections import ComputationalPathology
3from bioimageloader.utils import random_split_dataset
4
5# set random seed
6SEED = 42
7random.seed(SEED)
8# load dataset
9dset = ComputationalPathology('./data/ComputationalPathology')
10
11# define ratios and numbers
12r_train = 0.6
13r_val = 0.2
14#r_test = 0.2 # the rest
15# get real numbers
16n_train = int(r_train * len(dset))
17n_val = int(r_val * len(dset))
18n_test = len(dset) - n_train - n_val
19
20# SPLIT!
21dset_train, dset_val, dset_test = random_split_dataset(
22 dset,
23 [n_train, n_val, n_test]
24)
25# these assertions will not throw AssertionError
26assert len(dset_train) == n_train
27assert len(dset_val) == n_val
28assert len(dset_test) == n_test
Manual split#
Manual split means hard-coded indices. I found it useful to have hard-coded
indices for training/test split for all datasets and have them saved somewhere
for experiments and analyses. We will use bioimageloader.utils.split_dataset_by_indices()
.
1import random
2from bioimageloader.collections import ComputationalPathology
3from bioimageloader.utils import split_dataset_by_indices
4
5# set random seed
6SEED = 42
7random.seed(SEED)
8# load dataset
9dset = ComputationalPathology('./data/ComputationalPathology')
10
11# define ratios and numbers
12r_train = 0.6
13r_val = 0.2
14#r_test = 0.2 # the rest
15# get real numbers
16n_train = int(r_train * len(dset))
17n_val = int(r_val * len(dset))
18n_test = len(dset) - n_train - n_val
19
20# get indices and save them if you want
21indices = list(range(len(dset)))
22random.shuffle(indices)
23idx_train = [indices.pop() for _ in range(n_train)]
24idx_val = [indices.pop() for _ in range(n_val)]
25idx_test = [indices.pop() for _ in range(n_test)]
26
27# SPLIT!
28dset_train = split_dataset_by_indices(dset, idx_train)
29dset_val = split_dataset_by_indices(dset, idx_val)
30dset_test = split_dataset_by_indices(dset, idx_test)
31# these assertions will not throw AssertionError
32assert len(dset_train) == n_train
33assert len(dset_val) == n_val
34assert len(dset_test) == n_test