Source code for pytext.torchscript.seq2seq.encoder

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Dict, List, Optional, Tuple

import torch
import torch.jit
import torch.jit.quantized
from torch import Tensor, nn


[docs]class EncoderEnsemble(nn.Module): """ This class will call the encoders from all the models in the ensemble. It will process the encoder output to prepare input for each decoder step input """ def __init__(self, models, beam_size): super().__init__() self.models = nn.ModuleList(models) self.beam_size = beam_size
[docs] def forward( self, src_tokens: Tensor, src_lengths: Tensor, dict_feat: Optional[Tuple[Tensor, Tensor, Tensor]] = None, contextual_token_embedding: Optional[Tensor] = None, ) -> List[Dict[str, Tensor]]: # (seq_length, batch_size) for compatibility with Caffe2 src_tokens_seq_first = src_tokens.t() # Model level parallelism cannot be implemented because of lack of # support for future type. https://github.com/pytorch/pytorch/issues/26578 futures = torch.jit.annotate(List[Dict[str, Tensor]], []) for model in self.models: embedding_input = [[src_tokens_seq_first]] if dict_feat is not None: embedding_input.append(list(dict_feat)) if contextual_token_embedding is not None: embedding_input.append([contextual_token_embedding]) embeddings = model.source_embeddings(embedding_input) futures.append( # torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths) model.encoder(src_tokens_seq_first, embeddings, src_lengths) ) return self.prepare_decoderstep_ip(futures)
[docs] def prepare_decoderstep_ip( self, futures: List[Dict[str, Tensor]] ) -> List[Dict[str, Tensor]]: outputs = torch.jit.annotate(List[Dict[str, Tensor]], []) for idx, model in enumerate(self.models): # encoder_out = torch.jit._wait(future) encoder_out = futures[idx] tiled_encoder_out = model.encoder.tile_encoder_out( self.beam_size, encoder_out ) outputs.append(tiled_encoder_out) return outputs