Source code for pytext.models.representations.stacked_bidirectional_rnn

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from enum import Enum

import torch
import torch.nn as nn
from pytext.models.module import Module
from pytext.utils.usage import log_class_usage


[docs]class RnnType(Enum): RNN = "rnn" LSTM = "lstm" GRU = "gru"
RNN_TYPE_DICT = {RnnType.RNN: nn.RNN, RnnType.LSTM: nn.LSTM, RnnType.GRU: nn.GRU}
[docs]class StackedBidirectionalRNN(Module): """ `StackedBidirectionalRNN` implements a multi-layer bidirectional RNN with an option to return outputs from all the layers of RNN. Args: config (Config): Configuration object of type BiLSTM.Config. embed_dim (int): The number of expected features in the input. padding_value (float): Value for the padded elements. Defaults to 0.0. Attributes: padding_value (float): Value for the padded elements. dropout (nn.Dropout): Dropout layer preceding the LSTM. lstm (nn.LSTM): LSTM layer that operates on the inputs. representation_dim (int): The calculated dimension of the output features of BiLSTM. """
[docs] class Config(Module.Config): """ Configuration class for `StackedBidirectionalRNN`. Attributes: hidden_size (int): Number of features in the hidden state of the RNN. Defaults to 32. num_layers (int): Number of recurrent layers. Eg. setting `num_layers=2` would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in the outputs of the first RNN and computing the final result. Defaults to 1. dropout (float): Dropout probability to use. Defaults to 0.4. bidirectional (bool): If `True`, becomes a bidirectional RNN. Defaults to `True`. rnn_type (str): Which RNN type to use. Options: "rnn", "lstm", "gru". concat_layers (bool): Whether to concatenate the outputs of each layer of stacked RNN. """ hidden_size: int = 32 num_layers: int = 1 dropout: float = 0.0 bidirectional: bool = True rnn_type: RnnType = RnnType.LSTM concat_layers: bool = True
def __init__(self, config: Config, input_size: int, padding_value: float = 0.0): super().__init__() self.num_layers = config.num_layers self.dropout = nn.Dropout(config.dropout) self.concat_layers = config.concat_layers self.padding_value = padding_value self.rnns = nn.ModuleList() rnn_module = RNN_TYPE_DICT.get(config.rnn_type) assert rnn_module is not None, "rnn_cell cannot be None" for i in range(config.num_layers): input_size = input_size if i == 0 else 2 * config.hidden_size self.rnns.append( rnn_module( input_size, config.hidden_size, num_layers=1, bidirectional=config.bidirectional, ) ) self.representation_dim = ( (config.num_layers if config.concat_layers else 1) * config.hidden_size * (2 if config.bidirectional else 1) ) log_class_usage(__class__)
[docs] def forward(self, tokens, tokens_mask): """ Args: tokens: batch, max_seq_len, hidden_size tokens_mask: batch, max_seq_len (1 for padding, 0 for true) Output: tokens_encoded: batch, max_seq_len, hidden_size * num_layers if concat_layers = True else batch, max_seq_len, hidden_size """ # Sort in descending order of sequence lengths. seq_lengths = tokens_mask.eq(0).long().sum(1) seq_lengths_sorted, idx_of_sorted = torch.sort( seq_lengths, dim=0, descending=True ) # Sort tokens and pack into padded sequence. tokens_sorted = tokens.index_select(0, idx_of_sorted) packed_tokens = nn.utils.rnn.pack_padded_sequence( tokens_sorted, seq_lengths_sorted, batch_first=True ) # Pass tokens through stacked RNN and get all layers' output. outputs = [packed_tokens] for i in range(self.num_layers): rnn_input = outputs[-1] rnn_input = nn.utils.rnn.PackedSequence( self.dropout(rnn_input.data), rnn_input.batch_sizes ) outputs.append(self.rnns[i](rnn_input)[0]) outputs = outputs[1:] # Ignore the first element which just tokens. # Unpack packed RNN outouts and concatenate if necessary. for i in range(len(outputs)): outputs[i] = nn.utils.rnn.pad_packed_sequence( outputs[i], padding_value=self.padding_value, batch_first=True )[0] output = torch.cat(outputs, 2) if self.concat_layers else outputs[-1] # Restore to original order of examples. _, idx_of_original = torch.sort(idx_of_sorted, dim=0) output = output.index_select(0, idx_of_original) # Pad upto original batch's maximum sequence length. max_seq_len = tokens_mask.size(1) batch_size, output_seq_len, output_dim = output.size() if output_seq_len != max_seq_len: padding = torch.zeros( batch_size, max_seq_len - output_seq_len, output_dim ).type(output.data.type()) output = torch.cat([output, padding], 1) return output