Source code for pytext.models.output_layers.lm_output_layer

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, Dict, List, Optional, Tuple

import torch
from pytext.config.component import create_loss
from pytext.data.utils import Vocabulary
from pytext.fields import FieldMeta
from pytext.loss import CrossEntropyLoss, Loss
from pytext.utils.usage import log_class_usage

from .output_layer_base import OutputLayerBase


[docs]class LMOutputLayer(OutputLayerBase): """ Output layer for language models. It supports `CrossEntropyLoss` per word. Args: loss_fn (CrossEntropyLoss): Cross-entropy loss component. Defaults to None. Attributes: loss_fn: Cross-entropy loss component for computing loss. """
[docs] class Config(OutputLayerBase.Config): loss: CrossEntropyLoss.Config = CrossEntropyLoss.Config()
[docs] @classmethod def from_config( cls, config: Config, metadata: Optional[FieldMeta] = None, labels: Optional[Vocabulary] = None, ): if labels is not None: vocab = list(labels) pad_token_idx = labels.get_pad_index() else: vocab = metadata.vocab.itos pad_token_idx = metadata.pad_token_idx return cls( vocab, create_loss(config.loss, ignore_index=pad_token_idx), pad_token_idx=pad_token_idx, )
def __init__( self, target_names: List[str], loss_fn: Loss = None, config=None, pad_token_idx=-100, ): super().__init__(target_names, loss_fn, config) self.pad_token_idx = pad_token_idx log_class_usage(__class__)
[docs] def get_loss( self, logit: torch.Tensor, target: torch.Tensor, context: Dict[str, Any], reduce=True, ) -> torch.Tensor: """Compute word prediction loss by comparing prediction of each word in the sentence with the true word. Args: logit (torch.Tensor): Logit returned by :class:`~pytext.models.language_models.lmlstm.LMLSTM`. targets (torch.Tensor): Not applicable for language models. context (Dict[str, Any]): Not applicable. Defaults to None. reduce (bool): Whether to reduce loss over the batch. Defaults to True. Returns: torch.Tensor: Word prediction loss. """ if isinstance(target, tuple): target = target[0] # flatten the logit from [batch_size, seq_lens, dim] to # [batch_size * seq_lens, dim] return self.loss_fn(logit.view(-1, logit.size()[-1]), target.view(-1), reduce)
[docs] def get_pred( self, logits: torch.Tensor, *args, **kwargs ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute and return prediction and scores from the model. Prediction is computed using argmax over the word label/target space. Scores are softmax scores over the model logits. Args: logits (torch.Tensor): Logits returned :class:`~pytext.models.language_models.lmlstm.LMLSTM`. targets (torch.Tensor): True words. Returns: Tuple[torch.Tensor, torch.Tensor]: Model prediction and scores. """ return (logits, None)
[docs] @staticmethod def calculate_perplexity(sequence_loss: torch.Tensor) -> torch.Tensor: try: ppl = torch.exp(sequence_loss) except OverflowError: ppl = float("inf") return ppl