Source code for pytext.models.seq_models.rnn_encoder_decoder

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, List, Optional

import torch.jit
from pytext.config import ConfigBase
from pytext.models.module import create_module
from pytext.utils.usage import log_class_usage

from .base import PyTextSeq2SeqModule
from .rnn_decoder import RNNDecoder
from .rnn_encoder import LSTMSequenceEncoder


[docs]class RNNModel(PyTextSeq2SeqModule):
[docs] class Config(ConfigBase): encoder: LSTMSequenceEncoder.Config = LSTMSequenceEncoder.Config() decoder: RNNDecoder.Config = RNNDecoder.Config()
def __init__(self, encoder, decoder, source_embeddings): super().__init__() self.source_embeddings = source_embeddings self.encoder = encoder self.decoder = decoder log_class_usage(__class__)
[docs] def forward( self, src_tokens: torch.Tensor, additional_features: List[List[torch.Tensor]], src_lengths, prev_output_tokens, incremental_state: Optional[Dict[str, torch.Tensor]] = None, ): # embed tokens embeddings = self.source_embeddings([[src_tokens]] + additional_features) # n.b. tensorized_features[0][0] must be src_tokens encoder_out = self.encoder(src_tokens, embeddings, src_lengths=src_lengths) decoder_out = self.decoder(prev_output_tokens, encoder_out, incremental_state) return decoder_out
[docs] @classmethod def from_config( cls, config: Config, source_vocab, source_embedding, target_vocab, target_embedding, ): out_vocab_size = len(target_vocab) encoder = create_module(config.encoder) decoder = create_module(config.decoder, out_vocab_size, target_embedding) return cls(encoder, decoder, source_embedding)
[docs] def get_normalized_probs(self, net_output, log_probs, sample=None): return self.decoder.get_normalized_probs(net_output, log_probs, sample)
[docs] def max_decoder_positions(self): return max(self.encoder.max_positions(), self.decoder.max_positions())