Source code for pytext.models.representations.bilstm

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.onnx
from pytext.config import ConfigBase
from pytext.utils import cuda
from pytext.utils.usage import log_class_usage
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from .representation_base import RepresentationBase


[docs]class BiLSTM(RepresentationBase): """ `BiLSTM` implements a multi-layer bidirectional LSTM representation layer preceded by a dropout layer. Args: config (Config): Configuration object of type BiLSTM.Config. embed_dim (int): The number of expected features in the input. padding_value (float): Value for the padded elements. Defaults to 0.0. Attributes: padding_value (float): Value for the padded elements. dropout (nn.Dropout): Dropout layer preceding the LSTM. lstm (nn.LSTM): LSTM layer that operates on the inputs. representation_dim (int): The calculated dimension of the output features of BiLSTM. """
[docs] class Config(RepresentationBase.Config, ConfigBase): """ Configuration class for `BiLSTM`. Attributes: dropout (float): Dropout probability to use. Defaults to 0.4. lstm_dim (int): Number of features in the hidden state of the LSTM. Defaults to 32. num_layers (int): Number of recurrent layers. Eg. setting `num_layers=2` would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in the outputs of the first LSTM and computing the final result. Defaults to 1. bidirectional (bool): If `True`, becomes a bidirectional LSTM. Defaults to `True`. disable_sort_in_jit (bool): If `True`, disable sort in pack_padded_sequence to allow inference on GPU. Defaults to `False`. """ dropout: float = 0.4 lstm_dim: int = 32 num_layers: int = 1 bidirectional: bool = True pack_sequence: bool = True disable_sort_in_jit: bool = False
def __init__( self, config: Config, embed_dim: int, padding_value: float = 0.0 ) -> None: super().__init__(config) self.padding_value: float = padding_value self.dropout = nn.Dropout(config.dropout) self.lstm = nn.LSTM( embed_dim, config.lstm_dim, num_layers=config.num_layers, bidirectional=config.bidirectional, batch_first=True, ) self.representation_dim: int = config.lstm_dim * ( 2 if config.bidirectional else 1 ) self.pack_sequence = config.pack_sequence self.disable_sort_in_jit = config.disable_sort_in_jit log_class_usage(__class__)
[docs] def forward( self, embedded_tokens: torch.Tensor, seq_lengths: torch.Tensor, states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Given an input batch of sequential data such as word embeddings, produces a bidirectional LSTM representation of the sequential input and new state tensors. Args: embedded_tokens (torch.Tensor): Input tensor of shape (bsize x seq_len x input_dim). seq_lengths (torch.Tensor): List of sequences lengths of each batch element. states (Tuple[torch.Tensor, torch.Tensor]): Tuple of tensors containing the initial hidden state and the cell state of each element in the batch. Each of these tensors have a dimension of (bsize x num_layers * num_directions x nhid). Defaults to `None`. Returns: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: Bidirectional LSTM representation of input and the state of the LSTM `t = seq_len`. Shape of representation is (bsize x seq_len x representation_dim). Shape of each state is (bsize x num_layers * num_directions x nhid). """ if self.dropout.p > 0.0: embedded_tokens = self.dropout(embedded_tokens) if states is not None: # convert (h0, c0) from (bsz x num_layers*num_directions x nhid) to # (num_layers*num_directions x bsz x nhid) states = ( states[0].transpose(0, 1).contiguous(), states[1].transpose(0, 1).contiguous(), ) else: # We need to send in a zero state that matches the batch size, because # torch.jit tracing currently traces this as constant and therefore # locks the traced model into a static batch size. # see https://github.com/pytorch/pytorch/issues/16664 state = torch.zeros( self.config.num_layers * (2 if self.config.bidirectional else 1), embedded_tokens.size(0), # batch size self.config.lstm_dim, device=torch.cuda.current_device() if cuda.CUDA_ENABLED else None, dtype=embedded_tokens.dtype, ) states = (state, state) if torch.onnx.is_in_onnx_export(): lstm_in = [embedded_tokens.contiguous(), states[0], states[1]] + [ param.detach() for param in self.lstm._flat_weights ] rep, new_state_0, new_state_1 = torch.ops._caffe2.InferenceLSTM( lstm_in, self.lstm.num_layers, self.lstm.bias, True, self.lstm.bidirectional, ) new_state = (new_state_0, new_state_1) else: if self.pack_sequence: # We need to disble sorting when jit trace because it introduce # issues with sorted indices not in right device in pack_padded_sequence # using GPU inference rnn_input = pack_padded_sequence( embedded_tokens, seq_lengths.cpu(), batch_first=True, enforce_sorted=True if self.disable_sort_in_jit and torch._C._get_tracing_state() is not None else False, ) else: rnn_input = embedded_tokens rep, new_state = self.lstm(rnn_input, states) if self.pack_sequence: rep, _ = pad_packed_sequence( rep, padding_value=self.padding_value, batch_first=True, total_length=embedded_tokens.size(1), ) # Make sure the output from LSTM is padded to input's sequence length. # convert states back to (bsz x num_layers*num_directions x nhid) to be # used in data parallel model new_state = (new_state[0].transpose(0, 1), new_state[1].transpose(0, 1)) return rep, new_state