Source code for pytext.models.output_layers.squad_output_layer

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from pytext.config.component import create_loss
from pytext.fields import FieldMeta
from pytext.loss import CrossEntropyLoss, KLDivergenceCELoss, Loss
from pytext.models.output_layers import OutputLayerBase
from pytext.utils.usage import log_class_usage


[docs]class SquadOutputLayer(OutputLayerBase):
[docs] class Config(OutputLayerBase.Config): loss: Union[ CrossEntropyLoss.Config, KLDivergenceCELoss.Config ] = CrossEntropyLoss.Config() ignore_impossible: bool = True pos_loss_weight: float = 0.5 has_answer_loss_weight: float = 0.5 false_label: str = "False" max_answer_len: int = 30 # For knowledge distillation we have soft and hard labels. This specifies # the weight on loss against hard labels. hard_weight: float = 0.0
[docs] @classmethod def from_config( cls, config, metadata: Optional[FieldMeta] = None, labels: Optional[Iterable[str]] = None, is_kd: bool = False, ): return cls( loss_fn=create_loss(config.loss, ignore_index=-100), ignore_impossible=config.ignore_impossible, pos_loss_weight=config.pos_loss_weight, has_answer_loss_weight=config.has_answer_loss_weight, has_answer_labels=labels, false_label=config.false_label, max_answer_len=config.max_answer_len, hard_weight=config.hard_weight, is_kd=is_kd, )
def __init__( self, loss_fn: Loss, ignore_impossible: bool = Config.ignore_impossible, pos_loss_weight: float = Config.pos_loss_weight, has_answer_loss_weight: float = Config.has_answer_loss_weight, has_answer_labels: Iterable[str] = ("False", "True"), false_label: str = Config.false_label, max_answer_len: int = Config.max_answer_len, hard_weight: float = Config.hard_weight, is_kd: bool = False, ) -> None: super().__init__(loss_fn=loss_fn) self.pos_loss_weight = pos_loss_weight self.has_answer_loss_weight = has_answer_loss_weight self.has_answer_labels = has_answer_labels self.ignore_impossible = ignore_impossible self.max_answer_len = max_answer_len if not ignore_impossible: self.false_idx = 1 if has_answer_labels[1] == false_label else 0 self.true_idx = 1 - self.false_idx self.is_kd = is_kd self.hard_weight = hard_weight log_class_usage(__class__)
[docs] def get_position_preds( self, start_pos_logits: torch.Tensor, end_pos_logits: torch.Tensor, max_span_length: int, ): # the following is to enforce end_pos > start_pos. We create a matrix # of start_position X end_position, fill it with the sum logits, # then mask it to be upper-triangular # e.g. start_pos_logits = [1, 3, 0, 5, 2] # end_pos_logits = [2, 4, 6, 3, 5] # The max indices should be (3,4) with values (5,5). (5,6) would have a # higher score, but end_pos would be before start, so it's not feasible # # To calculate this, first create a matrix with i,j entry containing # start_pos_logits[i] + end_pos_logits[j] # = [[3, 5, 7, 4, 6], # [4, 7, 9, 6, 8], # [2, 4, 6, 3, 5], # [7, 9, 11, 8, 10], # [4, 6, 8, 5, 7]] # Then mask it to be upper-triagular: # logit_sum_matrix = [[3, 5, 7, 4, 6], # [0, 7, 9, 6, 8], # [0, 0, 6, 3, 5], # [0, 0, 0, 8, 10], # [0, 0, 0, 0, 7]] # Then we use argmax to retrieve the indices of the max value. size = start_pos_logits.size() + (start_pos_logits.size()[-1],) start_pos_logits = start_pos_logits.unsqueeze(-1).expand(size) + 10 end_pos_logits = ( end_pos_logits.unsqueeze(-1).expand(size).transpose(-2, -1) + 10 ) logit_sum_matrix = (start_pos_logits + end_pos_logits).triu() for i in range(logit_sum_matrix.size()[1]): logit_sum_matrix[:, i, i + max_span_length :] = 0 vals, ids = logit_sum_matrix.max(-1) _, start_position = vals.max(-1) end_position = ids.gather(-1, start_position.unsqueeze(-1)).squeeze(-1) return start_position, end_position
[docs] def get_pred( self, logits: torch.Tensor, targets: torch.Tensor, contexts: Dict[str, List[Any]], ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: start_pos_logits, end_pos_logits, has_answer_logits, _, _ = logits start_pos_preds, end_pos_preds = self.get_position_preds( start_pos_logits, end_pos_logits, self.max_answer_len ) has_answer_preds = has_answer_logits.float().argmax(-1) has_answer_scores = torch.zeros(has_answer_logits.size()) if not self.ignore_impossible: has_answer_scores = F.softmax(has_answer_logits, 1) # Compute the logit of the corresponding to start and end positions. start_pos_scores = ( F.softmax(start_pos_logits, 1) .gather(1, start_pos_preds.view(-1, 1)) .squeeze(-1) ) end_pos_scores = ( F.softmax(end_pos_logits, 1) .gather(1, end_pos_preds.view(-1, 1)) .squeeze(-1) ) return ( (start_pos_preds, end_pos_preds, has_answer_preds), (start_pos_scores, end_pos_scores, has_answer_scores), )
[docs] def get_loss( self, logits: Tuple[torch.Tensor, ...], targets: Tuple[torch.Tensor, ...], contexts: Optional[Dict[str, Any]] = None, *args, **kwargs, ) -> torch.Tensor: """Compute and return the loss given logits and targets. Args: logit (torch.Tensor): Logits returned :class:`~pytext.models.Model`. target (torch.Tensor): True label/target to compute loss against. context (Optional[Dict[str, Any]]): Context is a dictionary of items that's passed as additional metadata by the :class:`~pytext.data.DataHandler`. Defaults to None.= Returns: torch.Tensor: Model loss. """ return ( self._get_soft_hard_loss(logits, targets) if self.is_kd else self._get_hard_loss(logits, targets) )
def _get_hard_loss( self, logits: Tuple[torch.Tensor, torch.Tensor], targets: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ): start_pos_logits, end_pos_logits, has_answer_logits, _, _ = logits start_pos_target, end_pos_target, has_answer_target = targets num_answers = start_pos_target.size()[-1] if num_answers == 0: start_loss = torch.tensor(0.0, dtype=torch.float).type_as(end_pos_logits) end_loss = torch.tensor(0.0, dtype=torch.float).type_as(end_pos_logits) else: start_loss = self.loss_fn( start_pos_logits.repeat((num_answers, 1)), start_pos_target.transpose(1, 0).flatten(), reduce=False, ) end_loss = self.loss_fn( end_pos_logits.repeat((num_answers, 1)), end_pos_target.transpose(1, 0).flatten(), reduce=False, ) loss = (start_loss + end_loss).mean() if not self.ignore_impossible: has_answer_mask = ( has_answer_target.repeat((num_answers,)) == self.true_idx ).float() position_loss = (has_answer_mask * (start_loss + end_loss)).mean() has_answer_loss = self.loss_fn(has_answer_logits, has_answer_target) loss = ( self.has_answer_loss_weight * has_answer_loss + self.pos_loss_weight * position_loss ) return loss def _get_soft_hard_loss( self, logits: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], targets: Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ], ): start_pos_logits, end_pos_logits, has_answer_logits, _, _ = logits ( start_pos_target, end_pos_target, has_answer_target, start_pos_target_logits, end_pos_target_logits, has_answer_target_logits, ) = targets num_answers = start_pos_target.size()[-1] # Start and end position losses start_soft_loss, start_hard_loss = self.loss_fn( start_pos_logits.repeat((num_answers, 1)), ( start_pos_target.transpose(1, 0).flatten(), None, start_pos_target_logits.repeat((num_answers, 1)), ), reduce=False, combine_loss=False, ) end_soft_loss, end_hard_loss = self.loss_fn( end_pos_logits.repeat((num_answers, 1)), ( end_pos_target.transpose(1, 0).flatten(), None, end_pos_target_logits.repeat((num_answers, 1)), ), reduce=False, combine_loss=False, ) # Sum up along sequence length dimension. # Example for KL-divergence: we need to sum up p_i * log(q_i) over i. start_soft_loss = torch.sum(start_soft_loss, dim=1) end_soft_loss = torch.sum(end_soft_loss, dim=1) # Weighted sum of soft and hard loss of start and end positions. start_loss = self._weighted_loss(start_soft_loss, start_hard_loss) end_loss = self._weighted_loss(end_soft_loss, end_hard_loss) loss = (start_loss + end_loss).mean() if not self.ignore_impossible: has_answer_mask = ( has_answer_target.repeat((num_answers,)) == self.true_idx ).float() position_loss = (has_answer_mask * (start_loss + end_loss)).mean() has_answer_soft_loss, has_answer_hard_loss = self.loss_fn( has_answer_logits, (has_answer_target, None, has_answer_target_logits), reduce=False, combine_loss=False, ) has_answer_loss = self._weighted_loss( has_answer_soft_loss.mean(), has_answer_hard_loss.mean() ) loss = ( self.has_answer_loss_weight * has_answer_loss + self.pos_loss_weight * position_loss ) return loss def _weighted_loss(self, soft_loss, hard_loss): return (1.0 - self.hard_weight) * soft_loss + self.hard_weight * hard_loss