Source code for pytext.models.representations.bilstm_doc_slot_attention

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
from pytext.config import ConfigBase
from pytext.config.module_config import SlotAttentionType
from pytext.models.module import create_module
from pytext.utils.usage import log_class_usage

from .augmented_lstm import AugmentedLSTM
from .bilstm import BiLSTM
from .ordered_neuron_lstm import OrderedNeuronLSTM
from .pooling import MaxPool, MeanPool, SelfAttention
from .representation_base import RepresentationBase
from .slot_attention import SlotAttention


[docs]class BiLSTMDocSlotAttention(RepresentationBase): """ `BiLSTMDocSlotAttention` implements a multi-layer bidirectional LSTM based representation with support for various attention mechanisms. In default mode, when attention configuration is not provided, it behaves like a multi-layer LSTM encoder and returns the output features from the last layer of the LSTM, for each t. When document_attention configuration is provided, it produces a fixed-sized document representation. When slot_attention configuration is provide, it attends on output of each cell of LSTM module to produce a fixed sized word representation. Args: config (Config): Configuration object of type BiLSTMDocSlotAttention.Config. embed_dim (int): The number of expected features in the input. Attributes: dropout (nn.Dropout): Dropout layer preceding the LSTM. relu (nn.ReLU): An instance of the ReLU layer. lstm (nn.Module): Module that implements the LSTM. use_doc_attention (bool): If `True`, indicates using document attention. doc_attention (nn.Module): Module that implements document attention. self.projection_d (nn.Sequential): A sequence of dense layers for projection over document representation. use_word_attention (bool): If `True`, indicates using word attention. word_attention (nn.Module): Module that implements word attention. self.projection_w (nn.Sequential): A sequence of dense layers for projection over word representation. representation_dim (int): The calculated dimension of the output features of the `BiLSTMDocAttention` representation. """
[docs] class Config(RepresentationBase.Config, ConfigBase): dropout: float = 0.4 lstm: Union[ BiLSTM.Config, OrderedNeuronLSTM.Config, AugmentedLSTM.Config ] = BiLSTM.Config() pooling: Optional[ Union[SelfAttention.Config, MaxPool.Config, MeanPool.Config] ] = None slot_attention: Optional[SlotAttention.Config] = None doc_mlp_layers: int = 0 word_mlp_layers: int = 0
def __init__(self, config: Config, embed_dim: int) -> None: super().__init__(config) self.dropout = nn.Dropout(config.dropout) self.relu = nn.ReLU() # Shared representation. padding_value = ( float("-inf") if isinstance(config.pooling, MaxPool.Config) else 0.0 ) self.lstm = create_module( config.lstm, embed_dim=embed_dim, padding_value=padding_value ) lstm_out_dim = self.lstm.representation_dim # Document projection and attention. self.use_doc_attention: bool = config.pooling is not None if config.pooling: self.doc_attention = ( create_module(config.pooling, n_input=lstm_out_dim) if config.pooling else lambda x: x ) layers = [] for _ in range(config.doc_mlp_layers - 1): layers.extend( [nn.Linear(lstm_out_dim, lstm_out_dim), self.relu, self.dropout] ) layers.append(nn.Linear(lstm_out_dim, lstm_out_dim)) self.projection_d = nn.Sequential(*layers) # Word projection and attention. self.use_word_attention = config.slot_attention is not None if config.slot_attention: word_out_dim = lstm_out_dim self.word_attention = lambda x: x if config.slot_attention.attention_type != SlotAttentionType.NO_ATTENTION: self.word_attention = SlotAttention( config.slot_attention, lstm_out_dim, batch_first=True ) word_out_dim += lstm_out_dim layers = [nn.Linear(word_out_dim, lstm_out_dim), self.relu, self.dropout] for _ in range(config.word_mlp_layers - 2): layers.extend( [nn.Linear(lstm_out_dim, lstm_out_dim), self.relu, self.dropout] ) layers.append(nn.Linear(lstm_out_dim, lstm_out_dim)) self.projection_w = nn.Sequential(*layers) # Set the representation dimension attribute. self.representation_dim = ( self.doc_representation_dim ) = self.word_representation_dim = lstm_out_dim log_class_usage(__class__)
[docs] def forward( self, embedded_tokens: torch.Tensor, seq_lengths: torch.Tensor, *args, states: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Given an input batch of sequential data such as word embeddings, produces a bidirectional LSTM representation the appropriate attention. Args: embedded_tokens (torch.Tensor): Input tensor of shape (bsize x seq_len x input_dim). seq_lengths (torch.Tensor): List of sequences lengths of each batch element. states (Tuple[torch.Tensor, torch.Tensor]): Tuple of tensors containing the initial hidden state and the cell state of each element in the batch. Each of these tensors have a dimension of (bsize x num_layers * num_directions x nhid). Defaults to `None`. Returns: Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: Tensors containing the document and the word representation of the input. """ # Shared layers lstm_output, new_state = self.lstm(embedded_tokens, seq_lengths, states) # Default doc representation is hidden state of last cell of LSTM. # Default word representation is the output state of each cell of LSTM. outputs = [ new_state[0].contiguous().view(-1, self.doc_representation_dim), lstm_output, ] if self.use_doc_attention: outputs[0] = self.projection_d(self.doc_attention(lstm_output)) if self.use_word_attention: outputs[1] = self.projection_w(self.word_attention(lstm_output)) return outputs[0], outputs[1], new_state