Source code for ada.datasets.sampler

import torchvision
import torch.utils.data
import logging
import numpy as np
from torch.utils.data.sampler import RandomSampler, BatchSampler


[docs]class SamplingConfig: def __init__(self, balance=False, class_weights=None): if balance and class_weights is not None: raise ValueError("Params 'balance' and 'weights' are incompatible") self._balance = balance self._class_weights = class_weights
[docs] def create_loader(self, dataset, batch_size): if self._balance: sampler = BalancedBatchSampler(dataset, batch_size=batch_size) elif self._class_weights is not None: sampler = ReweightedBatchSampler( dataset, batch_size=batch_size, class_weights=self._class_weights ) else: if len(dataset) < batch_size: sub_sampler = RandomSampler( dataset, replacement=True, num_samples=batch_size ) else: sub_sampler = RandomSampler(dataset) sampler = BatchSampler(sub_sampler, batch_size=batch_size, drop_last=True) return torch.utils.data.DataLoader(dataset=dataset, batch_sampler=sampler)
# TODO: deterministic shuffle?
[docs]class MultiDataLoader: """ Batch Sampler for a MultiDataset. Iterates in parallel over different batch samplers for each dataset. Yields batches [(x_1, y_1), ..., (x_s, y_s)] for s datasets. """ def __init__(self, dataloaders, n_batches): if n_batches <= 0: raise ValueError("n_batches should be > 0") self._dataloaders = dataloaders self._n_batches = np.maximum(1, n_batches) self._init_iterators() def _init_iterators(self): self._iterators = [iter(dl) for dl in self._dataloaders] def _get_nexts(self): def _get_next_dl_batch(di, dl): try: batch = next(dl) except StopIteration: logging.debug(f"reinit loader {di} of type {type(dl)}") new_dl = iter(self._dataloaders[di]) self._iterators[di] = new_dl batch = next(new_dl) return batch return [_get_next_dl_batch(di, dl) for di, dl in enumerate(self._iterators)] def __iter__(self): for _ in range(self._n_batches): yield self._get_nexts() self._init_iterators() def __len__(self): return self._n_batches
[docs]class BalancedBatchSampler(torch.utils.data.sampler.BatchSampler): """ BatchSampler - from a MNIST-like dataset, samples n_samples for each of the n_classes. Returns batches of size n_classes * (batch_size // n_classes) adapted from https://github.com/adambielski/siamese-triplet/blob/master/datasets.py """ def __init__(self, dataset, batch_size): labels = get_labels(dataset) classes = sorted(set(labels)) n_classes = len(classes) self._n_samples = batch_size // n_classes if self._n_samples == 0: raise ValueError( f"batch_size should be bigger than the number of classes, got {batch_size}" ) self._class_iters = [ InfiniteSliceIterator(np.where(labels == class_)[0], class_=class_) for class_ in classes ] batch_size = self._n_samples * n_classes self.n_dataset = len(labels) self._n_batches = self.n_dataset // batch_size if self._n_batches == 0: raise ValueError( f"Dataset is not big enough to generate batches with size {batch_size}" ) logging.debug("K=", n_classes, "nk=", self._n_samples) logging.debug("Batch size = ", batch_size) def __iter__(self): for _ in range(self._n_batches): indices = [] for class_iter in self._class_iters: indices.extend(class_iter.get(self._n_samples)) np.random.shuffle(indices) yield indices for class_iter in self._class_iters: class_iter.reset() def __len__(self): return self._n_batches
[docs]class ReweightedBatchSampler(torch.utils.data.sampler.BatchSampler): """ BatchSampler - from a MNIST-like dataset, samples batch_size according to given input distribution assuming multi-class labels adapted from https://github.com/adambielski/siamese-triplet/blob/master/datasets.py """ # /!\ 'class_weights' should be provided in the "natural order" of the classes (i.e. sorted(classes)) /!\ def __init__(self, dataset, batch_size, class_weights): labels = get_labels(dataset) self._classes = sorted(set(labels)) n_classes = len(self._classes) if n_classes > len(class_weights): k = len(class_weights) sum_w = np.sum(class_weights) if sum_w >= 1: # normalize attributing equal weight to weighted part and remaining part class_weights /= sum_w * k / n_classes + (n_classes - k) / n_classes krem = k - n_classes wrem = 1 - sum_w logging.warning( f"will assume uniform distribution for labels > {len(class_weights)}" ) self._class_weights = np.ones(n_classes, dtype=np.float) self._class_weights[:k] = class_weights self._class_weights[k:] = wrem / krem else: self._class_weights = class_weights[:n_classes] if np.sum(self._class_weights) != 1: self._class_weights = self._class_weights / np.sum(self._class_weights) logging.debug("Using weights=", self._class_weights) if batch_size == 0: raise ValueError( f"batch_size should be bigger than the number of classes, got {batch_size}" ) self._class_to_iter = { class_: InfiniteSliceIterator(np.where(labels == class_)[0], class_=class_) for class_ in self._classes } self.n_dataset = len(labels) self._batch_size = batch_size self._n_batches = self.n_dataset // self._batch_size if self._n_batches == 0: raise ValueError( f"Dataset is not big enough to generate batches with size {self._batch_size}" ) logging.debug("K=", n_classes, "nk=", self._batch_size) logging.debug("Batch size = ", self._batch_size) def __iter__(self): for _ in range(self._n_batches): # sample batch_size classes class_idx = np.random.choice( self._classes, p=self._class_weights, replace=True, size=self._batch_size, ) indices = [] for class_, num in zip(*np.unique(class_idx, return_counts=True)): indices.extend(self._class_to_iter[class_].get(num)) np.random.shuffle(indices) yield indices for class_iter in self._class_to_iter.values(): class_iter.reset() def __len__(self): return self._n_batches
[docs]def get_labels(dataset): dataset_type = type(dataset) if dataset_type is torchvision.datasets.SVHN: return dataset.labels if dataset_type is torchvision.datasets.ImageFolder: return dataset.imgs[:][1] # Handle subset, recurses into non-subset version if dataset_type is torch.utils.data.Subset: indices = dataset.indices all_labels = get_labels(dataset.dataset) logging.debug(f"data subset of len {len(indices)} from {len(all_labels)}") labels = all_labels[indices] if isinstance(labels, torch.Tensor): return labels.numpy() return labels try: logging.debug(dataset.targets.shape, type(dataset.targets)) if isinstance(dataset.targets, torch.Tensor): return dataset.targets.numpy() return dataset.targets except AttributeError: logging.error(type(dataset))
[docs]class InfiniteSliceIterator: def __init__(self, array, class_): assert type(array) is np.ndarray self.array = array self.i = 0 self.class_ = class_
[docs] def reset(self): self.i = 0
[docs] def get(self, n): len_ = len(self.array) # not enough element in 'array' if len_ < n: logging.debug(f"there are really few items in class {self.class_}") self.reset() np.random.shuffle(self.array) mul = n // len_ rest = n - mul * len_ return np.concatenate((np.tile(self.array, mul), self.array[:rest])) # not enough element in array's tail if len_ - self.i < n: self.reset() if self.i == 0: np.random.shuffle(self.array) i = self.i self.i += n return self.array[i : self.i]