Source code for ada.datasets.dataset_usps

"""Dataset setting and data loader for USPS.
Modified from
https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py

amt: changed train_data and test_data to data, and train_labels and test_labels to targets like MNIST
"""
import gzip
import os
import pickle
import urllib
import logging

import numpy as np
import torch
import torch.utils.data as data


[docs]class USPS(data.Dataset): """USPS Dataset. Args: root (string): Root directory of dataset where dataset file exist. train (bool, optional): If True, resample from dataset randomly. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl" def __init__(self, root, train=True, transform=None, download=False): """Init USPS dataset.""" # init params self.root = os.path.expanduser(root) self.filename = "usps_28x28.pkl" self.train = train # Num of Train = 7438, Num ot Test 1860 self.transform = transform self.dataset_size = None # download dataset. if download: self.download() if not self._check_exists(): raise RuntimeError( "Dataset not found." + " You can use download=True to download it" ) self.data, self.targets = self.load_samples() self.targets = torch.LongTensor(self.targets) if self.train: total_num_samples = self.data.shape[0] indices = np.arange(total_num_samples) np.random.shuffle(indices) self.data = self.data[indices[0 : self.dataset_size], ::] self.targets = self.targets[indices[0 : self.dataset_size]] # self.train_data *= 255.0 # TODO check bug self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC def __getitem__(self, index): """Get images and target for data loader. Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img, label = self.data[index, ::], self.targets[index] if self.transform is not None: img = self.transform(img) label = torch.LongTensor([np.int64(label).item()]) # label = torch.FloatTensor([label.item()]) return img, label def __len__(self): """Return size of dataset.""" return self.dataset_size def _check_exists(self): """Check if dataset is download and in right place.""" return os.path.exists(os.path.join(self.root, self.filename))
[docs] def download(self): """Download dataset.""" filename = os.path.join(self.root, self.filename) dirname = os.path.dirname(filename) if not os.path.isdir(dirname): os.makedirs(dirname) if os.path.isfile(filename): return logging.info(f"Download {self.url} to {os.path.abspath(filename)}") urllib.request.urlretrieve(self.url, filename) logging.info("[DONE]") return
[docs] def load_samples(self): """Load sample images from dataset.""" filename = os.path.join(self.root, self.filename) f = gzip.open(filename, "rb") data_set = pickle.load(f, encoding="bytes") f.close() if self.train: images = data_set[0][0] labels = data_set[0][1] self.dataset_size = labels.shape[0] else: images = data_set[1][0] labels = data_set[1][1] self.dataset_size = labels.shape[0] return images, labels