import os
from enum import Enum
from pathlib import Path
from sklearn.utils import check_random_state
import numpy as np
import ada.datasets.toys as toys
import ada.datasets.digits_dataset_access as digits
import ada.datasets.office_dataset_access as office
from ada.datasets.multisource import MultiDomainDatasets, DatasetSizeType
from ada.datasets.sampler import SamplingConfig
import ada.utils.experimentation as xp
[docs]class WeightingType(Enum):
NATURAL = "natural"
BALANCED = "balanced"
PRESET0 = "preset0"
[docs]class DatasetFactory:
"""This class takes a configuration dictionary
and generates a MultiDomainDataset class
with the appropriate data.
"""
def __init__(self, data_config, data_path=None, download=True, n_fewshot=0):
"""
Args:
data_config (dict): parameters to factor the right dataset
data_path (str, optional): where data is stored/downloaded/created.
if no data_path given, creates one in your home/.ada
download (bool, optional): download (or generate for toy) the data.
n_fewshot (int, optional): Number of target samples for which the label may be used,
for batch sampling & train/val/test splits.
"""
self._data_config = data_config
if data_path is None:
self._data_path = f"{Path.home()}/.ada"
else:
self._data_path = data_path
self._download = download
self._n_fewshot = n_fewshot
self._long_name = self._data_config["dataset_name"]
os.makedirs(self._data_path, exist_ok=True)
[docs] def is_semi_supervised(self):
return self._n_fewshot is not None and self._n_fewshot > 0
[docs] def get_multi_domain_dataset(self, random_state):
self._create_dataset(random_state)
return self.domain_datasets
[docs] def get_data_args(self):
"""Returns dataset specific arguments necessary to build the network
first returned item is number of classes
second is a tuple of arguments to be passed to all network_factory functions.
Returns:
tuple: tuple containing:
- int: the number of classes in the dataset
- int or None: the input dimension
- int or None: the number of channels for images
"""
if self._data_config["dataset_group"] == "toy":
return (
self._data_config["cluster"]["n_clusters"],
self._data_config["cluster"]["dim"],
(),
)
if self._data_config["dataset_group"] == "digits":
return 10, 784, (self._num_channels,)
if self._data_config["dataset_group"] == "office31":
return 31, None, ()
[docs] def get_data_short_name(self):
return self._data_config["dataset_name"]
[docs] def get_data_long_name(self):
return self._long_name
[docs] def get_data_hash(self):
return xp.param_to_hash(self._data_config)
def _create_dataset(self, random_state):
random_state = check_random_state(random_state)
if self._data_config["dataset_group"] == "toy":
src, tgt = self._create_toy_access()
elif self._data_config["dataset_group"] == "digits":
src, tgt = self._create_digits_access()
elif self._data_config["dataset_group"] == "office31":
src, tgt = self._create_office31_access()
else:
raise NotImplementedError(
f"Unknown dataset type, you can need your own dataset here: {__file__}"
)
self._create_domain_dataset(src, tgt, random_state)
def _create_domain_dataset(self, source_access, target_access, random_state):
weight_type = WeightingType(self._data_config.get("weight_type", "natural"))
size_type = DatasetSizeType(self._data_config.get("size_type", "source"))
self._long_name = f"{self._long_name}_{weight_type.value}_{size_type.value}"
if weight_type is WeightingType.PRESET0:
source_sampling_config = SamplingConfig(
class_weights=np.arange(source_access.n_classes(), 0, -1)
)
target_sampling_config = SamplingConfig(
class_weights=random_state.randint(1, 4, size=target_access.n_classes())
)
elif weight_type is WeightingType.BALANCED:
source_sampling_config = SamplingConfig(balance=True)
target_sampling_config = SamplingConfig(balance=True)
elif weight_type not in WeightingType:
raise ValueError(f"Unknown weighting method {weight_type}.")
else:
source_sampling_config = SamplingConfig()
target_sampling_config = SamplingConfig()
self.domain_datasets = MultiDomainDatasets(
source_access=source_access,
target_access=target_access,
source_sampling_config=source_sampling_config,
target_sampling_config=target_sampling_config,
size_type=size_type,
n_fewshot=self._n_fewshot,
)
def _create_toy_access(self):
blob_args = self._data_config["cluster"]
shift_params = toys.get_datashift_params(**self._data_config["shift"])
self._long_name = (
f"blobs_{xp.param_to_str(blob_args)}_{xp.param_to_str(shift_params)}"
)
n_samples = self._data_config["n_samples"]
source_access = toys.CausalBlobsDataAccess(
data_path=self._data_path,
transform=toys.get_datashift_params("no_shift"),
download=self._download,
cluster_params=blob_args,
n_samples=n_samples,
)
target_access = toys.CausalBlobsDataAccess(
data_path=self._data_path,
transform=shift_params,
download=self._download,
cluster_params=blob_args,
n_samples=n_samples,
)
return source_access, target_access
def _create_digits_access(self):
(
source_access,
target_access,
self._num_channels,
) = digits.DigitDataset.get_accesses(
digits.DigitDataset(self._data_config["source"].upper()),
digits.DigitDataset(self._data_config["target"].upper()),
data_path=self._data_path,
)
return source_access, target_access
def _create_office31_access(self):
source_access, target_access = office.Office31Dataset.get_accesses(
office.Office31Dataset(self._data_config["source"].lower()),
office.Office31Dataset(self._data_config["target"].lower()),
data_path=self._data_path,
)
return source_access, target_access