Source code for ada.datasets.office_dataset_access
from enum import Enum
import ada.datasets.preprocessing as proc
from ada.datasets.dataset_access import DatasetAccess
from ada.datasets.dataset_office31 import Office31
[docs]class Office31Dataset(Enum):
Amazon = "amazon"
DSLR = "dslr"
Webcam = "webcam"
[docs] @staticmethod
def get_accesses(source: "Office31Dataset", target: "Office31Dataset", data_path):
return (
Office31DatasetAccess(source, data_path),
Office31DatasetAccess(target, data_path),
)
[docs]class Office31DatasetAccess(DatasetAccess):
def __init__(self, domain, data_path):
super().__init__(n_classes=31)
self._data_path = data_path
self._transform = proc.get_transform("office")
self._domain = domain.value
[docs] def get_train(self):
return Office31(
self._data_path,
domain=self._domain,
train=True,
transform=self._transform,
download=True,
)
[docs] def get_test(self):
return Office31(
self._data_path,
domain=self._domain,
train=False,
transform=self._transform,
download=True,
)