Source code for pytext.loss.loss

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from enum import Enum

import torch
import torch.nn.functional as F
from pytext.config import ConfigBase
from pytext.config.component import Component, ComponentType
from pytext.utils import loss as loss_utils, precision
from pytext.utils.cuda import FloatTensor
from torch import nn


[docs]def maybe_log_normalize(logits, logits_type, dim=-1): """Optionally log normalizes logits on the given dimension.""" if logits_type == SourceType.LOGITS: return F.log_softmax(logits, dim) elif logits_type == SourceType.PROBS: return logits.log() elif logits_type == SourceType.LOG_PROBS: return logits else: raise NotImplementedError
[docs]class SourceType(Enum): LOG_PROBS = "log_probs" LOGITS = "logits" PROBS = "probs"
[docs]class Loss(Component): """Base class for loss functions""" __COMPONENT_TYPE__ = ComponentType.LOSS def __init__(self, config=None, *args, **kwargs): super().__init__(config) def __call__(self, logit, targets, reduce=True): raise NotImplementedError
[docs]class CrossEntropyLoss(Loss):
[docs] class Config(ConfigBase): pass
def __init__(self, config, ignore_index=-100, weight=None, *args, **kwargs): self.ignore_index = ignore_index self.weight = weight def __call__(self, logits, targets, reduce=True): # Don't change to F.cross_entropy() because @barlaso suggested not doing so. # There's some wisdom from fairseq folks that it's the preferred way. # Needs more testing before we can change to using F.cross_entropy(). return F.nll_loss( F.log_softmax(logits, 1, dtype=torch.float32), targets, weight=self.weight, ignore_index=self.ignore_index, reduction="mean" if reduce else "none", )
[docs]class NLLLoss(Loss):
[docs] class Config(ConfigBase): pass
def __init__(self, config, ignore_index=-100, weight=None, *args, **kwargs): self.ignore_index = ignore_index self.weight = weight def __call__(self, log_probs, targets, reduce=True): return F.nll_loss( log_probs, targets, ignore_index=self.ignore_index, reduction="mean" if reduce else "none", weight=self.weight, )
[docs]class BinaryCrossEntropyWithLogitsLoss(Loss):
[docs] class Config(ConfigBase): reduce: bool = True
def __call__(self, logits, targets, reduce=True): """ Computes 1-vs-all binary cross entropy loss for multiclass classification. However, unlike BinaryCrossEntropyLoss, we require targets to be a one-hot vector. """ target_labels = targets[0].float() """ `F.binary_cross_entropy_with_logits` requires the output of the previous function be already a FloatTensor. """ loss = F.binary_cross_entropy_with_logits( precision.maybe_float(logits), target_labels, reduction="none" ) return loss.sum(-1).mean() if reduce else loss.sum(-1)
[docs]class BinaryCrossEntropyLoss(Loss):
[docs] class Config(ConfigBase): reweight_negative: bool = True reduce: bool = True
def __call__(self, logits, targets, reduce=True): """ Computes 1-vs-all binary cross entropy loss for multiclass classification. """ # Converts targets to one-hot representation. Dim: [batch, n_classes] targets = ( ( FloatTensor(targets.size(0), logits.size(1)) .zero_() .scatter_(1, targets.unsqueeze(1).data, 1) ) if len(logits.size()) > 1 # If multi-class classification. else targets.float() ) """ `F.binary_cross_entropy` or `torch.nn.BCELoss.` requires the output of the previous function be already a FloatTensor. """ # This weighting applies uniform class weights. # examples_per_class = one_hot_target.sum(0).clamp(min=1) # total_positive = examples_per_class.sum() # weights = total_positive.unsqueeze(0) / examples_per_class loss = F.binary_cross_entropy_with_logits( precision.maybe_float(logits), targets, reduction="none" ) if self.config.reweight_negative: # This makes sure we have same weights for all negative classes and # single positive class. Weight is 1 for the correct class and # 1 / (n - 1) for other ones. weights = targets + (1.0 - targets) / max(1, targets.size(1) - 1.0) loss = loss * weights return loss.sum(-1).mean() if reduce else loss.sum(-1)
[docs]class HingeLoss(Loss):
[docs] class Config(ConfigBase): margin: float = 1.0
def __init__(self, config, ignore_index=-100, weight=None, *args, **kwargs): self.margin = config.margin self.ignore_index = ignore_index self.weight = weight def __call__(self, logits, targets, reduce=True): return F.multi_margin_loss( logits, targets, margin=self.margin, weight=self.weight, reduction="mean" if reduce else "none", )
[docs]class CosineEmbeddingLoss(Loss):
[docs] class Config(ConfigBase): margin: float = 0.0
def __init__(self, config, *args, **kwargs): self.margin = config.margin def __call__(self, embeddings, targets, reduce=True): if len(embeddings) != 2: raise ValueError( f"Number of embeddings must be 2. Found {len(embeddings)} embeddings." ) return F.cosine_embedding_loss( embeddings[0], embeddings[1], targets, margin=self.margin, reduction="mean" if reduce else "none", )
[docs]class MultiLabelSoftMarginLoss(Loss):
[docs] class Config(ConfigBase): pass
def __call__(self, m_out, targets, reduce=True): """ Computes multi-label classification loss see details in torch.nn.MultiLabelSoftMarginLoss """ num_classes = m_out.size()[1] target_labels = targets[0] # each label list is padded by -1 to make every # observation example has the same length of list of labels # since -1 is out of the index range # add 1 to target_labels temporarily tmp_target_labels = target_labels + 1 # the idea is similar to one_hot_targets # the following encoding supports multi-label task # need to delete the first-column endoing since # it's for the padded label -1 n_hot_targets = ( FloatTensor(target_labels.size(0), num_classes + 1) .zero_() .scatter_(1, tmp_target_labels, 1) )[:, 1:] """ `F.multilabel_soft_margin_loss` or `torch.nn.MultiLabelSoftMarginLoss.` requires the output of the previous function be already a FloatTensor. """ # default: equal weight for each class # the losses are averaged over observations for each mini-batch loss = F.multilabel_soft_margin_loss( precision.maybe_float(m_out), n_hot_targets, reduction="mean" ) return loss
[docs]class AUCPRHingeLoss(nn.Module, Loss): """area under the precision-recall curve loss, Reference: "Scalable Learning of Non-Decomposable Objectives", Section 5 \ TensorFlow Implementation: \ https://github.com/tensorflow/models/tree/master/research/global_objectives\ """
[docs] class Config(ConfigBase): """ Attributes: precision_range_lower (float): the lower range of precision values over which to compute AUC. Must be nonnegative, `\leq precision_range_upper`, and `leq 1.0`. precision_range_upper (float): the upper range of precision values over which to compute AUC. Must be nonnegative, `\geq precision_range_lower`, and `leq 1.0`. num_classes (int): number of classes(aka labels) num_anchors (int): The number of grid points used to approximate the Riemann sum. """ precision_range_lower: float = 0.0 precision_range_upper: float = 1.0 num_classes: int = 1 num_anchors: int = 20
def __init__(self, config, weights=None, *args, **kwargs): """Args: config: Config containing `precision_range_lower`, `precision_range_upper`, `num_classes`, `num_anchors` """ nn.Module.__init__(self) Loss.__init__(self, config) self.num_classes = self.config.num_classes self.num_anchors = self.config.num_anchors self.precision_range = ( self.config.precision_range_lower, self.config.precision_range_upper, ) # Create precision anchor values and distance between anchors. # coresponding to [alpha_t] and [delta_t] in the paper. # precision_values: 1D `Tensor` of shape [K], where `K = num_anchors` # delta: Scalar (since we use equal distance between anchors) self.precision_values, self.delta = loss_utils.range_to_anchors_and_delta( self.precision_range, self.num_anchors ) # notation is [b_k] in paper, Parameter of shape [C, K] # where `C = number of classes` `K = num_anchors` self.biases = nn.Parameter( FloatTensor(self.config.num_classes, self.config.num_anchors).zero_() ) self.lambdas = nn.Parameter( FloatTensor(self.config.num_classes, self.config.num_anchors).data.fill_( 1.0 ) )
[docs] def forward(self, logits, targets, reduce=True, size_average=True, weights=None): """ Args: logits: Variable :math:`(N, C)` where `C = number of classes` targets: Variable :math:`(N)` where each value is `0 <= targets[i] <= C-1` weights: Coefficients for the loss. Must be a `Tensor` of shape [N] or [N, C], where `N = batch_size`, `C = number of classes`. size_average (bool, optional): By default, the losses are averaged over observations for each minibatch. However, if the field sizeAverage is set to False, the losses are instead summed for each minibatch. Default: ``True`` reduce (bool, optional): By default, the losses are averaged or summed over observations for each minibatch depending on size_average. When reduce is False, returns a loss per input/target element instead and ignores size_average. Default: True """ C = 1 if logits.dim() == 1 else logits.size(1) if self.num_classes != C: raise ValueError( "num classes is %d while logits width is %d" % (self.num_classes, C) ) labels, weights = AUCPRHingeLoss._prepare_labels_weights( logits, targets, weights=weights ) # Lagrange multipliers # Lagrange multipliers are required to be nonnegative. # Their gradient is reversed so that they are maximized # (rather than minimized) by the optimizer. # 1D `Tensor` of shape [K], where `K = num_anchors` lambdas = loss_utils.lagrange_multiplier(self.lambdas) # print("lambdas: {}".format(lambdas)) # A `Tensor` of Shape [N, C, K] hinge_loss = loss_utils.weighted_hinge_loss( labels.unsqueeze(-1), logits.unsqueeze(-1) - self.biases, positive_weights=1.0 + lambdas * (1.0 - self.precision_values), negative_weights=lambdas * self.precision_values, ) # 1D tensor of shape [C] class_priors = loss_utils.build_class_priors(labels, weights=weights) # lambda_term: Tensor[C, K] # according to paper, lambda_term = lambda * (1 - precision) * |Y^+| # where |Y^+| is number of postive examples = N * class_priors lambda_term = class_priors.unsqueeze(-1) * ( lambdas * (1.0 - self.precision_values) ) per_anchor_loss = weights.unsqueeze(-1) * hinge_loss - lambda_term # Riemann sum over anchors, and normalized by precision range # loss: Tensor[N, C] loss = per_anchor_loss.sum(2) * self.delta loss /= self.precision_range[1] - self.precision_range[0] if not reduce: return loss elif size_average: return loss.mean() else: return loss.sum()
@staticmethod def _prepare_labels_weights(logits, targets, weights=None): """ Args: logits: Variable :math:`(N, C)` where `C = number of classes` targets: Variable :math:`(N)` where each value is `0 <= targets[i] <= C-1` weights: Coefficients for the loss. Must be a `Tensor` of shape [N] or [N, C], where `N = batch_size`, `C = number of classes`. Returns: labels: Tensor of shape [N, C], one-hot representation weights: Tensor of shape broadcastable to labels """ N, C = logits.size() # Converts targets to one-hot representation. Dim: [N, C] labels = FloatTensor(N, C).zero_().scatter(1, targets.unsqueeze(1).data, 1) if weights is None: weights = FloatTensor(N).data.fill_(1.0) if weights.dim() == 1: weights.unsqueeze_(-1) return labels, weights
[docs]class KLDivergenceBCELoss(Loss):
[docs] class Config(ConfigBase): temperature: float = 1.0 hard_weight: float = 0.0
def __init__(self, config, ignore_index=-100, weight=None, *args, **kwargs): assert 0.0 <= config.hard_weight < 1.0 self.ignore_index = ignore_index self.weight = weight self.t = config.temperature self.hard_weight = config.hard_weight def __call__(self, logits, targets, reduce=True): """ Computes Kullback-Leibler divergence loss for multiclass classification probability distribution computed by BinaryCrossEntropyLoss loss """ hard_targets, _, soft_targets_logits = targets # we clamp the probability between (1e-20, 1 - 1e-20) to avoid log(0) problem # in the calculation of KLDivergence soft_targets = F.sigmoid(FloatTensor(soft_targets_logits) / self.t).clamp( 1e-20, 1 - 1e-20 ) probs = F.sigmoid(logits / self.t).clamp(1e-20, 1 - 1e-20) probs_neg = probs.neg().add(1).clamp(1e-20, 1 - 1e-20) soft_targets_neg = soft_targets.neg().add(1).clamp(1e-20, 1 - 1e-20) if self.weight is not None: soft_loss = ( F.kl_div(probs.log(), soft_targets, reduction="none") * self.weight + F.kl_div(probs_neg.log(), soft_targets_neg, reduction="none") * self.weight ) if reduce: soft_loss = soft_loss.mean() else: soft_loss = F.kl_div( probs.log(), soft_targets, reduction="mean" if reduce else "none" ) + F.kl_div( probs_neg.log(), soft_targets_neg, reduction="mean" if reduce else "none", ) soft_loss *= self.t ** 2 # see https://arxiv.org/pdf/1503.02531.pdf hard_loss = 0.0 if self.hard_weight > 0.0: one_hot_targets = ( FloatTensor(hard_targets.size(0), logits.size(1)) .zero_() .scatter_(1, hard_targets.unsqueeze(1).data, 1) ) hard_loss = F.binary_cross_entropy_with_logits( logits, one_hot_targets, reduction="mean" if reduce else "none", weight=self.weight, ) return (1.0 - self.hard_weight) * soft_loss + self.hard_weight * hard_loss
[docs]class KLDivergenceCELoss(Loss):
[docs] class Config(ConfigBase): temperature: float = 1.0 hard_weight: float = 0.0
def __init__(self, config, ignore_index=-100, weight=None, *args, **kwargs): # ignore_index not easily added to kl_div loss, don't support this until needed assert ignore_index < 0 assert 0.0 <= config.hard_weight < 1.0 self.weight = weight self.t = config.temperature self.hard_weight = config.hard_weight def __call__(self, logits, targets, reduce=True, combine_loss=True): """ Computes Kullback-Leibler divergence loss for multiclass classification probability distribution computed by CrossEntropyLoss loss. For, KL-divergence, batchmean is the right way to reduce, not just mean. """ hard_targets, _, soft_targets_logits = targets soft_targets = F.softmax(soft_targets_logits.float() / self.t, dim=1) soft_targets = soft_targets.clamp(1e-10, 1 - 1e-10) log_probs = F.log_softmax(logits / self.t, 1) if self.weight is not None: soft_loss = ( F.kl_div(log_probs, soft_targets, reduction="none") * self.weight ) # soft_loss dim is batch_size * num_labels, while hard_loss is just # batch size, we have to still reduce soft_loss by the labels # dimension in order to be able to add the two losses. soft_loss = ( torch.sum(soft_loss, dim=1).mean() if reduce else torch.sum(soft_loss, dim=1) ) else: soft_loss = F.kl_div( log_probs, soft_targets, reduction="batchmean" if reduce else "none" ) soft_loss *= self.t ** 2 # See https://arxiv.org/pdf/1503.02531.pdf hard_loss = F.nll_loss( F.log_softmax(logits, 1, dtype=torch.float32), hard_targets, weight=self.weight, reduction="mean" if reduce else "none", ) return ( (1.0 - self.hard_weight) * soft_loss + self.hard_weight * hard_loss if combine_loss else (soft_loss, hard_loss) )
[docs]class PairwiseRankingLoss(Loss): """ Given embeddings for a query, positive response and negative response computes pairwise ranking hinge loss """
[docs] class Config(ConfigBase): margin: float = 1.0
[docs] @staticmethod def get_similarities(embeddings): pos_embed, neg_embed, query_embed = embeddings pos_similarity = F.cosine_similarity(query_embed, pos_embed) neg_similarity = F.cosine_similarity(query_embed, neg_embed) return pos_similarity, neg_similarity, query_embed.size(0)
def __call__(self, logits, targets, reduce=True): pos_similarity, neg_similarity, batch_size = self.get_similarities(logits) targets_local = FloatTensor(batch_size) targets_local.fill_(1) # 1: pos_similarity should be higher than neg_similarity return F.margin_ranking_loss( pos_similarity, neg_similarity, targets_local, self.config.margin )
[docs]class MAELoss(Loss): """ Mean absolute error or L1 loss, for regression tasks. """
[docs] class Config(ConfigBase): pass
def __call__(self, predictions, targets, reduce=True): return F.l1_loss(predictions, targets, reduction="mean" if reduce else "none")
[docs]class MSELoss(Loss): """ Mean squared error or L2 loss, for regression tasks. """
[docs] class Config(ConfigBase): pass
def __call__(self, predictions, targets, reduce=True): return F.mse_loss(predictions, targets, reduction="mean" if reduce else "none")
[docs]class LabelSmoothedCrossEntropyLoss(Loss):
[docs] class Config(ConfigBase): beta: float = 0.1 source: SourceType = SourceType.LOGITS use_entropy: bool = False
def __init__(self, config, ignore_index=-100, weight=None, *args, **kwargs): # weight values other than 1.0 gives inconsistent behavior # Refer: https://github.com/pytorch/pytorch/issues/17577 if weight is not None: assert torch.sum(torch.abs(weight - 1.0)) < 1e-7 self.ignore_index = ignore_index self.weight = weight self.beta = config.beta self.source = config.source self.use_entropy = config.use_entropy self.cross_entropy_loss = None self.label_smoothing_loss = None def __call__(self, logits, targets, reduce=True): """ If use_entropy is False, returns the cross-entropy loss alongwith the KL divergence of the discrete uniform distribution with the logits. Refer to section 3.2 If use_entopy is True, uses the entropy of the output distribution as the smoothing loss (i.e., higher entropy, better). Refer to section 3 https://arxiv.org/pdf/1701.06548.pdf """ if self.use_entropy: # loss is negative of entropy probs = F.softmax(logits, dim=1) log_probs = torch.log(probs) label_smoothing_loss = torch.sum(log_probs * probs, dim=1) else: # negative KL-div has an additional log(num_classes) term but ignored # here because it doesn't contribute to optimization if self.source == SourceType.LOGITS: log_probs = F.log_softmax(logits, dim=1) elif self.source == SourceType.PROBS: log_probs = logits.log() else: log_probs = logits label_smoothing_loss = -1 * log_probs.mean(dim=1) if reduce: non_ignored = targets != self.ignore_index if non_ignored.any(): label_smoothing_loss = torch.mean(label_smoothing_loss[non_ignored]) else: label_smoothing_loss = torch.tensor(0.0, device=logits.device) cross_entropy_loss = F.nll_loss( log_probs, targets, ignore_index=self.ignore_index, reduction="mean" if reduce else "none", weight=self.weight, ) self.cross_entropy_loss = cross_entropy_loss self.label_smoothing_loss = label_smoothing_loss return (1.0 - self.beta) * cross_entropy_loss + self.beta * label_smoothing_loss