Source code for ada.models.losses

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import grad


[docs]def cross_entropy_logits(linear_output, label, weights=None): class_output = F.log_softmax(linear_output, dim=1) max_class = class_output.max(1) y_hat = max_class[1] # get the index of the max log-probability correct = y_hat.eq(label.view(label.size(0)).type_as(y_hat)) if weights is None: loss = nn.NLLLoss()(class_output, label.type_as(y_hat).view(label.size(0))) else: losses = nn.NLLLoss(reduction="none")( class_output, label.type_as(y_hat).view(label.size(0)) ) loss = torch.sum(weights * losses) / torch.sum(weights) return loss, correct
[docs]def entropy_logits(linear_output): p = F.softmax(linear_output, dim=1) loss_ent = -torch.sum(p * (torch.log(p + 1e-5)), dim=1) return loss_ent
[docs]def entropy_logits_loss(linear_output): return torch.mean(entropy_logits(linear_output))
[docs]def gradient_penalty(critic, h_s, h_t): # based on: https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py#L116 alpha = torch.rand(h_s.size(0), 1) alpha = alpha.expand(h_s.size()).type_as(h_s) try: differences = h_t - h_s interpolates = h_s + (alpha * differences) interpolates = torch.cat((interpolates, h_s, h_t), dim=0).requires_grad_() preds = critic(interpolates) gradients = grad( preds, interpolates, grad_outputs=torch.ones_like(preds), retain_graph=True, create_graph=True, )[0] gradient_norm = gradients.norm(2, dim=1) gradient_penalty = ((gradient_norm - 1) ** 2).mean() except: gradient_penalty = 0 return gradient_penalty
[docs]def gaussian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): """ Code from XLearn: computes the full kernel matrix, which is less than optimal since we don't use all of it with the linear MMD estimate. """ n_samples = int(source.size()[0]) + int(target.size()[0]) total = torch.cat([source, target], dim=0) total0 = total.unsqueeze(0).expand( int(total.size(0)), int(total.size(0)), int(total.size(1)) ) total1 = total.unsqueeze(1).expand( int(total.size(0)), int(total.size(0)), int(total.size(1)) ) L2_distance = ((total0 - total1) ** 2).sum(2) if fix_sigma: bandwidth = fix_sigma else: bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples) bandwidth /= kernel_mul ** (kernel_num // 2) bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)] kernel_val = [ torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list ] return sum(kernel_val) # /len(kernel_val)
[docs]def compute_mmd_loss(kernel_values, batch_size): loss = 0 for i in range(batch_size): s1, s2 = i, (i + 1) % batch_size t1, t2 = s1 + batch_size, s2 + batch_size loss += kernel_values[s1, s2] + kernel_values[t1, t2] loss -= kernel_values[s1, t2] + kernel_values[s2, t1] return loss / float(batch_size)