Source code for pytext.loss.regularized_loss

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Union

import torch
from pytext.config import ConfigBase
from pytext.config.component import create_loss

from .loss import (
    Loss,
    NLLLoss,
    HingeLoss,
    maybe_log_normalize,
    SourceType,
)
from .regularizer import UniformRegularizer, EntropyRegularizer, AdaptiveRegularizer
from .structured_loss import StructuredLoss, StructuredMarginLoss


[docs]class LabelSmoothingLoss(Loss): """Label loss with an optional regularizer for smoothing."""
[docs] class Config(ConfigBase): beta: float = 0.1 label_loss: Union[ NLLLoss.Config, StructuredMarginLoss.Config, HingeLoss.Config ] = NLLLoss.Config() smoothing_loss: Union[ UniformRegularizer.Config, EntropyRegularizer.Config, AdaptiveRegularizer.Config, ] = UniformRegularizer.Config()
def __init__(self, config, ignore_index=1): self.beta = config.beta self.label_loss_fn = create_loss(config.label_loss, ignore_index=ignore_index) self.smoothing_loss_fn = create_loss( config.smoothing_loss, ignore_index=ignore_index ) self.ignore_index = ignore_index # Tracking variables. self.label_loss = 0 self.smoothing_loss = 0 def __call__(self, logits, targets, reduce=True): label_loss = self.label_loss_fn(logits, targets, reduce) # Flatten logits if we're using a structured label loss. if isinstance(self.label_loss_fn, StructuredLoss): logits = logits.reshape(-1, logits.size(-1)) targets = targets.view(-1) smoothing_loss = self.smoothing_loss_fn(logits, targets, reduce) # Set tracking variables. self.label_loss = label_loss self.smoothing_loss = smoothing_loss loss = label_loss + self.beta * smoothing_loss return loss
[docs]class SamplewiseLabelSmoothingLoss(LabelSmoothingLoss): """Label smoothing loss with sample-wise logging.""" def __init__(self, config, ignore_index=-1): super().__init__(config, ignore_index) # Sample-wise tracking variables. self.samplewise_label_loss = 0 self.samplewise_smoothing_loss = 0 def _reduce_mean( self, logits, targets, batch_size, label_loss, smoothing_loss, reduce=True ): """ Class-specific reduction function to extract sample-wise losses. Currently, passing in reduce="mean" averages over all samples without providing access to sample-wise losses. """ # Save original losses. orig_label_loss = label_loss.clone() orig_smoothing_loss = smoothing_loss.clone() # Create target mask for pad tokens. mask = targets.ne(self.ignore_index) if mask.any(): # Guarantee ignored tokens have zero contribution to loss. label_loss[~mask] = 0 smoothing_loss[~mask] = 0 # Lengths after masking. lengths = torch.sum(mask.reshape(batch_size, -1), dim=1) # Sample-wise losses (we do not consider masked tokens in this loss). samplewise_label_loss = ( torch.sum(label_loss.reshape(batch_size, -1), dim=-1) / lengths ) samplewise_smoothing_loss = ( torch.sum(smoothing_loss.reshape(batch_size, -1), dim=-1) / lengths ) # Replace NaNs with zero (only happens with zero length samples). samplewise_label_loss[torch.isnan(samplewise_label_loss)] = 0 samplewise_smoothing_loss[torch.isnan(samplewise_smoothing_loss)] = 0 # Update original loss to use non-masked samples. label_loss = label_loss[mask] smoothing_loss = smoothing_loss[mask] else: samplewise_label_loss = torch.zeros(batch_size, device=logits.device) samplewise_smoothing_loss = torch.zeros(batch_size, device=logits.device) label_loss = torch.zeros(mask.shape, device=logits.shape) smoothing_loss = torch.zeros(mask.shape, device=logits.shape) # If `reduce` is enabled, compute mean loss over sequence. Otherwise, # revert values before masking. label_loss = torch.mean(label_loss) if reduce else orig_label_loss smoothing_loss = torch.mean(smoothing_loss) if reduce else orig_smoothing_loss return ( samplewise_label_loss, samplewise_smoothing_loss, label_loss, smoothing_loss, ) def __call__(self, logits, targets, reduce=True, batch_size=None): label_loss = self.label_loss_fn(logits, targets, reduce=False) smoothing_loss = self.smoothing_loss_fn(logits, targets, reduce=False) # Unless specified, batch_size is equal to the length of logits. if batch_size is None: batch_size = logits.shape[0] # Extract sample-wise losses and reduce regular losses. ( samplewise_label_loss, samplewise_smoothing_loss, label_loss, smoothing_loss, ) = self._reduce_mean( logits=logits, targets=targets, batch_size=batch_size, label_loss=label_loss, smoothing_loss=smoothing_loss, reduce=reduce, ) # Set sample-wise tracking variables. self.samplewise_label_loss = samplewise_label_loss self.samplewise_smoothing_loss = samplewise_smoothing_loss self.samplewise_total_loss = ( (samplewise_label_loss + self.beta * samplewise_smoothing_loss) if samplewise_label_loss is not None and samplewise_smoothing_loss is not None else None ) # Set tracking variables. self.label_loss = label_loss self.smoothing_loss = smoothing_loss loss = label_loss + self.beta * smoothing_loss return loss
[docs]class NARSequenceLoss(Loss): """Joint loss over labels and length of sequences for non-autoregressive modeling."""
[docs] class Config(ConfigBase): beta: float = 0.1 assert_valid_targets: bool = True label_type: SourceType = SourceType.LOG_PROBS length_type: SourceType = SourceType.LOG_PROBS label_loss: LabelSmoothingLoss.Config = LabelSmoothingLoss.Config() length_loss: LabelSmoothingLoss.Config = LabelSmoothingLoss.Config()
def __init__(self, config, ignore_index=1): self.beta = config.beta self.assert_valid_targets = config.assert_valid_targets self.label_type = config.label_type self.length_type = config.length_type # We can't use a structured loss for optimizing lengths. if isinstance(config.length_loss.label_loss, StructuredLoss): raise ValueError("StructuredLoss can't be used as a length loss") self.label_loss_fn = create_loss(config.label_loss, ignore_index=ignore_index) self.length_loss_fn = create_loss(config.length_loss, ignore_index=ignore_index) def __call__( self, label_logits, label_targets, length_logits, length_targets, reduce=True, ): """ label_logits: (B x T) x V_1 label_targets: (B x T) length_logits: B x V_2 length_targets: B """ label_logits = maybe_log_normalize( logits=label_logits, logits_type=self.label_type, dim=-1 ) length_logits = maybe_log_normalize( logits=length_logits, logits_type=self.length_type, dim=-1 ) max_supported_dim = length_logits.size(1) length_targets = length_targets.unsqueeze(-1) # (B x T) x 1 if self.assert_valid_targets: if torch.any(length_targets >= max_supported_dim): total_violations = str( length_targets[length_targets >= max_supported_dim] .flatten() .tolist() ) raise RuntimeError( f"max_supported_dim: {max_supported_dim}, " f"total violations: {total_violations}" ) else: length_targets[length_targets >= max_supported_dim] = max_supported_dim - 1 label_loss = self.label_loss_fn(label_logits, label_targets, reduce) length_loss = self.length_loss_fn( length_logits, length_targets.squeeze(-1), reduce ) loss = label_loss + self.beta * length_loss return ( loss, { "label_loss": label_loss, "length_loss": length_loss, "label_label_loss": self.label_loss_fn.label_loss, "label_smoothing_loss": self.label_loss_fn.smoothing_loss, "length_label_loss": self.length_loss_fn.label_loss, "length_smoothing_loss": self.length_loss_fn.smoothing_loss, }, )
[docs]class NARSamplewiseSequenceLoss(NARSequenceLoss): """Non-autoregressive sequence loss with sample-wise logging."""
[docs] class Config(NARSequenceLoss.Config): label_loss: SamplewiseLabelSmoothingLoss.Config = ( SamplewiseLabelSmoothingLoss.Config() ) length_loss: SamplewiseLabelSmoothingLoss.Config = ( SamplewiseLabelSmoothingLoss.Config() )
def __call__( self, label_logits, label_targets, length_logits, length_targets, reduce=True, ): """ label_logits: (B x T) x V_1 label_targets: (B x T) length_logits: B x V_2 length_targets: B """ label_logits = maybe_log_normalize( logits=label_logits, logits_type=self.label_type, dim=-1 ) length_logits = maybe_log_normalize( logits=length_logits, logits_type=self.length_type, dim=-1 ) max_length = int(torch.max(length_targets)) batch_size = label_logits.shape[0] // max_length max_supported_dim = length_logits.size(1) length_targets = length_targets.unsqueeze(-1) # (B x T) x 1 if self.assert_valid_targets: if torch.any(length_targets >= max_supported_dim): total_violations = str( length_targets[length_targets >= max_supported_dim] .flatten() .tolist() ) raise RuntimeError( f"max_supported_dim: {max_supported_dim}, " f"total violations: {total_violations}" ) else: length_targets[length_targets >= max_supported_dim] = max_supported_dim - 1 label_loss = self.label_loss_fn(label_logits, label_targets, reduce, batch_size) length_loss = self.length_loss_fn( length_logits, length_targets.squeeze(-1), reduce ) loss = label_loss + self.beta * length_loss # Log sample-wise losses. samplewise_losses = { "samplewise_label_loss": self.label_loss_fn.samplewise_total_loss, "samplewise_length_loss": self.length_loss_fn.samplewise_total_loss, "samplewise_label_label_loss": self.label_loss_fn.samplewise_label_loss, "samplewise_label_smoothing_loss": self.label_loss_fn.samplewise_smoothing_loss, "samplewise_length_label_loss": self.length_loss_fn.samplewise_label_loss, "samplewise_length_smoothing_loss": self.length_loss_fn.samplewise_smoothing_loss, } return ( loss, { "label_loss": label_loss, "length_loss": length_loss, "label_label_loss": self.label_loss_fn.label_loss, "label_smoothing_loss": self.label_loss_fn.smoothing_loss, "length_label_loss": self.length_loss_fn.label_loss, "length_smoothing_loss": self.length_loss_fn.smoothing_loss, **samplewise_losses, }, )