ADA: (Yet) Another Domain Adaptation library

Context

The aim of ADA is to help researchers build new methods for unsupervised and semi-supervised domain adaptation. The library is built on top of PyTorch-Lightning, enabling fast development of new models.

We built ADA with the idea of:

  • minimizing the boilerplate when developing a new method (loading data from several domains, logging errors, switching from CPU to GPU).
  • allowing fair comparison between methods by running all of them within the exact same environment.

You can find an introduction to ADA on medium.

Quick description

Methods from the main 3 groups of methods are available for unsupervised domain adaptation:

  • Adversarial methods: Domain-adversarial neural networks (DANN) and Conditional Adversarial Domain Adaptation networks (CDAN),
  • Optimal-Transport-based methods: Wasserstein distance guided representation learning (WDGRL), for which we propose two implementations, the second one being a variant better adapted to the PyTorch-Lightning, allowing for multi-GPU training.
  • MMD-based methods: Deep Adaptation Networks (DAN) and Joint Adaptation Networks (JAN)

All these methods are implemented in ada.models.architectures.

Adversarial and OT-based methods both rely on 3 networks:

  • a feature extractor network mapping inputs \(x\in\mathcal{X}\) to a latent space \(\mathcal{Z}\),
  • a task classifier network that learns to predict labels \(y \in \mathcal{Y}\) from latent vectors,
  • a domain classifier network that tries to predict whether samples in \(\mathcal{Z}\) come from the source or target domain.

MMD-based methods don’t make use of the critic network.

The full list of implemented algorithm with references can be found in the Algorithms implemented section.

Quick start

Take your favorite OS, terminal, environment and run:

git clone https://github.com/criteo-research/pytorch-ada.git

First you need to install the library. It has been tested with python 3.6+, with the latest versions of pytorch-lightning.

If you want to create a new conda environment, run:

conda env create -n adaenv python=3.7
conda activate adaenv

Install the library (with developer mode if you want to develop your own models later on, otherwise you can skip the -e):

pip install -e adalib

Note: on Windows, it could be necessary to first install pytorch and torchvision with conda:

conda install -c pytorch pytorch
conda install -c pytorch torchvision
pip install -e adalib

Run on of the scripts:

cd scripts
python run_simple.py

By default, this script launches experiments with all kinds of methods on a blobs dataset – it doesn’t take any parameter, you can change it easily from the script. It may take a few minutes to finish.

Most parameters are available and can be changed through configuration files, which are all grouped in the configs folder:

  • datasets
  • network layers and training parameters
  • methods (Source, DANN, CDAN…), and their specific parameters.

Advanced options

The script run_full_options.py runs the same kind of experiments allowing for more variants (semi-supervised, unbalanced, with gpus and MLFlow logging). You can run it without parameters or with -h to get help.

MLFlow

You can log results to MLFlow. Start a MLFlow server in another terminal:

conda activate adaenv
mlflow ui --port=31014

Streamlit application

Optionally, you can use the streamlit app. First install streamlit with

pip install streamlit

then launch the app like this:

streamlit run run_toys_app.py

This will start a web app with a default port = 8501. It should look like this in your browser:

_images/streamlit_screenshot.png

Checkout the Getting started page to get a more in-depth description of how you can use configuration files to run most of your experiments.

Benchmarks results

MNIST -> MNIST-M (5 runs)

Method source acc target acc
Source 89.0% +- 2.52 34.0% +- 1.71
DANN 94.2% +- 1.57 37.5% +- 2.85
CDAN 98.7% +- 0.19 68.4% +- 1.80
CDAN-E 98.7% +- 0.12 69.6% +- 1.51
DAN 98.0% +- 0.68 47.0% +- 1.85
JAN 96.4% +- 4.57 52.9% +- 2.16
WDGRL 93.9% +- 2.70 52.0% +- 4.82

MNIST -> USPS (5 runs)

Method source acc target acc
Source 99.2% +- 0.08 94.2% +- 1.07
DANN 99.1% +- 0.15 93.8% +- 1.06
CDAN 98.8% +- 0.17 90.7% +- 1.17
CDAN-E 98.9% +- 0.11 90.3% +- 0.98
DAN 99.0% +- 0.14 95.0% +- 0.83
JAN 98.6% +- 0.30 89.5% +- 2.00
WDGRL 98.7% +- 0.13 85.7% +- 6.57

See Benchmarks results for more complete benchmarks.

Contributing

Code

You can find the latest version on github. Before submitting code, please run black to have clean code formatting:

pip install black
black .

Documentation

First pip install sphinx, sphinx-paramlinks, recommonmark. Generate the documentation:

cd docs
sphinx-apidoc -o source/ ../adalib/ada ../scripts/
make html

Citing

If this library is useful for your research please cite:

@misc{adalib2020,
title={(Yet) Another Domain Adaptation library},
author={Tousch, Anne-Marie and Renaudin, Christophe},
url={https://github.com/criteo-research/pytorch-ada},
year={2020}
}