Source code for ada.datasets.dataset_access
import torch
import torch.utils.data
[docs]class DatasetAccess:
"""
This class ensures a unique API is used to access training, validation and test splits
of any dataset.
"""
def __init__(self, n_classes):
self._n_classes = n_classes
[docs] def n_classes(self):
return self._n_classes
[docs] def get_train(self):
"""
returns: a torch.utils.data.Dataset
"""
raise NotImplementedError()
[docs] def get_train_val(self, val_ratio):
train_dataset = self.get_train()
ntotal = len(train_dataset)
ntrain = int((1 - val_ratio) * ntotal)
torch.manual_seed(torch.initial_seed())
return torch.utils.data.random_split(train_dataset, [ntrain, ntotal - ntrain])
[docs] def get_test(self):
raise NotImplementedError()