ada.datasets package¶
Submodules¶
ada.datasets.dataset_access module¶
ada.datasets.dataset_factory module¶
-
class
ada.datasets.dataset_factory.DatasetFactory(data_config, data_path=None, download=True, n_fewshot=0)[source]¶ Bases:
objectThis class takes a configuration dictionary and generates a MultiDomainDataset class with the appropriate data.
-
get_data_args()[source]¶ Returns dataset specific arguments necessary to build the network first returned item is number of classes second is a tuple of arguments to be passed to all network_factory functions.
Returns: - tuple containing:
- int: the number of classes in the dataset
- int or None: the input dimension
- int or None: the number of channels for images
Return type: tuple
-
ada.datasets.dataset_mnistm module¶
Dataset setting and data loader for MNIST-M.
Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
CREDIT: https://github.com/corenel amt: changed train_data and test_data to data, and train_labels and test_labels to targets like MNIST
-
class
ada.datasets.dataset_mnistm.MNISTM(root, train=True, transform=None, target_transform=None, download=False)[source]¶ Bases:
torch.utils.data.dataset.DatasetMNIST-M Dataset. Auto-downloads the dataset and provide the torch Dataset API.
Parameters: - root (str) – path to directory where the MNISTM folder will be created (or exists.)
- train (bool, optional) – defaults to True. If True, loads the training data. Otherwise, loads the test data.
- transform (callable, optional) – defaults to None.
A function/transform that takes in
an PIL image and returns a transformed version.
E.g,
transforms.RandomCropThis preprocessing function applied to all images (whether source or target) - target_transform (callable, optional) – default toNone, similar to transform. This preprocessing function applied to all target images, after transform
- download (bool optional) – defaults to False. Whether to allow downloading the data if not found on disk.
-
processed_folder= 'processed'¶
-
raw_folder= 'raw'¶
-
test_file= 'mnist_m_test.pt'¶
-
training_file= 'mnist_m_train.pt'¶
-
url= 'https://github.com/VanushVaswani/keras_mnistm/releases/download/1.0/keras_mnistm.pkl.gz'¶
ada.datasets.dataset_office31 module¶
Dataset setting and data loader for Office31. See Domain Adaptation Project at Berkeley.
-
class
ada.datasets.dataset_office31.Office31(root, domain=None, train=True, transform=None, download=False)[source]¶ Bases:
torch.utils.data.dataset.DatasetOffice31 Domain Adaptation Dataset from the Domain Adaptation Project at Berkeley.
Parameters: - root (string) – Root directory of dataset where dataset file exist.
- train (bool, optional) – If True, resample from dataset randomly.
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- transform (callable, optional) – A function/transform that takes in
an PIL image and returns a transformed version.
E.g,
transforms.RandomCrop
-
url= 'https://docs.google.com/uc?export=download&id=0B4IapRTv9pJ1WGZVd1VDMmhwdlE'¶
ada.datasets.dataset_usps module¶
Dataset setting and data loader for USPS. Modified from https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py
amt: changed train_data and test_data to data, and train_labels and test_labels to targets like MNIST
-
class
ada.datasets.dataset_usps.USPS(root, train=True, transform=None, download=False)[source]¶ Bases:
torch.utils.data.dataset.DatasetUSPS Dataset.
Parameters: - root (string) – Root directory of dataset where dataset file exist.
- train (bool, optional) – If True, resample from dataset randomly.
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- transform (callable, optional) – A function/transform that takes in
an PIL image and returns a transformed version.
E.g,
transforms.RandomCrop
-
url= 'https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl'¶
ada.datasets.digits_dataset_access module¶
-
class
ada.datasets.digits_dataset_access.DigitDataset[source]¶ Bases:
enum.EnumAn enumeration.
-
MNIST= 'MNIST'¶
-
MNISTM= 'MNISTM'¶
-
SVHN= 'SVHN'¶
-
USPS= 'USPS'¶
-
-
class
ada.datasets.digits_dataset_access.MNISTDatasetAccess(data_path, transform_kind)[source]¶ Bases:
ada.datasets.digits_dataset_access.DigitDatasetAccess
-
class
ada.datasets.digits_dataset_access.MNISTMDatasetAccess(data_path, transform_kind)[source]¶ Bases:
ada.datasets.digits_dataset_access.DigitDatasetAccess
-
class
ada.datasets.digits_dataset_access.SVHNDatasetAccess(data_path, transform_kind)[source]¶ Bases:
ada.datasets.digits_dataset_access.DigitDatasetAccess
ada.datasets.multisource module¶
-
class
ada.datasets.multisource.DatasetSizeType[source]¶ Bases:
enum.EnumAn enumeration.
-
Max= 'max'¶
-
Source= 'source'¶
-
-
class
ada.datasets.multisource.DomainsDatasetBase[source]¶ Bases:
object-
get_domain_loaders(split='train', batch_size=32)[source]¶ handles the sampling of a dataset containing multiple domains
Parameters: - split (string, optional) – [“train”|”valid”|”test”]. Which dataset to iterate on. Defaults to “train”.
- batch_size (int, optional) – Defaults to 32.
Returns: - A dataloader with API similar to the torch.dataloader, but returning
batches from several domains at each iteration.
Return type:
-
-
class
ada.datasets.multisource.MultiDomainDatasets(source_access: ada.datasets.dataset_access.DatasetAccess, target_access: ada.datasets.dataset_access.DatasetAccess, val_split_ratio=0.1, source_sampling_config=None, target_sampling_config=None, size_type=<DatasetSizeType.Max: 'max'>, n_fewshot=None, random_state=None)[source]¶ Bases:
ada.datasets.multisource.DomainsDatasetBase-
get_domain_loaders(split='train', batch_size=32)[source]¶ handles the sampling of a dataset containing multiple domains
Parameters: - split (string, optional) – [“train”|”valid”|”test”]. Which dataset to iterate on. Defaults to “train”.
- batch_size (int, optional) – Defaults to 32.
Returns: - A dataloader with API similar to the torch.dataloader, but returning
batches from several domains at each iteration.
Return type:
-
ada.datasets.office_dataset_access module¶
ada.datasets.sampler module¶
-
class
ada.datasets.sampler.BalancedBatchSampler(dataset, batch_size)[source]¶ Bases:
torch.utils.data.sampler.BatchSamplerBatchSampler - from a MNIST-like dataset, samples n_samples for each of the n_classes. Returns batches of size n_classes * (batch_size // n_classes) adapted from https://github.com/adambielski/siamese-triplet/blob/master/datasets.py
-
class
ada.datasets.sampler.MultiDataLoader(dataloaders, n_batches)[source]¶ Bases:
objectBatch Sampler for a MultiDataset. Iterates in parallel over different batch samplers for each dataset. Yields batches [(x_1, y_1), …, (x_s, y_s)] for s datasets.
-
class
ada.datasets.sampler.ReweightedBatchSampler(dataset, batch_size, class_weights)[source]¶ Bases:
torch.utils.data.sampler.BatchSamplerBatchSampler - from a MNIST-like dataset, samples batch_size according to given input distribution assuming multi-class labels adapted from https://github.com/adambielski/siamese-triplet/blob/master/datasets.py
ada.datasets.toys module¶
-
class
ada.datasets.toys.CausalBlobs(data_path, train=True, transform=None, download=True, cluster_params=None, n_samples=300)[source]¶ Bases:
torch.utils.data.dataset.DatasetCausalGaussianBlobs Dataset. MNIST-like dataset that generates Blobs in a given environment setting - original cluster params set by `cluster_params dictionary - environment and cluster generation params given by transform dictionary
-
raw_folder= 'BlobsData'¶
-
-
class
ada.datasets.toys.CausalBlobsDataAccess(data_path, transform, download, cluster_params, n_samples)[source]¶
-
class
ada.datasets.toys.CausalClusterGenerator(dim=2, n_clusters=2, radius=0.05, proba_classes=0.5, centers='fixed', shape='blobs', data_seed=None)[source]¶ Bases:
objectGenerate blobs from a gaussian distribution following given causal parameters relating environment/domain, X and Y: - Y –> X: select class Y, then distribution X|Y
-
generate_sample(nb_samples, shift_y=False, shift_x=False, shift_conditional_x=False, shift_conditional_y=False, y_cause_x=True, ye=0.5, te=0.3, se=None, re=None)[source]¶ Generate a sample and apply a given shift: shift_x = change p(x), ie x_e = f(x, env) shift_y = change p(y), ie y_e = f(y, env) shift_conditional_x = change p(x|y), ie x_e = f(y, x, env) shift_conditional_y = change p(y|x), ie y_e = f(x, y, env)
env_parameters control the change in the data: ye = proportion of class 0 labels te = translation value (uniform on all dimensions!) se = scaling factor re = rotation in radians
-
means¶
-
-
ada.datasets.toys.gen_cluster_distributions(dim, n_clusters, radius, random_state=None, centers='normal')[source]¶
-
ada.datasets.toys.get_datashift_params(data_shift=None, ye=0.5, te=None, se=None, re=None)[source]¶ This factory simplifies the parameter generation process for a number of use cases. The parameters generated can be used with CausalClusterGenerator.generate_sample
-
ada.datasets.toys.shift_data(x_in, ti=None, ri=None, si=None)[source]¶ This function applies scaling, translation and/or rotation to 2D data points, in that order only.
- Args
- x_in (np.ndarray): data, input feature array of shape (n, d) ti (float, optional): translation (scalar or vector of compatible dimension). Defaults to None. ri (float, optional): rotation angle in radians (scalar, for 2D points only). Defaults to None. si (float, optional): scaling factor (scalar). Defaults to None.
Returns: transformed feature array of shape (n, d), same as x_in. Return type: np.ndarray