"""
Domain adaptation architectures should be indifferent to the task at hand: digits, toy datasets, recsys etc.
Take modules as input and organise them into an architecture.
"""
from enum import Enum
import numpy as np
import torch
from ada.models.layers import ReverseLayerF
import ada.models.losses as losses
import pytorch_lightning as pl
[docs]def set_requires_grad(model, requires_grad=True):
for param in model.parameters():
param.requires_grad = requires_grad
[docs]def get_aggregated_metrics(metric_name_list, metric_outputs):
metric_dict = {}
for metric_name in metric_name_list:
metric_dim = len(metric_outputs[0][metric_name].shape)
if metric_dim == 0:
metric_value = torch.stack([x[metric_name] for x in metric_outputs]).mean()
else:
metric_value = (
torch.cat([x[metric_name] for x in metric_outputs]).double().mean()
)
metric_dict[metric_name] = metric_value.item()
return metric_dict
[docs]def get_aggregated_metrics_from_dict(input_metric_dict):
metric_dict = {}
for metric_name, metric_value in input_metric_dict.items():
metric_dim = len(metric_value.shape)
if metric_dim == 0:
metric_dict[metric_name] = metric_value
else:
metric_dict[metric_name] = metric_value.double().mean()
return metric_dict
# multi-GPUs: mandatory to convert float values into tensors
[docs]def get_metrics_from_parameter_dict(parameter_dict, device):
return {k: torch.tensor(v, device=device) for k, v in parameter_dict.items()}
[docs]class Method(Enum):
"""
Lists the available methods.
Provides a few methods that group the methods by type.
"""
Source = "Source"
DANN = "DANN"
CDAN = "CDAN"
CDAN_E = "CDAN-E"
FSDANN = "FSDANN"
MME = "MME"
WDGRL = "WDGRL" # Wasserstein Distance Guided Representation Learning
WDGRLMod = "WDGRLMod"
DAN = "DAN" # Deep Adaptation Networks
JAN = "JAN" # Joint Adaptation Networks
[docs] def is_mmd_method(self):
return self in (Method.DAN, Method.JAN)
[docs] def is_dann_method(self):
return self in (Method.DANN, Method.Source)
[docs] def is_cdan_method(self):
return self is Method.CDAN or self is Method.CDAN_E
[docs] def is_fewshot_method(self):
return self in (Method.FSDANN, Method.MME, Method.Source)
[docs] def allow_supervised(self):
return self.is_fewshot_method()
[docs]def create_mmd_based(
method: Method, dataset, feature_extractor, task_classifier, **train_params
):
if not method.is_mmd_method():
raise ValueError(f"Unsupported MMD method: {method}")
if method is Method.DAN:
return DANtrainer(
dataset, feature_extractor, task_classifier, method=method, **train_params
)
if method is Method.JAN:
return JANtrainer(
dataset,
feature_extractor,
task_classifier,
method=method,
kernel_mul=[2.0, 2.0],
kernel_num=[5, 1],
**train_params,
)
[docs]def create_dann_like(
method: Method, dataset, feature_extractor, task_classifier, critic, **train_params
):
if dataset.is_semi_supervised():
return create_fewshot_trainer(
method, dataset, feature_extractor, task_classifier, critic, **train_params
)
if method.is_dann_method():
alpha = 0 if method is Method.Source else 1
return DANNtrainer(
alpha=alpha,
dataset=dataset,
feature_extractor=feature_extractor,
task_classifier=task_classifier,
critic=critic,
method=method,
**train_params,
)
elif method.is_cdan_method():
return CDANtrainer(
dataset=dataset,
feature_extractor=feature_extractor,
task_classifier=task_classifier,
critic=critic,
method=method,
use_entropy=method is Method.CDAN_E,
**train_params,
)
elif method is Method.WDGRL:
return WDGRLtrainer(
dataset=dataset,
feature_extractor=feature_extractor,
task_classifier=task_classifier,
critic=critic,
method=method,
**train_params,
)
elif method is Method.WDGRLMod:
return WDGRLtrainerMod(
dataset=dataset,
feature_extractor=feature_extractor,
task_classifier=task_classifier,
critic=critic,
method=method,
**train_params,
)
else:
raise ValueError(f"Unsupported method: {method}")
[docs]def create_fewshot_trainer(
method: Method, dataset, feature_extractor, task_classifier, critic, **train_params
):
if not dataset.is_semi_supervised():
raise ValueError(f"Dataset must be semi-supervised for few-shot methods.")
if method.is_fewshot_method():
alpha = 0 if method is Method.Source else 1
return FewShotDANNtrainer(
alpha=alpha,
dataset=dataset,
feature_extractor=feature_extractor,
task_classifier=task_classifier,
critic=critic,
method=method,
**train_params,
)
else:
raise ValueError(f"Unsupported semi-supervised method: {method}")
[docs]class BaseAdaptTrainer(pl.LightningModule):
def __init__(
self,
dataset,
feature_extractor,
task_classifier,
method=None,
lambda_init=1.0,
adapt_lambda=True,
adapt_lr=True,
nb_init_epochs=10,
nb_adapt_epochs=50,
batch_size=32,
init_lr=1e-3,
optimizer=None,
):
"""Base class for all domain adaptation architectures.
This class implements the classic building blocks used in all the derived architectures
for domain adaptation.
If you inherit from this class, you will have to implement only:
- a forward pass
- a `compute_loss` function that returns the task loss :math:`\mathcal{L}_c` and adaptation loss :math:`\mathcal{L}_a`, as well as
a dictionary for summary statistics and other metrics you may want to have access to.
The default training step uses only the task loss :math:`\mathcal{L}_c` during warmup,
the uses the loss defined as:
:math:`\mathcal{L} = \mathcal{L}_c + \lambda \mathcal{L}_a`,
where :math:`\lambda` will follow the schedule defined by the DANN paper:
:math:`\lambda_p = \frac{2}{1 + \exp{(-\gamma \cdot p)}} - 1` where $p$ the learning progress
changes linearly from 0 to 1.
Args:
dataset (ada.datasets.MultiDomainDatasets): the multi-domain datasets to be used
for train, validation, and tests.
feature_extractor (torch.nn.Module): the feature extractor network (mapping inputs :math:`x\in\mathcal{X}` to
a latent space :math:`\mathcal{Z}`,)
task_classifier (torch.nn.Module): the task classifier network that learns to predict labels
:math:`y \in \mathcal{Y}` from latent vectors,
method (Method, optional): the method implemented by the class. Defaults to None.
Mostly useful when several methods may be implemented using the same class.
lambda_init (float, optional): Weight attributed to the adaptation part of the loss. Defaults to 1.0.
adapt_lambda (bool, optional): Whether to make lambda grow from 0 to 1 following the schedule from
the DANN paper. Defaults to True.
adapt_lr (bool, optional): Whether to use the schedule for the learning rate as defined
in the DANN paper. Defaults to True.
nb_init_epochs (int, optional): Number of warmup epochs (during which lambda=0, training only on the source). Defaults to 10.
nb_adapt_epochs (int, optional): Number of training epochs. Defaults to 50.
batch_size (int, optional): Defaults to 32.
init_lr (float, optional): Initial learning rate. Defaults to 1e-3.
optimizer (dict, optional): Optimizer parameters, a dictionary with 2 keys:
"type": a string in ("SGD", "Adam", "AdamW")
"optim_params": kwargs for the above PyTorch optimizer.
Defaults to None.
"""
super().__init__()
self._method = method
self._init_lambda = lambda_init
self.lamb_da = lambda_init
self._adapt_lambda = adapt_lambda
self._adapt_lr = adapt_lr
self._init_epochs = nb_init_epochs
self._non_init_epochs = nb_adapt_epochs - self._init_epochs
assert self._non_init_epochs > 0
self._batch_size = batch_size
self._init_lr = init_lr
self._lr_fact = 1.0
self._grow_fact = 0.0
self._dataset = dataset
self.feat = feature_extractor
self.classifier = task_classifier
self._dataset.prepare_data_loaders()
self._nb_training_batches = None # to be set by method train_dataloader
self._optimizer_params = optimizer
@property
def method(self):
return self._method
def _update_batch_epoch_factors(self, batch_id):
if self.current_epoch >= self._init_epochs:
delta_epoch = self.current_epoch - self._init_epochs
p = (batch_id + delta_epoch * self._nb_training_batches) / (
self._non_init_epochs * self._nb_training_batches
)
self._grow_fact = 2.0 / (1.0 + np.exp(-10 * p)) - 1
if self._adapt_lr:
self._lr_fact = 1.0 / ((1.0 + 10 * p) ** 0.75)
if self._adapt_lambda:
self.lamb_da = self._init_lambda * self._grow_fact
[docs] def get_parameters_watch_list(self):
"""
Update this list for parameters to watch while training (ie log with MLFlow)
"""
return {
"lambda": self.lamb_da,
"last_epoch": self.current_epoch,
}
[docs] def forward(self, x):
raise NotImplementedError("Forward pass needs to be defined.")
[docs] def compute_loss(self, batch, split_name="V"):
"""Define the loss of the model
Args:
batch (tuple): batches returned by the MultiDomainLoader.
split_name (str, optional): learning stage (one of ["T", "V", "Te"]).
Defaults to "V" for validation. "T" is for training and "Te" for test.
This is currently used only for naming the metrics used for logging.
Returns:
a 3-element tuple with task_loss, adv_loss, log_metrics.
log_metrics should be a dictionary.
Raises:
NotImplementedError: children of this classes should implement this method.
"""
raise NotImplementedError("Loss needs to be defined.")
[docs] def training_step(self, batch, batch_nb):
"""The most generic of training steps
Args:
batch (tuple): the batch as returned by the MultiDomainLoader dataloader iterator:
2 tuples: (x_source, y_source), (x_target, y_target) in the unsupervised setting
3 tuples: (x_source, y_source), (x_target_labeled, y_target_labeled), (x_target_unlabeled, y_target_unlabeled) in the semi-supervised setting
batch_nb (int): id of the current batch.
Returns:
dict: must contain a "loss" key with the loss to be used for back-propagation.
see pytorch-lightning for more details.
"""
self._update_batch_epoch_factors(batch_nb)
task_loss, adv_loss, log_metrics = self.compute_loss(batch, split_name="T")
if self.current_epoch < self._init_epochs:
# init phase doesn't use few-shot learning
# ad-hoc decision but makes models more comparable between each other
loss = task_loss
else:
loss = task_loss + self.lamb_da * adv_loss
log_metrics = get_aggregated_metrics_from_dict(log_metrics)
log_metrics.update(
get_metrics_from_parameter_dict(
self.get_parameters_watch_list(), loss.device
)
)
log_metrics["T_total_loss"] = loss
log_metrics["T_task_loss"] = task_loss
return {
"loss": loss, # required, for backward pass
"progress_bar": {"class_loss": task_loss},
"log": log_metrics,
}
[docs] def validation_step(self, batch, batch_nb):
task_loss, adv_loss, log_metrics = self.compute_loss(batch, split_name="V")
loss = task_loss + self.lamb_da * adv_loss
log_metrics["val_loss"] = loss
return log_metrics
def _validation_epoch_end(self, outputs, metrics_at_valid):
log_dict = get_aggregated_metrics(metrics_at_valid, outputs)
device = outputs[0].get("val_loss").device
log_dict.update(
get_metrics_from_parameter_dict(self.get_parameters_watch_list(), device)
)
avg_loss = log_dict["val_loss"]
return {
"val_loss": avg_loss, # for callbacks (eg early stopping)
"progress_bar": {"val_loss": avg_loss},
"log": log_dict,
}
[docs] def validation_epoch_end(self, outputs):
metrics_to_log = (
"val_loss",
"V_source_acc",
"V_target_acc",
)
return self._validation_epoch_end(outputs, metrics_to_log)
[docs] def test_step(self, batch, batch_nb):
task_loss, adv_loss, log_metrics = self.compute_loss(batch, split_name="Te")
loss = task_loss + self.lamb_da * adv_loss
log_metrics["test_loss"] = loss
return log_metrics
[docs] def test_epoch_end(self, outputs):
metrics_at_test = (
"test_loss",
"Te_source_acc",
"Te_target_acc",
)
log_dict = get_aggregated_metrics(metrics_at_test, outputs)
return {
"avg_test_loss": log_dict["test_loss"],
"progress_bar": log_dict,
"log": log_dict,
}
def _configure_optimizer(self, parameters):
if self._optimizer_params is None:
optimizer = torch.optim.Adam(
parameters,
lr=self._init_lr,
betas=(0.8, 0.999),
weight_decay=1e-5,
)
return [optimizer]
if self._optimizer_params["type"] == "Adam":
optimizer = torch.optim.Adam(
parameters,
lr=self._init_lr,
**self._optimizer_params["optim_params"],
)
return [optimizer]
if self._optimizer_params["type"] == "SGD":
optimizer = torch.optim.SGD(
parameters,
lr=self._init_lr,
**self._optimizer_params["optim_params"],
)
if self._adapt_lr:
feature_sched = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda epoch: self._lr_fact
)
return [optimizer], [feature_sched]
return [optimizer]
raise NotImplementedError(
f"Unknown optimizer type {self._optimizer_params['type']}"
)
[docs] def train_dataloader(self):
dataloader = self._dataset.get_domain_loaders(
split="train", batch_size=self._batch_size
)
self._nb_training_batches = len(dataloader)
return dataloader
[docs] def val_dataloader(self):
return self._dataset.get_domain_loaders(
split="valid", batch_size=self._batch_size
)
[docs] def test_dataloader(self):
return self._dataset.get_domain_loaders(
split="test", batch_size=self._batch_size
)
[docs]class BaseDANNLike(BaseAdaptTrainer):
def __init__(
self,
dataset,
feature_extractor,
task_classifier,
critic,
alpha=1.0,
entropy_reg=0.0, # not used
adapt_reg=True, # not used
batch_reweighting=False, # not used
**base_params,
):
super().__init__(dataset, feature_extractor, task_classifier, **base_params)
self.alpha = alpha
self._entropy_reg_init = entropy_reg # not used
self._entropy_reg = entropy_reg # not used
self._adapt_reg = adapt_reg # not used
self._reweight_beta = 4 # not used
self._do_dynamic_batch_weight = batch_reweighting # not used
self.domain_classifier = critic
[docs] def get_parameters_watch_list(self):
"""
Update this list for parameters to watch while training (ie log with MLFlow)
"""
param_list = super().get_parameters_watch_list()
param_list.update({"alpha": self.alpha, "entropy_reg": self._entropy_reg})
return param_list
def _update_batch_epoch_factors(self, batch_id):
super()._update_batch_epoch_factors(batch_id)
if self._adapt_reg:
self._entropy_reg = self._entropy_reg_init * self._grow_fact
[docs] def compute_loss(self, batch, split_name="V"):
if len(batch) == 3:
raise NotImplementedError("DANN does not support semi-supervised setting.")
(x_s, y_s), (x_tu, y_tu) = batch
batch_size = len(y_s)
_, y_hat, d_hat = self.forward(x_s)
_, y_t_hat, d_t_hat = self.forward(x_tu)
loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s)
_, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu)
loss_dmn_src, dok_src = losses.cross_entropy_logits(
d_hat, torch.zeros(batch_size)
)
loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits(
d_t_hat, torch.ones(len(d_t_hat))
)
adv_loss = loss_dmn_src + loss_dmn_tgt
task_loss = loss_cls
log_metrics = {
f"{split_name}_source_acc": ok_src,
f"{split_name}_target_acc": ok_tgt,
f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)),
f"{split_name}_source_domain_acc": dok_src,
f"{split_name}_target_domain_acc": dok_tgt,
}
return task_loss, adv_loss, log_metrics
[docs] def validation_epoch_end(self, outputs):
metrics_to_log = (
"val_loss",
"V_source_acc",
"V_target_acc",
"V_source_domain_acc",
"V_target_domain_acc",
"V_domain_acc",
)
return self._validation_epoch_end(outputs, metrics_to_log)
[docs] def test_epoch_end(self, outputs):
metrics_at_test = (
"test_loss",
"Te_source_acc",
"Te_target_acc",
"Te_domain_acc",
)
log_dict = get_aggregated_metrics(metrics_at_test, outputs)
return {
"avg_test_loss": log_dict["test_loss"],
"progress_bar": log_dict,
"log": log_dict,
}
[docs]class DANNtrainer(BaseDANNLike):
"""
This class implements the DANN architecture from
Ganin, Yaroslav, et al.
"Domain-adversarial training of neural networks."
The Journal of Machine Learning Research (2016)
https://arxiv.org/abs/1505.07818
"""
def __init__(
self,
dataset,
feature_extractor,
task_classifier,
critic,
method=None,
**base_params,
):
super().__init__(
dataset, feature_extractor, task_classifier, critic, **base_params
)
if method is None:
self._method = Method.DANN
else:
self._method = Method(method)
assert self._method.is_dann_method()
[docs] def forward(self, x):
if self.feat is not None:
x = self.feat(x)
feature = x.view(x.size(0), -1)
reverse_feature = ReverseLayerF.apply(feature, self.alpha)
class_output = self.classifier(feature)
adversarial_output = self.domain_classifier(reverse_feature)
return feature, class_output, adversarial_output
[docs]class CDANtrainer(BaseDANNLike):
"""
Implements CDAN: Long, Mingsheng, et al. "Conditional adversarial domain adaptation."
Advances in Neural Information Processing Systems. 2018.
https://papers.nips.cc/paper/7436-conditional-adversarial-domain-adaptation.pdf
"""
def __init__(
self,
dataset,
feature_extractor,
task_classifier,
critic,
use_entropy=False,
use_random=False,
random_dim=1024,
**base_params,
):
super().__init__(
dataset, feature_extractor, task_classifier, critic, **base_params
)
self.random_layer = None
self.random_dim = random_dim
self.entropy = use_entropy
if use_random:
nb_inputs = self.feat.output_size() * self.classifier.n_classes()
self.random_layer = torch.nn.Linear(
in_features=nb_inputs, out_features=self.random_dim, bias=False
)
torch.nn.init.normal_(self.random_layer.weight, mean=0, std=1)
for param in self.random_layer.parameters():
param.requires_grad = False
[docs] def forward(self, x):
if self.feat is not None:
x = self.feat(x)
x = x.view(x.size(0), -1)
class_output = self.classifier(x)
# The GRL hook is applied to all inputs to the adversary
reverse_feature = ReverseLayerF.apply(x, self.alpha)
softmax_output = torch.nn.Softmax(dim=1)(class_output)
reverse_out = ReverseLayerF.apply(softmax_output, self.alpha)
feature = torch.bmm(reverse_out.unsqueeze(2), reverse_feature.unsqueeze(1))
feature = feature.view(-1, reverse_out.size(1) * reverse_feature.size(1))
if self.random_layer:
random_out = self.random_layer.forward(feature)
adversarial_output = self.domain_classifier(
random_out.view(-1, random_out.size(1))
)
else:
adversarial_output = self.domain_classifier(feature)
return x, class_output, adversarial_output
def _compute_entropy_weights(self, logits):
entropy = losses.entropy_logits(logits)
entropy = ReverseLayerF.apply(entropy, self.alpha)
entropy_w = 1.0 + torch.exp(-entropy)
return entropy_w
[docs] def compute_loss(self, batch, split_name="V"):
if len(batch) == 3:
raise NotImplementedError("CDAN does not support semi-supervised setting.")
(x_s, y_s), (x_tu, y_tu) = batch
batch_size = len(y_s)
_, y_hat, d_hat = self.forward(x_s)
_, y_t_hat, d_t_hat = self.forward(x_tu)
loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s)
_, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu)
if self.entropy:
e_s = self._compute_entropy_weights(y_hat)
e_t = self._compute_entropy_weights(y_t_hat)
source_weight = e_s / torch.sum(e_s)
target_weight = e_t / torch.sum(e_t)
else:
source_weight = None
target_weight = None
loss_dmn_src, dok_src = losses.cross_entropy_logits(
d_hat, torch.zeros(batch_size), source_weight
)
loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits(
d_t_hat, torch.ones(len(d_t_hat)), target_weight
)
adv_loss = loss_dmn_src + loss_dmn_tgt
task_loss = loss_cls
log_metrics = {
f"{split_name}_source_acc": ok_src,
f"{split_name}_target_acc": ok_tgt,
f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)),
f"{split_name}_source_domain_acc": dok_src,
f"{split_name}_target_domain_acc": dok_tgt,
}
return task_loss, adv_loss, log_metrics
[docs]class WDGRLtrainer(BaseDANNLike):
"""
Implements WDGRL as described in
Shen, Jian, et al.
"Wasserstein distance guided representation learning for domain adaptation."
Thirty-Second AAAI Conference on Artificial Intelligence. 2018.
https://arxiv.org/pdf/1707.01217.pdf
This class also implements the asymmetric ($\beta$) variant described in:
Wu, Yifan, et al.
"Domain adaptation with asymmetrically-relaxed distribution alignment."
ICML (2019)
https://arxiv.org/pdf/1903.01689.pdf
"""
def __init__(
self,
dataset,
feature_extractor,
task_classifier,
critic,
k_critic=5,
gamma=10,
beta_ratio=0,
**base_params,
):
"""
parameters:
k_critic: number of steps to train critic (called n in Algorithm 1 of the paper)
"""
super().__init__(
dataset, feature_extractor, task_classifier, critic, **base_params
)
self._k_critic = k_critic
self._beta_ratio = beta_ratio
self._gamma = gamma
[docs] def forward(self, x):
if self.feat is not None:
x = self.feat(x)
x = x.view(x.size(0), -1)
class_output = self.classifier(x)
adversarial_output = self.domain_classifier(x)
return x, class_output, adversarial_output
[docs] def compute_loss(self, batch, split_name="V"):
if len(batch) == 3:
raise NotImplementedError("WDGRL does not support semi-supervised setting.")
(x_s, y_s), (x_tu, y_tu) = batch
batch_size = len(y_s)
_, y_hat, d_hat = self.forward(x_s)
_, y_t_hat, d_t_hat = self.forward(x_tu)
loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s)
_, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu)
_, dok_src = losses.cross_entropy_logits(d_hat, torch.zeros(batch_size))
_, dok_tgt = losses.cross_entropy_logits(d_t_hat, torch.ones(len(d_t_hat)))
wasserstein_distance = d_hat.mean() - (1 + self._beta_ratio) * d_t_hat.mean()
adv_loss = wasserstein_distance
task_loss = loss_cls
log_metrics = {
f"{split_name}_source_acc": ok_src,
f"{split_name}_target_acc": ok_tgt,
f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)),
f"{split_name}_source_domain_acc": dok_src,
f"{split_name}_target_domain_acc": dok_tgt,
f"{split_name}_wasserstein_dist": wasserstein_distance,
}
return task_loss, adv_loss, log_metrics
[docs] def critic_update_steps(self, batch):
if self.current_epoch < self._init_epochs:
return
set_requires_grad(self.feat, requires_grad=False)
set_requires_grad(self.domain_classifier, requires_grad=True)
(x_s, y_s), (x_tu, _) = batch
with torch.no_grad():
h_s = self.feat(x_s).data.view(x_s.shape[0], -1)
h_t = self.feat(x_tu).data.view(x_tu.shape[0], -1)
for _ in range(self._k_critic):
gp = losses.gradient_penalty(self.domain_classifier, h_s, h_t)
critic_s = self.domain_classifier(h_s)
critic_t = self.domain_classifier(h_t)
wasserstein_distance = (
critic_s.mean() - (1 + self._beta_ratio) * critic_t.mean()
)
critic_cost = -wasserstein_distance + self._gamma * gp
self.critic_opt.zero_grad()
critic_cost.backward()
self.critic_opt.step()
if self.critic_sched:
self.critic_sched.step()
set_requires_grad(self.feat, requires_grad=True)
set_requires_grad(self.domain_classifier, requires_grad=False)
[docs] def training_step(self, batch, batch_id):
self._update_batch_epoch_factors(batch_id)
self.critic_update_steps(batch)
task_loss, adv_loss, log_metrics = self.compute_loss(batch, split_name="T")
if self.current_epoch < self._init_epochs:
# init phase doesn't use few-shot learning
# ad-hoc decision but makes models more comparable between each other
loss = task_loss
else:
loss = task_loss + self.lamb_da * adv_loss
log_metrics = get_aggregated_metrics_from_dict(log_metrics)
log_metrics.update(
get_metrics_from_parameter_dict(
self.get_parameters_watch_list(), loss.device
)
)
log_metrics["T_total_loss"] = loss
log_metrics["T_task_loss"] = task_loss
return {
"loss": loss, # required, for backward pass
"progress_bar": {"class_loss": task_loss},
"log": log_metrics,
}
[docs]class WDGRLtrainerMod(WDGRLtrainer):
"""
Implements a modified version WDGRL as described in
Shen, Jian, et al.
"Wasserstein distance guided representation learning for domain adaptation."
Thirty-Second AAAI Conference on Artificial Intelligence. 2018.
https://arxiv.org/pdf/1707.01217.pdf
This class also implements the asymmetric ($\beta$) variant described in:
Wu, Yifan, et al.
"Domain adaptation with asymmetrically-relaxed distribution alignment."
ICML (2019)
https://arxiv.org/pdf/1903.01689.pdf
"""
def __init__(
self,
dataset,
feature_extractor,
task_classifier,
critic,
k_critic=5,
gamma=10,
beta_ratio=0,
**base_params,
):
"""
parameters:
k_critic: number of steps to train critic (called n in Algorithm 1 of the paper)
"""
super().__init__(
dataset, feature_extractor, task_classifier, critic, **base_params
)
self._k_critic = k_critic
self._beta_ratio = beta_ratio
self._gamma = gamma
[docs] def critic_update_steps(self, batch):
(x_s, y_s), (x_tu, _) = batch
with torch.no_grad():
h_s = self.feat(x_s).data.view(x_s.shape[0], -1)
h_t = self.feat(x_tu).data.view(x_tu.shape[0], -1)
gp = losses.gradient_penalty(self.domain_classifier, h_s, h_t)
critic_s = self.domain_classifier(h_s)
critic_t = self.domain_classifier(h_t)
wasserstein_distance = (
critic_s.mean() - (1 + self._beta_ratio) * critic_t.mean()
)
critic_cost = -wasserstein_distance + self._gamma * gp
log_metrics = {"T_critic_loss": critic_cost}
return {
"loss": critic_cost, # required, for backward pass
"progress_bar": {"critic loss": critic_cost},
"log": log_metrics,
}
[docs] def training_step(self, batch, batch_id, optimizer_idx):
self._update_batch_epoch_factors(batch_id)
if optimizer_idx == 0:
return self.critic_update_steps(batch)
task_loss, adv_loss, log_metrics = self.compute_loss(batch, split_name="T")
if self.current_epoch < self._init_epochs:
# init phase doesn't use few-shot learning
# ad-hoc decision but makes models more comparable between each other
loss = task_loss
else:
loss = task_loss + self.lamb_da * adv_loss
log_metrics = get_aggregated_metrics_from_dict(log_metrics)
log_metrics.update(
get_metrics_from_parameter_dict(
self.get_parameters_watch_list(), loss.device
)
)
log_metrics["T_total_loss"] = loss
log_metrics["T_task_loss"] = task_loss
return {
"loss": loss, # required, for backward pass
"progress_bar": {"class_loss": task_loss},
"log": log_metrics,
}
[docs] def optimizer_step(
self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None
):
if current_epoch < self._init_epochs:
# do not update critic
if optimizer_i == 0:
pass
if optimizer_i == 1:
optimizer.step()
optimizer.zero_grad()
else:
if optimizer_i == 0:
optimizer.step()
optimizer.zero_grad()
# update discriminator opt every k_critic steps
if optimizer_i == 1:
if (batch_nb + 1) % self._k_critic == 0:
optimizer.step()
optimizer.zero_grad()
[docs]class FewShotDANNtrainer(BaseDANNLike):
"""Implements adaptations of DANN to the semi-supervised setting
naive: task classifier is trained on labeled target data, in addition to source
data.
MME: immplements Saito, Kuniaki, et al.
"Semi-supervised domain adaptation via minimax entropy."
Proceedings of the IEEE International Conference on Computer Vision. 2019
https://arxiv.org/pdf/1904.06487.pdf
"""
def __init__(
self, dataset, feature_extractor, task_classifier, critic, method, **base_params
):
super().__init__(
dataset, feature_extractor, task_classifier, critic, **base_params
)
self._method = Method(method)
[docs] def forward(self, x):
if self.feat is not None:
x = self.feat(x)
x = x.view(x.size(0), -1)
reverse_feature = ReverseLayerF.apply(x, self.alpha)
class_output = self.classifier(x)
adversarial_output = self.domain_classifier(reverse_feature)
return x, class_output, adversarial_output
[docs] def compute_loss(self, batch, split_name="V"):
assert len(batch) == 3
(x_s, y_s), (x_tl, y_tl), (x_tu, y_tu) = batch
batch_size = len(y_s)
_, y_hat, d_hat = self.forward(x_s)
_, y_tl_hat, d_tl_hat = self.forward(x_tl)
_, y_tu_hat, d_tu_hat = self.forward(x_tu)
d_target_pred = torch.cat((d_tl_hat, d_tu_hat))
loss_cls_s, ok_src = losses.cross_entropy_logits(y_hat, y_s)
loss_cls_tl, ok_tl = losses.cross_entropy_logits(y_tl_hat, y_tl)
_, ok_tu = losses.cross_entropy_logits(y_tu_hat, y_tu)
ok_tgt = torch.cat((ok_tl, ok_tu))
if self.current_epoch < self._init_epochs:
# init phase doesn't use few-shot learning
# ad-hoc decision but makes models more comparable between each other
task_loss = loss_cls_s
else:
task_loss = (batch_size * loss_cls_s + len(y_tl) * loss_cls_tl) / (
batch_size + len(y_tl)
)
loss_dmn_src, dok_src = losses.cross_entropy_logits(
d_hat, torch.zeros(batch_size)
)
loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits(
d_target_pred, torch.ones(len(d_target_pred))
)
if self._method is Method.MME:
# only keep accuracy, overwrite "domain" loss
loss_dmn_src = 0
loss_dmn_tgt = losses.entropy_logits_loss(y_tu_hat)
adv_loss = loss_dmn_src + loss_dmn_tgt
log_metrics = {
f"{split_name}_source_acc": ok_src,
f"{split_name}_target_acc": ok_tgt,
f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)),
f"{split_name}_source_domain_acc": dok_src,
f"{split_name}_target_domain_acc": dok_tgt,
}
return task_loss, adv_loss, log_metrics
[docs]class BaseMMDLike(BaseAdaptTrainer):
def __init__(
self,
dataset,
feature_extractor,
task_classifier,
kernel_mul=2.0,
kernel_num=5,
**base_params,
):
super().__init__(dataset, feature_extractor, task_classifier, **base_params)
self._kernel_mul = kernel_mul
self._kernel_num = kernel_num
[docs] def forward(self, x):
if self.feat is not None:
x = self.feat(x)
x = x.view(x.size(0), -1)
class_output = self.classifier(x)
return x, class_output
def _compute_mmd(self, phi_s, phi_t, y_hat, y_t_hat):
raise NotImplementedError("You need to implement a MMD-loss")
[docs] def compute_loss(self, batch, split_name="V"):
if len(batch) == 3:
raise NotImplementedError("MMD does not support semi-supervised setting.")
(x_s, y_s), (x_tu, y_tu) = batch
phi_s, y_hat = self.forward(x_s)
phi_t, y_t_hat = self.forward(x_tu)
loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s)
_, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu)
mmd = self._compute_mmd(phi_s, phi_t, y_hat, y_t_hat)
task_loss = loss_cls
log_metrics = {
f"{split_name}_source_acc": ok_src,
f"{split_name}_target_acc": ok_tgt,
f"{split_name}_mmd": mmd,
}
return task_loss, mmd, log_metrics
[docs] def validation_epoch_end(self, outputs):
metrics_to_log = (
"val_loss",
"V_source_acc",
"V_target_acc",
"V_mmd",
)
return self._validation_epoch_end(outputs, metrics_to_log)
[docs] def test_epoch_end(self, outputs):
metrics_at_test = (
"test_loss",
"Te_source_acc",
"Te_target_acc",
"Te_mmd",
)
log_dict = get_aggregated_metrics(metrics_at_test, outputs)
return {
"avg_test_loss": log_dict["test_loss"],
"progress_bar": log_dict,
"log": log_dict,
}
[docs]class DANtrainer(BaseMMDLike):
"""
This is an implementation of DAN
Long, Mingsheng, et al.
"Learning Transferable Features with Deep Adaptation Networks."
International Conference on Machine Learning. 2015.
http://proceedings.mlr.press/v37/long15.pdf
code based on https://github.com/thuml/Xlearn.
"""
def __init__(self, dataset, feature_extractor, task_classifier, **base_params):
super().__init__(dataset, feature_extractor, task_classifier, **base_params)
def _compute_mmd(self, phi_s, phi_t, y_hat, y_t_hat):
batch_size = int(phi_s.size()[0])
kernels = losses.gaussian_kernel(
phi_s,
phi_t,
kernel_mul=self._kernel_mul,
kernel_num=self._kernel_num,
)
return losses.compute_mmd_loss(kernels, batch_size)
[docs]class JANtrainer(BaseMMDLike):
"""
This is an implementation of JAN
Long, Mingsheng, et al.
"Deep transfer learning with joint adaptation networks."
International Conference on Machine Learning, 2017.
https://arxiv.org/pdf/1605.06636.pdf
code based on https://github.com/thuml/Xlearn.
"""
def __init__(
self,
dataset,
feature_extractor,
task_classifier,
kernel_mul=(2.0, 2.0),
kernel_num=(5, 1),
**base_params,
):
super().__init__(
dataset,
feature_extractor,
task_classifier,
kernel_mul=kernel_mul,
kernel_num=kernel_num,
**base_params,
)
def _compute_mmd(self, phi_s, phi_t, y_hat, y_t_hat):
softmax_layer = torch.nn.Softmax(dim=-1)
source_list = [phi_s, softmax_layer(y_hat)]
target_list = [phi_t, softmax_layer(y_t_hat)]
batch_size = int(phi_s.size()[0])
joint_kernels = None
for source, target, k_mul, k_num, sigma in zip(
source_list, target_list, self._kernel_mul, self._kernel_num, [None, 1.68]
):
kernels = losses.gaussian_kernel(
source, target, kernel_mul=k_mul, kernel_num=k_num, fix_sigma=sigma
)
if joint_kernels is not None:
joint_kernels = joint_kernels * kernels
else:
joint_kernels = kernels
return losses.compute_mmd_loss(joint_kernels, batch_size)