#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F
from fairseq import utils as fairseq_utils
from pytext.config import ConfigBase
from pytext.models.seq_models.base import PyTextSeq2SeqModule
from pytext.utils.usage import log_class_usage
from torch import nn

from .attention import DotAttention
from .base import PlaceholderIdentity, PyTextIncrementalDecoderComponent

[docs]class DecoderWithLinearOutputProjection(PyTextSeq2SeqModule): """ Common super class for decoder networks with output projection layers. """ def __init__(self, out_vocab_size, out_embed_dim=512): super().__init__() self.linear_projection = nn.Linear(out_embed_dim, out_vocab_size) self.reset_parameters() log_class_usage(__class__)
[docs] def reset_parameters(self): nn.init.uniform_(self.linear_projection.weight, -0.1, 0.1) nn.init.zeros_(self.linear_projection.bias)
[docs] def forward( self, input_tokens, encoder_out: Dict[str, torch.Tensor], incremental_state: Optional[Dict[str, torch.Tensor]] = None, timestep: int = 0, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: x, features = self.forward_unprojected( input_tokens, encoder_out, incremental_state ) logits = self.linear_projection(x) return logits, features
[docs] def forward_unprojected(self, input_tokens, encoder_out, incremental_state=None): """Forward pass through the decoder without output projection.""" raise NotImplementedError()
[docs]class RNNDecoderBase(PyTextIncrementalDecoderComponent): """ RNN decoder with multihead attention. Attention is calculated using encoder output and output of decoder's first RNN layerself. Attention is applied after first RNN layer and concatenated to input of subsequent layers. """
[docs] class Config(ConfigBase): encoder_hidden_dim: int = 512 embed_dim: int = 512 hidden_dim: int = 512 out_embed_dim: int = 512 cell_type: str = "lstm" num_layers: int = 1 dropout_in: float = 0.1 dropout_out: float = 0.1 attention_type: str = "dot" attention_heads: int = 8 first_layer_attention: bool = False averaging_encoder: bool = False
[docs] @classmethod def from_config(cls, config, out_vocab_size, target_embedding): return cls(out_vocab_size, target_embedding, **config._asdict())
def __init__( self, embed_tokens, encoder_hidden_dim, embed_dim, hidden_dim, out_embed_dim, cell_type, num_layers, dropout_in, dropout_out, attention_type, attention_heads, first_layer_attention, averaging_encoder, ): encoder_hidden_dim = max(1, encoder_hidden_dim) self.encoder_hidden_dim = encoder_hidden_dim self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.out_embed_dim = out_embed_dim self.dropout_in = dropout_in self.dropout_out = dropout_out self.attention_type = attention_type self.attention_heads = attention_heads self.first_layer_attention = first_layer_attention self.embed_tokens = embed_tokens self.hidden_dim = hidden_dim self.averaging_encoder = averaging_encoder if cell_type == "lstm": cell_class = torch.nn.LSTMCell else: raise RuntimeError("Cell type not supported") self.change_hidden_dim = hidden_dim != encoder_hidden_dim if self.change_hidden_dim: hidden_init_fc_list = [] cell_init_fc_list = [] for _ in range(num_layers): hidden_init_fc_list.append(nn.Linear(encoder_hidden_dim, hidden_dim)) cell_init_fc_list.append(nn.Linear(encoder_hidden_dim, hidden_dim)) self.hidden_init_fc_list = nn.ModuleList(hidden_init_fc_list) self.cell_init_fc_list = nn.ModuleList(cell_init_fc_list) else: # Empty module lists to appease Torchscript self.hidden_init_fc_list = nn.ModuleList([]) self.cell_init_fc_list = nn.ModuleList([]) if attention_type == "dot": self.attention = DotAttention( decoder_hidden_state_dim=hidden_dim, context_dim=encoder_hidden_dim ) else: raise RuntimeError(f"Attention type {attention_type} not supported") self.combined_output_and_context_dim = self.attention.context_dim + hidden_dim layers = [] for layer in range(num_layers): if layer == 0: cell_input_dim = embed_dim else: cell_input_dim = hidden_dim # attention applied to first layer always. if self.first_layer_attention or layer == 0: cell_input_dim += self.attention.context_dim layers.append(cell_class(input_size=cell_input_dim, hidden_size=hidden_dim)) self.layers = nn.ModuleList(layers) self.num_layers = len(layers) if self.combined_output_and_context_dim != out_embed_dim: self.additional_fc = nn.Linear( self.combined_output_and_context_dim, out_embed_dim ) else: # Using identity layer in place of the bottleneck simplifies torchscript # compatibility. self.additional_fc = PlaceholderIdentity() log_class_usage(__class__)
[docs] def forward_unprojected( self, input_tokens, encoder_out: Dict[str, torch.Tensor], incremental_state: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: if incremental_state is not None and len(incremental_state) > 0: input_tokens = input_tokens[:, -1:] bsz, seqlen = input_tokens.size() # get outputs from encoder encoder_outs = encoder_out["unpacked_output"] src_lengths = encoder_out["src_lengths"] # embed tokens x = self.embed_tokens([[input_tokens]]) x = F.dropout(x, p=self.dropout_in, # B x T x C -> T x B x C x = x.transpose(0, 1) # initialize previous states (or get from cache during incremental generation) cached_state = self._get_cached_state(incremental_state) if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: # first time step, initialize previous states if incremental_state is None: incremental_state = {} self._init_prev_states(encoder_out, incremental_state) init_state = self._get_cached_state(incremental_state) assert init_state is not None prev_hiddens, prev_cells, input_feed = init_state outs = [] attn_scores_per_step: List[torch.Tensor] = [] next_hiddens: List[torch.Tensor] = [] next_cells: List[torch.Tensor] = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step step_input =[j, :, :], input_feed), dim=1) for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(step_input, (prev_hiddens[i], prev_cells[i])) if self.first_layer_attention and i == 0: # tgt_len is 1 in decoder and squeezed for both matrices # input_feed.shape = tgt_len X bsz X embed_dim # step_attn_scores.shape = src_len X tgt_len X bsz input_feed, step_attn_scores = self.attention( hidden, encoder_outs, src_lengths ) # hidden state becomes the input to the next layer layer_output = F.dropout( hidden, p=self.dropout_out, ) step_input = layer_output if self.first_layer_attention: step_input =, input_feed), dim=1) # save state for next time step next_hiddens.append(hidden) next_cells.append(cell) if not self.first_layer_attention: input_feed, step_attn_scores = self.attention( hidden, encoder_outs, src_lengths ) attn_scores_per_step.append(step_attn_scores) combined_output_and_context =, input_feed), dim=1) # save final output outs.append(combined_output_and_context) # update hidden states for next timestep prev_hiddens = torch.stack(next_hiddens, 0) prev_cells = torch.stack(next_cells, 0) next_hiddens = [] next_cells = [] attn_scores = torch.stack(attn_scores_per_step, dim=1) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen attn_scores = attn_scores.transpose(0, 2) # cache previous states self._set_cached_state( incremental_state, (prev_hiddens, prev_cells, input_feed) ) # collect outputs across time steps x =, dim=0).view( seqlen, bsz, self.combined_output_and_context_dim ) # T x B x C -> B x T x C x = x.transpose(1, 0) # bottleneck layer x = self.additional_fc(x) x = F.dropout(x, p=self.dropout_out, return ( x, { "attn_scores": attn_scores, "src_tokens": encoder_out["src_tokens"], "src_lengths": encoder_out["src_lengths"], }, )
[docs] def reorder_incremental_state( self, incremental_state: Dict[str, torch.Tensor], new_order ): """Reorder buffered internal state (for incremental generation).""" assert incremental_state is not None hiddens = self.get_incremental_state(incremental_state, "cached_hiddens") assert hiddens is not None cells = self.get_incremental_state(incremental_state, "cached_cells") assert cells is not None feeds = self.get_incremental_state(incremental_state, "cached_feeds") assert feeds is not None self.set_incremental_state( incremental_state, "cached_hiddens", hiddens.index_select(1, new_order) ) self.set_incremental_state( incremental_state, "cached_cells", cells.index_select(1, new_order) ) self.set_incremental_state( incremental_state, "cached_feeds", feeds.index_select(0, new_order) )
[docs] def max_positions(self): """Maximum output length supported by the decoder.""" return int(1e5) # an arbitrary large number
def _init_prev_states( self, encoder_out: Dict[str, torch.Tensor], incremental_state: Dict[str, torch.Tensor], ) -> None: encoder_output = encoder_out["unpacked_output"] final_hiddens = encoder_out["final_hiddens"] prev_cells = encoder_out["final_cells"] if self.averaging_encoder: # Use mean encoder hidden states prev_hiddens = torch.stack( [torch.mean(encoder_output, 0)] * self.num_layers, dim=0 ) else: # Simply return the final state of each layer prev_hiddens = final_hiddens if self.change_hidden_dim: transformed_hiddens: List[torch.Tensor] = [] transformed_cells: List[torch.Tensor] = [] i: int = 0 for hidden_init_fc, cell_init_fc in zip( self.hidden_init_fc_list, self.cell_init_fc_list ): transformed_hiddens.append(hidden_init_fc(prev_hiddens[i])) transformed_cells.append(cell_init_fc(prev_cells[i])) i += 1 use_hiddens = torch.stack(transformed_hiddens, dim=0) use_cells = torch.stack(transformed_cells, dim=0) else: use_hiddens = prev_hiddens use_cells = prev_cells assert self.attention.context_dim initial_attn_context = torch.zeros( self.attention.context_dim, device=encoder_output.device ) batch_size = encoder_output.size(1) self.set_incremental_state(incremental_state, "cached_hiddens", use_hiddens) self.set_incremental_state(incremental_state, "cached_cells", use_cells) self.set_incremental_state( incremental_state, "cached_feeds", initial_attn_context.expand(batch_size, self.attention.context_dim), )
[docs] def get_normalized_probs(self, net_output, log_probs, sample): """Get normalized probabilities (or log probs) from a net's output.""" logits = net_output[0] if log_probs: return fairseq_utils.log_softmax(logits, dim=-1) else: return fairseq_utils.softmax(logits, dim=-1)
def _get_cached_state(self, incremental_state: Optional[Dict[str, torch.Tensor]]): if incremental_state is None or len(incremental_state) == 0: return None hiddens = self.get_incremental_state(incremental_state, "cached_hiddens") assert hiddens is not None cells = self.get_incremental_state(incremental_state, "cached_cells") assert cells is not None feeds = self.get_incremental_state(incremental_state, "cached_feeds") assert feeds is not None return (hiddens, cells, feeds) def _set_cached_state( self, incremental_state: Optional[Dict[str, torch.Tensor]], state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ) -> None: if incremental_state is None: return (hiddens, cells, feeds) = state self.set_incremental_state(incremental_state, "cached_hiddens", hiddens) self.set_incremental_state(incremental_state, "cached_cells", cells) self.set_incremental_state(incremental_state, "cached_feeds", feeds)
[docs]class RNNDecoder(RNNDecoderBase, DecoderWithLinearOutputProjection): def __init__( self, out_vocab_size, embed_tokens, encoder_hidden_dim, embed_dim, hidden_dim, out_embed_dim, cell_type, num_layers, dropout_in, dropout_out, attention_type, attention_heads, first_layer_attention, averaging_encoder, ): DecoderWithLinearOutputProjection.__init__( self, out_vocab_size, out_embed_dim=out_embed_dim ) RNNDecoderBase.__init__( self, embed_tokens, encoder_hidden_dim, embed_dim, hidden_dim, out_embed_dim, cell_type, num_layers, dropout_in, dropout_out, attention_type, attention_heads, first_layer_attention, averaging_encoder, ) log_class_usage(__class__)