Source code for ada.datasets.dataset_mnistm

"""
Dataset setting and data loader for MNIST-M.

Modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py

CREDIT: https://github.com/corenel
amt: changed train_data and test_data to data, and train_labels and test_labels to targets like MNIST
"""

from __future__ import print_function

import errno
import os
import logging

import torch
import torch.utils.data as data
from PIL import Image


[docs]class MNISTM(data.Dataset): """ MNIST-M Dataset. Auto-downloads the dataset and provide the torch Dataset API. Args: root (str): path to directory where the MNISTM folder will be created (or exists.) train (bool, optional): defaults to True. If True, loads the training data. Otherwise, loads the test data. transform (callable, optional): defaults to None. A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` This preprocessing function applied to all images (whether source or target) target_transform (callable, optional): default toNone, similar to transform. This preprocessing function applied to all target images, after `transform` download (bool optional): defaults to False. Whether to allow downloading the data if not found on disk. """ url = "https://github.com/VanushVaswani/keras_mnistm/releases/download/1.0/keras_mnistm.pkl.gz" raw_folder = "raw" processed_folder = "processed" training_file = "mnist_m_train.pt" test_file = "mnist_m_test.pt" def __init__( self, root, train=True, transform=None, target_transform=None, download=False, ): """Init MNIST-M dataset.""" super(MNISTM, self).__init__() self.root = os.path.join(root, "MNISTM") self.mnist_root = root self.transform = transform self.target_transform = target_transform self.train = train # training set or test set if download: self.download() if not self._check_exists(): raise RuntimeError( "Dataset not found." + " You can use download=True to download it" ) if self.train: self.data, self.targets = torch.load( os.path.join(self.root, self.processed_folder, self.training_file) ) else: self.data, self.targets = torch.load( os.path.join(self.root, self.processed_folder, self.test_file) ) def __getitem__(self, index): """Get images and target for data loader. Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img, target = self.data[index], self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img.squeeze().numpy(), mode="RGB") if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): """Return size of dataset.""" return len(self.data) def _check_exists(self): return os.path.exists( os.path.join(self.root, self.processed_folder, self.training_file) ) and os.path.exists( os.path.join(self.root, self.processed_folder, self.test_file) )
[docs] def download(self): """Download the MNISTM data.""" # import essential packages from six.moves import urllib import gzip import pickle from torchvision import datasets # check if dataset already exists if self._check_exists(): return # make data dirs try: os.makedirs(os.path.join(self.root, self.raw_folder)) os.makedirs(os.path.join(self.root, self.processed_folder)) except OSError as e: if e.errno == errno.EEXIST: pass else: raise # download pkl files logging.info("Downloading " + self.url) filename = self.url.rpartition("/")[2] file_path = os.path.join(self.root, self.raw_folder, filename) if not os.path.exists(file_path.replace(".gz", "")): data = urllib.request.urlopen(self.url) with open(file_path, "wb") as f: f.write(data.read()) with open(file_path.replace(".gz", ""), "wb") as out_f, gzip.GzipFile( file_path ) as zip_f: out_f.write(zip_f.read()) os.unlink(file_path) # process and save as torch files logging.info("Processing...") # load MNIST-M images from pkl file with open(file_path.replace(".gz", ""), "rb") as f: mnist_m_data = pickle.load(f, encoding="bytes") mnist_m_train_data = torch.ByteTensor(mnist_m_data[b"train"]) mnist_m_test_data = torch.ByteTensor(mnist_m_data[b"test"]) # get MNIST labels mnist_train_labels = datasets.MNIST( root=self.mnist_root, train=True, download=True ).targets mnist_test_labels = datasets.MNIST( root=self.mnist_root, train=False, download=True ).targets # save MNIST-M dataset training_set = (mnist_m_train_data, mnist_train_labels) test_set = (mnist_m_test_data, mnist_test_labels) with open( os.path.join(self.root, self.processed_folder, self.training_file), "wb" ) as f: torch.save(training_set, f) with open( os.path.join(self.root, self.processed_folder, self.test_file), "wb" ) as f: torch.save(test_set, f) logging.info("[DONE]")