Source code for ada.datasets.dataset_office31
"""
Dataset setting and data loader for Office31.
See `Domain Adaptation Project at Berkeley <https://people.eecs.berkeley.edu/~jhoffman/domainadapt/#datasets_code>`_.
"""
import tarfile
import os
import logging
import urllib
import urllib.parse
import glob
from PIL import Image
import requests
import numpy as np
import torch
import torch.utils.data as data
from sklearn import preprocessing
[docs]class Office31(data.Dataset):
"""Office31 Domain Adaptation Dataset from the
`Domain Adaptation Project at Berkeley <https://people.eecs.berkeley.edu/~jhoffman/domainadapt/#datasets_code>`_.
Args:
root (string):
Root directory of dataset where dataset file exist.
train (bool, optional):
If True, resample from dataset randomly.
download (bool, optional):
If true, downloads the dataset
from the internet and puts it in root directory.
If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in
an PIL image and returns a transformed version.
E.g, ``transforms.RandomCrop``
"""
url = "https://docs.google.com/uc?export=download&id=0B4IapRTv9pJ1WGZVd1VDMmhwdlE"
def __init__(self, root, domain=None, train=True, transform=None, download=False):
"""Init Office31 dataset."""
# init params
self.root = os.path.expanduser(root)
self.filename = "domain_adaptation_images.tar.gz"
self.dirname = "Office31"
self.train = train
self.transform = transform
self.dataset_size = None
self.domain = domain
# download dataset.
if download:
self.download()
if not self._check_exists():
raise RuntimeError(
"Dataset not found." + " You can use download=True to download it"
)
self.labeler = preprocessing.LabelEncoder()
self.data, self.targets = self.load_samples()
self.targets = torch.LongTensor(self.targets)
def __getitem__(self, index):
"""Get images and target for data loader.
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
path = os.path.join(self.root, self.dirname, self.domain, self.data[index])
img = None
with open(path, "rb") as f:
with Image.open(f) as imgf:
img = imgf.convert("RGB")
label = self.targets[index]
if self.transform is not None:
img = self.transform(img)
label = torch.LongTensor([np.int64(label).item()])
# label = torch.FloatTensor([label.item()])
return img, label
def __len__(self):
"""Return size of dataset."""
return self.dataset_size
def _check_exists(self):
"""Check if dataset is download and in right place."""
return os.path.exists(os.path.join(self.root, self.filename))
[docs] def download(self):
"""Download dataset."""
filename = os.path.join(self.root, self.filename)
dirname = os.path.join(self.root, self.dirname)
if not os.path.exists(filename):
logging.info("Downloading " + self.url)
with requests.Session() as session:
resp = session.head(self.url)
confirm = None
for key, value in resp.cookies.items():
if "download_warning" in key:
confirm = value
break
if confirm is None:
raise RuntimeError("Could not find 'download_warning' in cookies")
resp = session.get(f"{self.url}&confirm={urllib.parse.quote(confirm)}")
with open(filename, "wb") as f:
f.write(resp.content)
os.makedirs(dirname, exist_ok=True)
logging.info("Extracting files to " + dirname)
with tarfile.open(filename, "r:gz") as tar:
tar.extractall(path=dirname)
logging.info("[DONE]")
[docs] def load_samples(self):
"""Load sample images from dataset."""
imgdir = os.path.join(self.root, self.dirname, self.domain, "images")
image_list = glob.glob(f"{imgdir}/*/*.jpg")
if len(image_list) == 0:
raise RuntimeError("Offce31 dataset is empty. Maybe it was not downloaded.")
labels = [os.path.split(os.path.split(p)[0])[-1] for p in image_list]
labels = self.labeler.fit_transform(labels)
n_total = len(image_list)
n_test = int(0.1 * n_total)
indices = np.arange(n_total)
rg = np.random.RandomState(seed=128753)
rg.shuffle(indices)
train_indices = indices[:-n_test]
test_indices = indices[-n_test:]
if self.train:
images = np.array(image_list)[train_indices].tolist()
labels = np.array(labels)[train_indices].tolist()
self.dataset_size = len(images)
self.labeler
else:
images = np.array(image_list)[test_indices].tolist()
labels = np.array(labels)[test_indices].tolist()
self.dataset_size = len(images)
return images, labels