Source code for pytext.loss.structured_loss

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from enum import Enum
from typing import Union

import torch
import torch.nn.functional as F
from pytext.config import ConfigBase
from pytext.config.component import create_loss

from .loss import Loss, NLLLoss, HingeLoss


[docs]class CostFunctionType(Enum): HAMMING = "hamming"
[docs]def hamming_distance(logits, targets, cost_scale=1.0): """ Computes Hamming distance (https://en.wikipedia.org/wiki/Hamming_distance), which is defined as the number of positions where two sequences of equal length differ. We apply Hamming distance locally, incrementing non-gold token scores by `cost_scale`. ``` Example: Given targets = [0, 1] and cost_scale = 1.0, we have the following: logits (before) = [[-1.0, 1.0, 2.0], [-2.0, -1.0, 1.0]] logits (after) = [[-1.0, 2.0, 3.0], [-1.0, -1.0, 2.0]] ``` """ hamming_cost = cost_scale * torch.ones_like(logits) # B x T x V gold_cost = torch.zeros_like(targets).to(logits.dtype).unsqueeze(2) # B x T x 1 hamming_cost.scatter_(2, targets.unsqueeze(2), gold_cost) return hamming_cost
[docs]def get_cost_fn(cost_fn_type: CostFunctionType): """Retrieves a cost function corresponding to `cost_fn_type`.""" if cost_fn_type == cost_fn_type.HAMMING: return hamming_distance else: raise RuntimeError("invalid cost type provideo")
[docs]class StructuredLoss(Loss): """Generic loss function applied to structured outputs.""" def __init__(self, config, ignore_index=1): self.ignore_index = ignore_index def __call__(self, logits, targets, reduce=True): raise NotImplementedError
[docs]class StructuredMarginLoss(StructuredLoss): """ Margin-based loss which requires a gold structure Y to score at least `cost(Y, Y')` above a hypothesis structure `Y'`. The cost function used is variable, but should reflect the underlying semantics of the task (e.g., BLEU in machine translation). """
[docs] class Config(ConfigBase): cost_scale: float = 1.0 cost_fn: CostFunctionType = CostFunctionType.HAMMING label_loss: Union[NLLLoss.Config, HingeLoss.Config] = NLLLoss.Config()
def __init__(self, config, ignore_index=1, *args, **kwargs): super().__init__(config, ignore_index) self.cost_scale = config.cost_scale self.cost_fn = get_cost_fn(config.cost_fn) self.label_loss_fn = create_loss(config.label_loss, ignore_index=ignore_index) def __call__(self, logits, targets, reduce=True): # Get cost-augmented logits. cost = self.cost_fn(logits, targets, self.cost_scale) logits = logits.clone() + cost # NLLLoss expects log normalized logits. if isinstance(self.label_loss_fn, NLLLoss): logits = F.log_softmax(logits, 2) # Flatten logits and targets. logits = logits.view(-1, logits.size(-1)) targets = targets.view(-1) return self.label_loss_fn(logits, targets, reduce)