Source code for ada.datasets.dataset_usps
"""Dataset setting and data loader for USPS.
Modified from
https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py
amt: changed train_data and test_data to data, and train_labels and test_labels to targets like MNIST
"""
import gzip
import os
import pickle
import urllib
import logging
import numpy as np
import torch
import torch.utils.data as data
[docs]class USPS(data.Dataset):
"""USPS Dataset.
Args:
root (string): Root directory of dataset where dataset file exist.
train (bool, optional): If True, resample from dataset randomly.
download (bool, optional): If true, downloads the dataset
from the internet and puts it in root directory.
If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in
an PIL image and returns a transformed version.
E.g, ``transforms.RandomCrop``
"""
url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl"
def __init__(self, root, train=True, transform=None, download=False):
"""Init USPS dataset."""
# init params
self.root = os.path.expanduser(root)
self.filename = "usps_28x28.pkl"
self.train = train
# Num of Train = 7438, Num ot Test 1860
self.transform = transform
self.dataset_size = None
# download dataset.
if download:
self.download()
if not self._check_exists():
raise RuntimeError(
"Dataset not found." + " You can use download=True to download it"
)
self.data, self.targets = self.load_samples()
self.targets = torch.LongTensor(self.targets)
if self.train:
total_num_samples = self.data.shape[0]
indices = np.arange(total_num_samples)
np.random.shuffle(indices)
self.data = self.data[indices[0 : self.dataset_size], ::]
self.targets = self.targets[indices[0 : self.dataset_size]]
# self.train_data *= 255.0 # TODO check bug
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
def __getitem__(self, index):
"""Get images and target for data loader.
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, label = self.data[index, ::], self.targets[index]
if self.transform is not None:
img = self.transform(img)
label = torch.LongTensor([np.int64(label).item()])
# label = torch.FloatTensor([label.item()])
return img, label
def __len__(self):
"""Return size of dataset."""
return self.dataset_size
def _check_exists(self):
"""Check if dataset is download and in right place."""
return os.path.exists(os.path.join(self.root, self.filename))
[docs] def download(self):
"""Download dataset."""
filename = os.path.join(self.root, self.filename)
dirname = os.path.dirname(filename)
if not os.path.isdir(dirname):
os.makedirs(dirname)
if os.path.isfile(filename):
return
logging.info(f"Download {self.url} to {os.path.abspath(filename)}")
urllib.request.urlretrieve(self.url, filename)
logging.info("[DONE]")
return
[docs] def load_samples(self):
"""Load sample images from dataset."""
filename = os.path.join(self.root, self.filename)
f = gzip.open(filename, "rb")
data_set = pickle.load(f, encoding="bytes")
f.close()
if self.train:
images = data_set[0][0]
labels = data_set[0][1]
self.dataset_size = labels.shape[0]
else:
images = data_set[1][0]
labels = data_set[1][1]
self.dataset_size = labels.shape[0]
return images, labels