Source code for pytext.models.seq_models.seq2seq_output_layer

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, Dict, Tuple, Union

import torch
from pytext.config import ConfigBase
from pytext.config.component import create_loss
from pytext.data.utils import Vocabulary
from pytext.loss import CrossEntropyLoss, LabelSmoothedCrossEntropyLoss, NLLLoss
from pytext.models.output_layers import OutputLayerBase


[docs]class Seq2SeqOutputLayer(OutputLayerBase):
[docs] class Config(ConfigBase): loss: Union[ CrossEntropyLoss.Config, LabelSmoothedCrossEntropyLoss.Config, NLLLoss.Config, ] = CrossEntropyLoss.Config()
[docs] @classmethod def from_config(cls, config: Config, vocab: Vocabulary): return cls(vocab._vocab, create_loss(config.loss, vocab.get_pad_index()))
[docs] def get_loss( self, model_outputs: Tuple[torch.Tensor, Dict[str, torch.Tensor]], targets: Tuple[torch.Tensor, torch.Tensor], context: Dict[str, Any] = None, reduce=True, ) -> torch.Tensor: # flatten the logit from [batch_size, seq_lens, dim] to # [batch_size * seq_lens, dim] logits = model_outputs[0] loss = self.loss_fn( logits.view(-1, logits.size()[-1]), targets[0].view(-1), reduce ) return loss