ada.datasets package

Submodules

ada.datasets.dataset_access module

class ada.datasets.dataset_access.DatasetAccess(n_classes)[source]

Bases: object

This class ensures a unique API is used to access training, validation and test splits of any dataset.

get_test()[source]
get_train()[source]

returns: a torch.utils.data.Dataset

get_train_val(val_ratio)[source]
n_classes()[source]

ada.datasets.dataset_factory module

class ada.datasets.dataset_factory.DatasetFactory(data_config, data_path=None, download=True, n_fewshot=0)[source]

Bases: object

This class takes a configuration dictionary and generates a MultiDomainDataset class with the appropriate data.

get_data_args()[source]

Returns dataset specific arguments necessary to build the network first returned item is number of classes second is a tuple of arguments to be passed to all network_factory functions.

Returns:
tuple containing:
  • int: the number of classes in the dataset
  • int or None: the input dimension
  • int or None: the number of channels for images
Return type:tuple
get_data_hash()[source]
get_data_long_name()[source]
get_data_short_name()[source]
get_multi_domain_dataset(random_state)[source]
is_semi_supervised()[source]
class ada.datasets.dataset_factory.WeightingType[source]

Bases: enum.Enum

An enumeration.

BALANCED = 'balanced'
NATURAL = 'natural'
PRESET0 = 'preset0'

ada.datasets.dataset_mnistm module

Dataset setting and data loader for MNIST-M.

Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py

CREDIT: https://github.com/corenel amt: changed train_data and test_data to data, and train_labels and test_labels to targets like MNIST

class ada.datasets.dataset_mnistm.MNISTM(root, train=True, transform=None, target_transform=None, download=False)[source]

Bases: torch.utils.data.dataset.Dataset

MNIST-M Dataset. Auto-downloads the dataset and provide the torch Dataset API.

Parameters:
  • root (str) – path to directory where the MNISTM folder will be created (or exists.)
  • train (bool, optional) – defaults to True. If True, loads the training data. Otherwise, loads the test data.
  • transform (callable, optional) – defaults to None. A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop This preprocessing function applied to all images (whether source or target)
  • target_transform (callable, optional) – default toNone, similar to transform. This preprocessing function applied to all target images, after transform
  • download (bool optional) – defaults to False. Whether to allow downloading the data if not found on disk.
download()[source]

Download the MNISTM data.

processed_folder = 'processed'
raw_folder = 'raw'
test_file = 'mnist_m_test.pt'
training_file = 'mnist_m_train.pt'
url = 'https://github.com/VanushVaswani/keras_mnistm/releases/download/1.0/keras_mnistm.pkl.gz'

ada.datasets.dataset_office31 module

Dataset setting and data loader for Office31. See Domain Adaptation Project at Berkeley.

class ada.datasets.dataset_office31.Office31(root, domain=None, train=True, transform=None, download=False)[source]

Bases: torch.utils.data.dataset.Dataset

Office31 Domain Adaptation Dataset from the Domain Adaptation Project at Berkeley.

Parameters:
  • root (string) – Root directory of dataset where dataset file exist.
  • train (bool, optional) – If True, resample from dataset randomly.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
download()[source]

Download dataset.

load_samples()[source]

Load sample images from dataset.

url = 'https://docs.google.com/uc?export=download&id=0B4IapRTv9pJ1WGZVd1VDMmhwdlE'

ada.datasets.dataset_usps module

Dataset setting and data loader for USPS. Modified from https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py

amt: changed train_data and test_data to data, and train_labels and test_labels to targets like MNIST

class ada.datasets.dataset_usps.USPS(root, train=True, transform=None, download=False)[source]

Bases: torch.utils.data.dataset.Dataset

USPS Dataset.

Parameters:
  • root (string) – Root directory of dataset where dataset file exist.
  • train (bool, optional) – If True, resample from dataset randomly.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
download()[source]

Download dataset.

load_samples()[source]

Load sample images from dataset.

url = 'https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl'

ada.datasets.digits_dataset_access module

class ada.datasets.digits_dataset_access.DigitDataset[source]

Bases: enum.Enum

An enumeration.

MNIST = 'MNIST'
MNISTM = 'MNISTM'
SVHN = 'SVHN'
USPS = 'USPS'
get_accesses = <function DigitDataset.get_accesses>[source]
class ada.datasets.digits_dataset_access.DigitDatasetAccess(data_path, transform_kind)[source]

Bases: ada.datasets.dataset_access.DatasetAccess

class ada.datasets.digits_dataset_access.MNISTDatasetAccess(data_path, transform_kind)[source]

Bases: ada.datasets.digits_dataset_access.DigitDatasetAccess

get_test()[source]
get_train()[source]

returns: a torch.utils.data.Dataset

class ada.datasets.digits_dataset_access.MNISTMDatasetAccess(data_path, transform_kind)[source]

Bases: ada.datasets.digits_dataset_access.DigitDatasetAccess

get_test()[source]
get_train()[source]

returns: a torch.utils.data.Dataset

class ada.datasets.digits_dataset_access.SVHNDatasetAccess(data_path, transform_kind)[source]

Bases: ada.datasets.digits_dataset_access.DigitDatasetAccess

get_test()[source]
get_train()[source]

returns: a torch.utils.data.Dataset

class ada.datasets.digits_dataset_access.USPSDatasetAccess(data_path, transform_kind)[source]

Bases: ada.datasets.digits_dataset_access.DigitDatasetAccess

get_test()[source]
get_train()[source]

returns: a torch.utils.data.Dataset

ada.datasets.multisource module

class ada.datasets.multisource.DatasetSizeType[source]

Bases: enum.Enum

An enumeration.

Max = 'max'
Source = 'source'
get_size = <function DatasetSizeType.get_size>[source]
class ada.datasets.multisource.DomainsDatasetBase[source]

Bases: object

get_domain_loaders(split='train', batch_size=32)[source]

handles the sampling of a dataset containing multiple domains

Parameters:
  • split (string, optional) – [“train”|”valid”|”test”]. Which dataset to iterate on. Defaults to “train”.
  • batch_size (int, optional) – Defaults to 32.
Returns:

A dataloader with API similar to the torch.dataloader, but returning

batches from several domains at each iteration.

Return type:

MultiDataLoader

prepare_data_loaders()[source]

handles train/validation/test split to have 3 datasets each with data from all domains :return:

class ada.datasets.multisource.MultiDomainDatasets(source_access: ada.datasets.dataset_access.DatasetAccess, target_access: ada.datasets.dataset_access.DatasetAccess, val_split_ratio=0.1, source_sampling_config=None, target_sampling_config=None, size_type=<DatasetSizeType.Max: 'max'>, n_fewshot=None, random_state=None)[source]

Bases: ada.datasets.multisource.DomainsDatasetBase

get_domain_loaders(split='train', batch_size=32)[source]

handles the sampling of a dataset containing multiple domains

Parameters:
  • split (string, optional) – [“train”|”valid”|”test”]. Which dataset to iterate on. Defaults to “train”.
  • batch_size (int, optional) – Defaults to 32.
Returns:

A dataloader with API similar to the torch.dataloader, but returning

batches from several domains at each iteration.

Return type:

MultiDataLoader

is_semi_supervised()[source]
prepare_data_loaders()[source]

handles train/validation/test split to have 3 datasets each with data from all domains :return:

ada.datasets.office_dataset_access module

class ada.datasets.office_dataset_access.Office31Dataset[source]

Bases: enum.Enum

An enumeration.

Amazon = 'amazon'
DSLR = 'dslr'
Webcam = 'webcam'
get_accesses = <function Office31Dataset.get_accesses>[source]
class ada.datasets.office_dataset_access.Office31DatasetAccess(domain, data_path)[source]

Bases: ada.datasets.dataset_access.DatasetAccess

get_test()[source]
get_train()[source]

returns: a torch.utils.data.Dataset

ada.datasets.preprocessing module

ada.datasets.preprocessing.get_transform(kind)[source]

ada.datasets.sampler module

class ada.datasets.sampler.BalancedBatchSampler(dataset, batch_size)[source]

Bases: torch.utils.data.sampler.BatchSampler

BatchSampler - from a MNIST-like dataset, samples n_samples for each of the n_classes. Returns batches of size n_classes * (batch_size // n_classes) adapted from https://github.com/adambielski/siamese-triplet/blob/master/datasets.py

class ada.datasets.sampler.InfiniteSliceIterator(array, class_)[source]

Bases: object

get(n)[source]
reset()[source]
class ada.datasets.sampler.MultiDataLoader(dataloaders, n_batches)[source]

Bases: object

Batch Sampler for a MultiDataset. Iterates in parallel over different batch samplers for each dataset. Yields batches [(x_1, y_1), …, (x_s, y_s)] for s datasets.

class ada.datasets.sampler.ReweightedBatchSampler(dataset, batch_size, class_weights)[source]

Bases: torch.utils.data.sampler.BatchSampler

BatchSampler - from a MNIST-like dataset, samples batch_size according to given input distribution assuming multi-class labels adapted from https://github.com/adambielski/siamese-triplet/blob/master/datasets.py

class ada.datasets.sampler.SamplingConfig(balance=False, class_weights=None)[source]

Bases: object

create_loader(dataset, batch_size)[source]
ada.datasets.sampler.get_labels(dataset)[source]

ada.datasets.toys module

class ada.datasets.toys.CausalBlobs(data_path, train=True, transform=None, download=True, cluster_params=None, n_samples=300)[source]

Bases: torch.utils.data.dataset.Dataset

CausalGaussianBlobs Dataset. MNIST-like dataset that generates Blobs in a given environment setting - original cluster params set by `cluster_params dictionary - environment and cluster generation params given by transform dictionary

create_on_disk()[source]
delete_from_disk()[source]
raw_folder = 'BlobsData'
class ada.datasets.toys.CausalBlobsDataAccess(data_path, transform, download, cluster_params, n_samples)[source]

Bases: ada.datasets.dataset_access.DatasetAccess

get_test()[source]
get_train()[source]

returns: a torch.utils.data.Dataset

class ada.datasets.toys.CausalClusterGenerator(dim=2, n_clusters=2, radius=0.05, proba_classes=0.5, centers='fixed', shape='blobs', data_seed=None)[source]

Bases: object

Generate blobs from a gaussian distribution following given causal parameters relating environment/domain, X and Y: - Y –> X: select class Y, then distribution X|Y

generate_sample(nb_samples, shift_y=False, shift_x=False, shift_conditional_x=False, shift_conditional_y=False, y_cause_x=True, ye=0.5, te=0.3, se=None, re=None)[source]

Generate a sample and apply a given shift: shift_x = change p(x), ie x_e = f(x, env) shift_y = change p(y), ie y_e = f(y, env) shift_conditional_x = change p(x|y), ie x_e = f(y, x, env) shift_conditional_y = change p(y|x), ie y_e = f(x, y, env)

env_parameters control the change in the data: ye = proportion of class 0 labels te = translation value (uniform on all dimensions!) se = scaling factor re = rotation in radians

means
ada.datasets.toys.gen_cluster_distributions(dim, n_clusters, radius, random_state=None, centers='normal')[source]
ada.datasets.toys.get_datashift_params(data_shift=None, ye=0.5, te=None, se=None, re=None)[source]

This factory simplifies the parameter generation process for a number of use cases. The parameters generated can be used with CausalClusterGenerator.generate_sample

ada.datasets.toys.shift_data(x_in, ti=None, ri=None, si=None)[source]

This function applies scaling, translation and/or rotation to 2D data points, in that order only.

Args
x_in (np.ndarray): data, input feature array of shape (n, d) ti (float, optional): translation (scalar or vector of compatible dimension). Defaults to None. ri (float, optional): rotation angle in radians (scalar, for 2D points only). Defaults to None. si (float, optional): scaling factor (scalar). Defaults to None.
Returns:transformed feature array of shape (n, d), same as x_in.
Return type:np.ndarray

Module contents