Source code for pytext.models.representations.seq_rep

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Union

import torch
from pytext.config import ConfigBase
from pytext.models.module import create_module
from pytext.utils.usage import log_class_usage

from .bilstm_doc_attention import BiLSTMDocAttention
from .docnn import DocNNRepresentation
from .representation_base import RepresentationBase


[docs]class SeqRepresentation(RepresentationBase): """ Representation for a sequence of sentences Each sentence will be embedded with a DocNN model, then all the sentences are embedded with another DocNN/BiLSTM model """
[docs] class Config(RepresentationBase.Config): doc_representation: DocNNRepresentation.Config = DocNNRepresentation.Config() seq_representation: Union[ BiLSTMDocAttention.Config, DocNNRepresentation.Config ] = BiLSTMDocAttention.Config()
def __init__(self, config: Config, embed_dim: int) -> None: super().__init__(config) self.doc_rep = create_module(config.doc_representation, embed_dim=embed_dim) self.doc_representation_dim = self.doc_rep.representation_dim self.seq_rep = create_module( config.seq_representation, embed_dim=self.doc_representation_dim ) self.representation_dim = self.seq_rep.representation_dim log_class_usage(__class__)
[docs] def forward( self, embedded_seqs: torch.Tensor, seq_lengths: torch.Tensor, *args ) -> torch.Tensor: # embedded_seqs: (bsz, max_num_sen, max_seq_len, dim) (bsz, max_num_sen, max_seq_len, dim) = torch.onnx.operators.shape_as_tensor( embedded_seqs ) rep = self.doc_rep( torch.onnx.operators.reshape_from_tensor_shape( embedded_seqs, torch.cat( ((bsz * max_num_sen).view(1), max_seq_len.view(1), dim.view(1)) ), ) ) sentence_reps = torch.onnx.operators.reshape_from_tensor_shape( rep, torch.cat( ( bsz.view(1), max_num_sen.view(1), torch.tensor(self.doc_representation_dim).view(1), ) ), ) if isinstance(self.seq_rep, BiLSTMDocAttention): return self.seq_rep(embedded_tokens=sentence_reps, seq_lengths=seq_lengths) else: return self.seq_rep(embedded_tokens=sentence_reps)