Source code for pytext.models.representations.pair_rep

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
from pytext.models.module import create_module
from pytext.utils.usage import log_class_usage

from .bilstm_doc_attention import BiLSTMDocAttention
from .docnn import DocNNRepresentation
from .representation_base import RepresentationBase


SubRepresentation = Union[BiLSTMDocAttention.Config, DocNNRepresentation.Config]


[docs]class PairRepresentation(RepresentationBase): """Wrapper representation for a pair of inputs. Takes a tuple of inputs: the left sentence, and the right sentence(s). Returns a representation of the pair of sentences, either as a concatenation of the two sentence embeddings or as a "siamese" representation which also includes their difference and elementwise product (arXiv:1705.02364). If more than two inputs are provided, the extra inputs are assumed to be extra "right" sentences, and the output will be the stacked pair representations of the left sentence together with all right sentences. This is more efficient than separately computing all these pair representations, because the left sentence will not need to be re-embedded multiple times. """ class Config(RepresentationBase.Config): """ Attributes: encode_relations (bool): if `false`, return the concatenation of the two representations; if `true`, also concatenate their pairwise absolute difference and pairwise elementwise product (à la arXiv:1705.02364). Default: `true`. subrepresentation (SubRepresentation): the sub-representation used for the inputs. If `subrepresentation_right` is not given, then this representation is used for both inputs with tied weights. subrepresentation_right (Optional[SubRepresentation]): the sub-representation used for the right input. Optional. If missing, `subrepresentation` is used with tied weights. Default: `None`. """ subrepresentation: SubRepresentation = BiLSTMDocAttention.Config() subrepresentation_right: Optional[SubRepresentation] = None encode_relations: bool = True def __init__(self, config: Config, embed_dim: Tuple[int, ...]) -> None: super().__init__(config) assert len(embed_dim) == 2 if config.subrepresentation_right is not None: self.subrepresentations = nn.ModuleList( [ create_module(config.subrepresentation, embed_dim=embed_dim[0]), create_module( config.subrepresentation_right, embed_dim=embed_dim[1] ), ] ) if config.encode_relations: assert ( self.subrepresentations[0].representation_dim == self.subrepresentations[1].representation_dim ), ( "Representations must have the same dimension" ", because `encode_relations` involves elementwise operations." ) else: assert embed_dim[0] == embed_dim[1], ( "Embeddings must have the same dimension" ", because subrepresentation weights are tied." ) subrep = create_module(config.subrepresentation, embed_dim=embed_dim[0]) self.subrepresentations = nn.ModuleList([subrep, subrep]) self.encode_relations = config.encode_relations self.representation_dim = self.subrepresentations[0].representation_dim if self.encode_relations: self.representation_dim *= 4 else: self.representation_dim += self.subrepresentations[1].representation_dim log_class_usage(__class__) # Takes care of dropping the extra return value of LSTM-based rep's (state). @staticmethod def _represent( rep: RepresentationBase, embs: torch.Tensor, lens: torch.Tensor ) -> torch.Tensor: representation = rep(embs, lens) if isinstance(representation, tuple): return representation[0] return representation
[docs] def forward( self, embeddings: Tuple[torch.Tensor, ...], *lengths: torch.Tensor ) -> torch.Tensor: """Computes the pair representations. Arguments: embeddings: token embeddings of the left sentence, followed by the token embeddings of the right sentence(s). lengths: the corresponding sequence lengths. Returns: A tensor of shape `(num_right_inputs, batch_size, rep_size)`, with the first dimension squeezed if one. """ left_rep = self._represent( self.subrepresentations[0], embeddings[0], lengths[0] ) assert len(embeddings) == len(lengths) and len(embeddings) >= 2 # The leftmost inputs already came sorted by length. The others need to # be sorted as well, for packing. We do it manually. sorted_right_inputs = [] sorted_right_indices = [] for embs, lens in zip(embeddings[1:], lengths[1:]): lens_sorted, sorted_idx = lens.sort(descending=True) embs_sorted = embs[sorted_idx] sorted_right_inputs.append((embs_sorted, lens_sorted)) sorted_right_indices.append(sorted_idx) sorted_right_reps = [ self._represent(self.subrepresentations[1], embs, lens) for (embs, lens) in sorted_right_inputs ] # Put the right inputs back in the original order, so they still match # up within the batch to the left inputs right_reps = [] for idx, rep in zip(sorted_right_indices, sorted_right_reps): _, desorted_idx = idx.sort() right_reps.append(rep[desorted_idx]) final_reps = [] for right_rep in right_reps: this_rep = [] this_rep.append(left_rep) this_rep.append(right_rep) if self.encode_relations: this_rep.append(torch.abs(left_rep - right_rep)) this_rep.append(left_rep * right_rep) final_reps.append(torch.cat(this_rep, -1)) return torch.stack(final_reps).squeeze(0)