import logging
from enum import Enum
import numpy as np
from sklearn.utils import check_random_state
import torch.utils.data
from ada.datasets.sampler import get_labels, MultiDataLoader, SamplingConfig
from ada.datasets.dataset_access import DatasetAccess
[docs]class DatasetSizeType(Enum):
Max = "max" # size of the biggest dataset
Source = "source" # size of the source dataset
[docs] @staticmethod
def get_size(size_type, source_dataset, *other_datasets):
if size_type is DatasetSizeType.Max:
return max(list(map(len, other_datasets)) + [len(source_dataset)])
elif size_type is DatasetSizeType.Source:
return len(source_dataset)
else:
raise ValueError(
f"Size type size must be 'max' or 'source', had '{size_type}'"
)
[docs]class DomainsDatasetBase:
[docs] def prepare_data_loaders(self):
"""
handles train/validation/test split to have 3 datasets each with data from all domains
:return:
"""
raise NotImplementedError()
[docs] def get_domain_loaders(self, split="train", batch_size=32):
"""
handles the sampling of a dataset containing multiple domains
Args:
split (string, optional): ["train"|"valid"|"test"]. Which dataset to iterate on. Defaults to "train".
batch_size (int, optional): Defaults to 32.
Returns:
MultiDataLoader: A dataloader with API similar to the torch.dataloader, but returning
batches from several domains at each iteration.
"""
raise NotImplementedError()
[docs]class MultiDomainDatasets(DomainsDatasetBase):
def __init__(
self,
source_access: DatasetAccess,
target_access: DatasetAccess,
val_split_ratio=0.1,
source_sampling_config=None,
target_sampling_config=None,
size_type=DatasetSizeType.Max,
n_fewshot=None,
random_state=None,
):
"""The class controlling how the source and target domains are
iterated over.
Args:
source_access (DatasetAccess): accessor for the source dataset
target_access (DatasetAccess): accessor for the target dataset
val_split_ratio (float, optional): ratio for the validation part of the train dataset. Defaults to 0.1.
source_sampling_config (SamplingConfig, optional): How to sample from the source. Defaults to None (=> RandomSampler).
target_sampling_config (SamplingConfig, optional): How to sample from the target. Defaults to None (=> RandomSampler).
size_type (DatasetSizeType, optional): Which dataset size to use to define the number of epochs vs batch_size. Defaults to DatasetSizeType.Max.
n_fewshot (int, optional): Number of target samples for which the label may be used,
to define the few-shot, semi-supervised setting. Defaults to None.
random_state ([int|np.random.RandomState], optional): Used for deterministic sampling/few-shot label selection. Defaults to None.
"""
self._source_access = source_access
self._target_access = target_access
self._val_split_ratio = val_split_ratio
self._source_sampling_config = (
source_sampling_config
if source_sampling_config is not None
else SamplingConfig()
)
self._target_sampling_config = (
target_sampling_config
if target_sampling_config is not None
else SamplingConfig()
)
self._size_type = size_type
self._n_fewshot = n_fewshot
self._random_state = check_random_state(random_state)
self._source_by_split = {}
self._labeled_target_by_split = None
self._target_by_split = {}
[docs] def is_semi_supervised(self):
return self._n_fewshot is not None and self._n_fewshot > 0
[docs] def prepare_data_loaders(self):
logging.debug("Load source")
(
self._source_by_split["train"],
self._source_by_split["valid"],
) = self._source_access.get_train_val(self._val_split_ratio)
logging.debug("Load target")
(
self._target_by_split["train"],
self._target_by_split["valid"],
) = self._target_access.get_train_val(self._val_split_ratio)
logging.debug("Load source Test")
self._source_by_split["test"] = self._source_access.get_test()
logging.debug("Load target Test")
self._target_by_split["test"] = self._target_access.get_test()
if self._n_fewshot is not None and self._n_fewshot > 0:
# semi-supervised target domain
self._labeled_target_by_split = {}
for part in ["train", "valid", "test"]:
(
self._labeled_target_by_split[part],
self._target_by_split[part],
) = _split_dataset_few_shot(
self._target_by_split[part], self._n_fewshot
)
[docs] def get_domain_loaders(self, split="train", batch_size=32):
source_ds = self._source_by_split[split]
source_loader = self._source_sampling_config.create_loader(
source_ds, batch_size
)
target_ds = self._target_by_split[split]
if self._labeled_target_by_split is None:
# unsupervised target domain
target_loader = self._target_sampling_config.create_loader(
target_ds, batch_size
)
n_dataset = DatasetSizeType.get_size(self._size_type, source_ds, target_ds)
return MultiDataLoader(
dataloaders=[source_loader, target_loader],
n_batches=max(n_dataset // batch_size, 1),
)
else:
# semi-supervised target domain
target_labeled_ds = self._labeled_target_by_split[split]
target_unlabeled_ds = target_ds
# label domain: always balanced
target_labeled_loader = SamplingConfig(
balance=True, class_weights=None
).create_loader(
target_labeled_ds, batch_size=min(len(target_labeled_ds), batch_size)
)
target_unlabeled_loader = self._target_sampling_config.create_loader(
target_unlabeled_ds, batch_size
)
n_dataset = DatasetSizeType.get_size(
self._size_type, source_ds, target_labeled_ds, target_unlabeled_ds
)
return MultiDataLoader(
dataloaders=[
source_loader,
target_labeled_loader,
target_unlabeled_loader,
],
n_batches=max(n_dataset // batch_size, 1),
)
def __len__(self):
source_ds = self._source_by_split["train"]
target_ds = self._target_by_split["train"]
if self._labeled_target_by_split is None:
return DatasetSizeType.get_size(self._size_type, source_ds, target_ds)
else:
labeled_target_ds = self._labeled_target_by_split["train"]
return DatasetSizeType.get_size(
self._size_type, source_ds, labeled_target_ds, target_ds
)
def _split_dataset_few_shot(dataset, n_fewshot, random_state=None):
if n_fewshot <= 0:
raise ValueError(f"n_fewshot should be > 0, not '{n_fewshot}'")
assert n_fewshot > 0
labels = get_labels(dataset)
classes = sorted(set(labels))
if n_fewshot < 1:
max_few = len(dataset) // len(classes)
n_fewshot = round(max_few * n_fewshot)
n_fewshot = int(round(n_fewshot))
random_state = check_random_state(random_state)
# sample n_fewshot items per class from last dataset
tindices = []
uindices = []
for class_ in classes:
indices = np.where(labels == class_)[0]
random_state.shuffle(indices)
head, tail = np.split(indices, [n_fewshot])
assert len(head) == n_fewshot
tindices.append(head)
uindices.append(tail)
tindices = np.concatenate(tindices)
uindices = np.concatenate(uindices)
assert len(tindices) == len(classes) * n_fewshot
labeled_dataset = torch.utils.data.Subset(dataset, tindices)
unlabeled_dataset = torch.utils.data.Subset(dataset, uindices)
return labeled_dataset, unlabeled_dataset