Source code for pytext.models.seq_models.conv_model

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, Optional, Dict, Union, Tuple

from pytext.config import ConfigBase
from pytext.models.module import create_module
from torch import Tensor

from .base import PyTextSeq2SeqModule
from .conv_decoder import LightConvDecoder, LightConvDecoupledDecoder
from .conv_encoder import LightConvEncoder


[docs]class CNNModel(PyTextSeq2SeqModule): class Config(ConfigBase): encoder: LightConvEncoder.Config = LightConvEncoder.Config() decoder: Union[ LightConvDecoder.Config, LightConvDecoupledDecoder.Config ] = LightConvDecoder.Config()
[docs] @classmethod def from_config( cls, config: Config, src_dict, source_embedding, tgt_dict, target_embedding, dict_embedding=None, ): cls.validate_config(config) encoder = create_module(config.encoder, src_dict, source_embedding) decoder = create_module(config.decoder, tgt_dict, target_embedding) return cls(encoder, decoder, source_embedding)
def __init__(self, encoder, decoder, source_embedding): super().__init__() self.encoder = encoder self.decoder = decoder self.source_embeddings = source_embedding
[docs] def forward( self, src_tokens: Tensor, additional_features: List[List[Tensor]], src_lengths, prev_output_tokens, src_subword_begin_indices: Optional[Tensor] = None, ) -> Tuple[Tensor, Dict[str, Tensor]]: # embed tokens embeddings = self.source_embeddings([[src_tokens]] + additional_features) encoder_out = self.encoder(src_tokens, embeddings, src_lengths=src_lengths) decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out) return decoder_out
[docs] def get_normalized_probs(self, net_output, log_probs, sample=None): return self.decoder.get_normalized_probs(net_output, log_probs, sample)
[docs] def max_decoder_positions(self): return max(self.encoder.max_positions(), self.decoder.max_positions())
[docs] def get_embedding_module(self): return self.source_embeddings
[docs] @classmethod def validate_config(cls, config): assert ( config.encoder.encoder_config.max_target_positions <= config.decoder.decoder_config.max_target_positions )
[docs]class DecoupledCNNModel(CNNModel): class Config(CNNModel.Config): encoder: LightConvEncoder.Config = LightConvEncoder.Config() decoder: LightConvDecoupledDecoder.Config = LightConvDecoupledDecoder.Config()