Source code for pytext.models.seq_models.nar_output_layer

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, Dict, Tuple, Union

import torch
from pytext.config import ConfigBase
from pytext.config.component import create_loss
from pytext.data.utils import Vocabulary
from pytext.loss import (
    StructuredLoss,
    NARSequenceLoss,
    NARSamplewiseSequenceLoss,
)
from pytext.models.output_layers import OutputLayerBase


[docs]class NARSeq2SeqOutputLayer(OutputLayerBase): """Non-autoregressive seq2seq output layer.""" class Config(ConfigBase): loss: Union[ NARSequenceLoss.Config, NARSamplewiseSequenceLoss.Config ] = NARSequenceLoss.Config()
[docs] @classmethod def from_config(cls, config: Config, vocab: Vocabulary): return cls( vocab._vocab, create_loss(config.loss, ignore_index=vocab.get_pad_index()) )
[docs] def get_loss( self, model_outputs: Tuple[torch.Tensor, Dict[str, torch.Tensor]], targets: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], context: Dict[str, Any] = None, reduce=True, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ label_logits: B x T x V_1 label_targets: B x T length_logits: B x V_2 length_targets: B """ label_logits, output_dict = model_outputs length_logits = output_dict["predicted_tgt_lengths"] (_, label_targets), length_targets = targets # Structured losses require access to sequences in each batch, so don't # flatten logits and targets for these. if not isinstance(self.loss_fn.label_loss_fn.label_loss_fn, StructuredLoss): label_logits = label_logits.view(-1, label_logits.size(-1)) # (B x T) x V label_targets = label_targets.view(-1) # (B x T) loss, two_losses = self.loss_fn( label_logits, label_targets, length_logits, length_targets, reduce, ) return loss, two_losses