#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import itertools
import os
from typing import Dict, List, Tuple, Union
import torch
import torch.nn as nn
from pytext.data.tensorizers import LabelTensorizer, Tensorizer, TokenTensorizer
from pytext.models.decoders import DecoderBase
from pytext.models.decoders.mlp_decoder import MLPDecoder
from pytext.models.embeddings import EmbeddingBase, EmbeddingList, WordEmbedding
from pytext.models.model import BaseModel
from pytext.models.module import create_module
from pytext.models.output_layers import (
ClassificationOutputLayer,
OutputLayerBase,
PairwiseCosineDistanceOutputLayer,
)
from pytext.models.representations.bilstm_doc_attention import BiLSTMDocAttention
from pytext.models.representations.docnn import DocNNRepresentation
from pytext.models.representations.representation_base import RepresentationBase
from pytext.utils.file_io import PathManager
from scipy.special import comb
[docs]class BasePairwiseModel(BaseModel):
"""
A base classification model that scores a pair of texts.
Subclasses need to implement the from_config, forward and save_modules.
"""
__EXPANSIBLE__ = True
[docs] class Config(BaseModel.Config):
decoder: MLPDecoder.Config = MLPDecoder.Config()
output_layer: Union[
ClassificationOutputLayer.Config, PairwiseCosineDistanceOutputLayer.Config
] = ClassificationOutputLayer.Config()
encode_relations: bool = True
def __init__(
self,
decoder: DecoderBase,
output_layer: OutputLayerBase,
encode_relations: bool,
) -> None:
super().__init__()
self.decoder = decoder
self.output_layer = output_layer
self.encode_relations = encode_relations
self.use_cosine_sim = isinstance(
output_layer, PairwiseCosineDistanceOutputLayer
)
[docs] @classmethod
def from_config(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
raise NotImplementedError
[docs] def forward(
self, input1: Tuple[torch.Tensor, ...], input2: Tuple[torch.Tensor, ...]
):
raise NotImplementedError
[docs] def save_modules(self, base_path: str = "", suffix: str = ""):
raise NotImplementedError
@classmethod
def _create_decoder(
cls,
config: MLPDecoder.Config,
representations: nn.ModuleList,
tensorizers: Dict[str, Tensorizer],
):
decoder = None
if config.decoder:
num_reps = len(representations)
rep_dim = representations[0].representation_dim
decoder_in_dim = num_reps * rep_dim
if config.encode_relations:
decoder_in_dim += 2 * comb(num_reps, 2, exact=True) * rep_dim
decoder = create_module(
config.decoder,
in_dim=decoder_in_dim,
out_dim=len(tensorizers["labels"].vocab),
)
return decoder
@classmethod
def _encode_relations(cls, encodings: List[torch.Tensor]) -> List[torch.Tensor]:
for rep_l, rep_r in itertools.combinations(encodings, 2):
encodings.append(torch.abs(rep_l - rep_r))
encodings.append(rep_l * rep_r)
return encodings
def _save_modules(self, modules: Dict[str, nn.Module], base_path: str, suffix: str):
super().save_modules(base_path, suffix)
# Special case to also save the multi-representations separately, if needed.
for key, module in modules.items():
if getattr(module.config, "save_path", None):
path = f"{module.config.save_path}-{key}{suffix}"
if base_path:
path = os.path.join(base_path, path)
print(
f"Saving state of module {type(module).__name__} " f"to {path} ..."
)
with PathManager.open(path, "wb") as save_file:
if isinstance(module, torch.jit.ScriptModule):
module.save(save_file)
else:
torch.save(module.state_dict(), save_file)
[docs]class PairwiseModel(BasePairwiseModel):
"""
A classification model that scores a pair of texts, for example, a model for
natural language inference.
The model shares embedding space (so it doesn't support
pairs of texts where left and right are in different languages). It uses
bidirectional LSTM or CNN to represent the two documents, and concatenates
them along with their absolute difference and elementwise product. This
concatenated pair representation is passed to a multi-layer perceptron to
decode to label/target space.
See https://arxiv.org/pdf/1705.02364.pdf for more details.
It can be instantiated just like any other :class:`~Model`.
"""
EMBEDDINGS = ["embedding"]
INPUTS_PAIR = [["tokens1"], ["tokens2"]]
[docs] class Config(BasePairwiseModel.Config):
"""
Attributes:
encode_relations (bool): if `false`, return the concatenation of the two
representations; if `true`, also concatenate their pairwise absolute
difference and pairwise elementwise product (à la arXiv:1705.02364).
Default: `true`.
tied_representation: whether to use the same representation, with
tied weights, for all the input subrepresentations. Default: `true`.
"""
class ModelInput(BasePairwiseModel.Config.ModelInput):
tokens1: TokenTensorizer.Config = TokenTensorizer.Config(column="text1")
tokens2: TokenTensorizer.Config = TokenTensorizer.Config(column="text2")
labels: LabelTensorizer.Config = LabelTensorizer.Config()
inputs: ModelInput = ModelInput()
embedding: WordEmbedding.Config = WordEmbedding.Config()
representation: Union[
BiLSTMDocAttention.Config, DocNNRepresentation.Config
] = BiLSTMDocAttention.Config()
shared_representations: bool = True
def __init__(
self,
embeddings: nn.ModuleList,
representations: nn.ModuleList,
decoder: MLPDecoder,
output_layer: ClassificationOutputLayer,
encode_relations: bool,
shared_representations: bool,
) -> None:
super().__init__(decoder, output_layer, encode_relations)
self.embeddings = embeddings
self.representations = representations
self.shared_representations = shared_representations
# from_config and helper function
@classmethod
def _create_embedding(cls, config, tensorizer) -> EmbeddingBase:
return create_module(config, None, tensorizer)
@classmethod
def _create_embeddings(
cls, config: Config, tensorizers: Dict[str, Tensorizer]
) -> nn.ModuleList:
embeddings = []
for inputs in cls.INPUTS_PAIR:
embedding_list = []
for emb, input in zip(cls.EMBEDDINGS, inputs):
if hasattr(config, emb) and input in tensorizers:
embedding_list.append(
cls._create_embedding(getattr(config, emb), tensorizers[input])
)
if len(embedding_list) == 1:
embeddings.append(embedding_list[0])
else:
embeddings.append(EmbeddingList(embeddings=embedding_list, concat=True))
return nn.ModuleList(embeddings)
@classmethod
def _create_representations(cls, config: Config, embeddings: nn.ModuleList):
if config.shared_representations:
# create representation once and used for all embeddings
embedding_dim = embeddings[0].embedding_dim
representations = nn.ModuleList(
itertools.repeat(
create_module(config.representation, embed_dim=embedding_dim),
len(embeddings),
)
)
else:
representations = nn.ModuleList(
[
create_module(
config.representation, embed_dim=embedding.embedding_dim
)
for embedding in embeddings
]
)
return representations
[docs] @classmethod
def from_config(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
embeddings = cls._create_embeddings(config, tensorizers)
representations = cls._create_representations(config, embeddings)
decoder = cls._create_decoder(config, representations, tensorizers)
output_layer = create_module(
config.output_layer, labels=tensorizers["labels"].vocab
)
return cls(
embeddings,
representations,
decoder,
output_layer,
config.encode_relations,
config.shared_representations,
)
[docs] def arrange_targets(self, tensor_dict):
return tensor_dict["labels"]
# _encode and helper functions
@classmethod
def _represent_helper(
cls, rep: RepresentationBase, embs: torch.Tensor, lens: torch.Tensor
) -> torch.Tensor:
representation = rep(embs, lens)
if isinstance(representation, tuple):
return representation[0]
return representation
@classmethod
def _represent_sort(
cls,
embeddings: List[torch.Tensor],
lengths: List[torch.Tensor],
represention_modules: nn.ModuleList,
) -> List[torch.Tensor]:
"""
Apply the representations computations in `self.representations` to the
sentence representations in `embeddings`.
Internally, it sorts the sentences in `embeddings` by the number
of tokens for packing efficiency, where the number of tokens is in `lengths`,
and undoes the sort after applying the representations to preserve the
original ordering of sentences. Assumes that the leftmost sentences are
already sorted by number of tokens.
"""
if isinstance(represention_modules[0], BiLSTMDocAttention):
# The leftmost inputs already come sorted by length. The others need to
# be sorted as well, for packing. We do it manually.
sorted_inputs = [(embeddings[0], lengths[0])]
sorted_indices = [None]
for embs, lens in zip(embeddings[1:], lengths[1:]):
lens_sorted, sorted_idx = lens.sort(descending=True)
embs_sorted = embs[sorted_idx]
sorted_inputs.append((embs_sorted, lens_sorted))
sorted_indices.append(sorted_idx)
representations = [
cls._represent_helper(rep, embs, lens)
for rep, (embs, lens) in zip(represention_modules, sorted_inputs)
]
# Put the inputs back in the original order, so they still match up to
# each other as well as the targets.
unsorted_representations = [representations[0]]
for sorted_idx, rep in zip(sorted_indices[1:], representations[1:]):
_, unsorted_idx = sorted_idx.sort()
unsorted_representations.append(rep[unsorted_idx])
return unsorted_representations
else:
return [
cls._represent_helper(rep, embs, lens)
for rep, (embs, lens) in zip(
represention_modules, zip(embeddings, lengths)
)
]
def _represent(self, embeddings: List[torch.Tensor], seq_lens: List[torch.Tensor]):
representations = self._represent_sort(
embeddings, seq_lens, self.representations
)
if self.encode_relations:
representations = self._encode_relations(representations)
return torch.cat(representations, -1)
[docs] def forward(
self, input1: Tuple[torch.Tensor, ...], input2: Tuple[torch.Tensor, ...]
) -> torch.Tensor:
token_tups, seq_lens = (input1[:-1], input2[:-1]), (input1[-1], input2[-1])
embeddings = [
emb(*token_tup) for emb, token_tup in zip(self.embeddings, token_tups)
]
representation = self._represent(embeddings, seq_lens)
return self.decoder(representation)
[docs] def save_modules(self, base_path: str = "", suffix: str = ""):
modules = {}
if not self.shared_representations:
# need to save all representations
modules = {f"rep{i + 1}": rep for i, rep in enumerate(self.representations)}
self._save_modules(modules, base_path, suffix)