Source code for pytext.models.representations.augmented_lstm

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import math
from typing import Optional, Tuple

import torch
import torch.nn as nn
from pytext.config import ConfigBase
from pytext.utils.usage import log_class_usage
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence

from .representation_base import RepresentationBase


[docs]class AugmentedLSTMCell(nn.Module): """ `AugmentedLSTMCell` implements a AugmentedLSTM cell. Args: embed_dim (int): The number of expected features in the input. lstm_dim (int): Number of features in the hidden state of the LSTM. Defaults to 32. use_highway (bool): If `True` we append a highway network to the outputs of the LSTM. use_bias (bool): If `True` we use a bias in our LSTM calculations, otherwise we don't. Attributes: input_linearity (nn.Module): Fused weight matrix which computes a linear function over the input. state_linearity (nn.Module): Fused weight matrix which computes a linear function over the states. """ def __init__( self, embed_dim: int, lstm_dim: int, use_highway: bool, use_bias: bool = True ): super().__init__() self.embed_dim = embed_dim self.lstm_dim = lstm_dim self.use_highway = use_highway self.use_bias = use_bias if use_highway: self._highway_inp_proj_start = 5 * self.lstm_dim self._highway_inp_proj_end = 6 * self.lstm_dim # fused linearity of input to input_gate, # forget_gate, memory_init, output_gate, highway_gate, # and the actual highway value self.input_linearity = nn.Linear( self.embed_dim, self._highway_inp_proj_end, bias=self.use_bias ) # fused linearity of input to input_gate, # forget_gate, memory_init, output_gate, highway_gate self.state_linearity = nn.Linear( self.lstm_dim, self._highway_inp_proj_start, bias=True ) else: # If there's no highway layer then we have a standard # LSTM. The 4 comes from fusing input, forget, memory, output # gates/inputs. self.input_linearity = nn.Linear( self.embed_dim, 4 * self.lstm_dim, bias=self.use_bias ) self.state_linearity = nn.Linear( self.lstm_dim, 4 * self.lstm_dim, bias=True ) self.reset_parameters() log_class_usage(__class__)
[docs] def reset_parameters(self): stdv = 1.0 / math.sqrt(self.lstm_dim) for weight in self.parameters(): nn.init.uniform_(weight, -stdv, stdv)
[docs] def forward( self, x: torch.Tensor, states=Tuple[torch.Tensor, torch.Tensor], variational_dropout_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Warning: DO NOT USE THIS LAYER DIRECTLY, INSTEAD USE the AugmentedLSTM class Args: x (torch.Tensor): Input tensor of shape (bsize x input_dim). states (Tuple[torch.Tensor, torch.Tensor]): Tuple of tensors containing the hidden state and the cell state of each element in the batch. Each of these tensors have a dimension of (bsize x nhid). Defaults to `None`. Returns: Tuple[torch.Tensor, torch.Tensor]: Returned states. Shape of each state is (bsize x nhid). """ hidden_state, memory_state = states projected_input = self.input_linearity(x) projected_state = self.state_linearity(hidden_state) input_gate = forget_gate = memory_init = output_gate = highway_gate = None if self.use_highway: fused_op = projected_input[:, : 5 * self.lstm_dim] + projected_state fused_chunked = torch.chunk(fused_op, 5, 1) ( input_gate, forget_gate, memory_init, output_gate, highway_gate, ) = fused_chunked highway_gate = torch.sigmoid(highway_gate) else: fused_op = projected_input + projected_state input_gate, forget_gate, memory_init, output_gate = torch.chunk( fused_op, 4, 1 ) input_gate = torch.sigmoid(input_gate) forget_gate = torch.sigmoid(forget_gate) memory_init = torch.tanh(memory_init) output_gate = torch.sigmoid(output_gate) memory = input_gate * memory_init + forget_gate * memory_state timestep_output: torch.Tensor = output_gate * torch.tanh(memory) if self.use_highway: highway_input_projection = projected_input[ :, self._highway_inp_proj_start : self._highway_inp_proj_end ] timestep_output = ( highway_gate * timestep_output + (1 - highway_gate) * highway_input_projection # noqa ) if variational_dropout_mask is not None and self.training: timestep_output = timestep_output * variational_dropout_mask return timestep_output, memory
[docs]class AugmentedLSTMUnidirectional(nn.Module): """ `AugmentedLSTMUnidirectional` implements a one-layer single directional AugmentedLSTM layer. AugmentedLSTM is an LSTM which optionally appends an optional highway network to the output layer. Furthermore the dropout controlls the level of variational dropout done. Args: embed_dim (int): The number of expected features in the input. lstm_dim (int): Number of features in the hidden state of the LSTM. Defaults to 32. go_forward (bool): Whether to compute features left to right (forward) or right to left (backward). recurrent_dropout_probability (float): Variational dropout probability to use. Defaults to 0.0. use_highway (bool): If `True` we append a highway network to the outputs of the LSTM. use_input_projection_bias (bool): If `True` we use a bias in our LSTM calculations, otherwise we don't. Attributes: cell (AugmentedLSTMCell): AugmentedLSTMCell that is applied at every timestep. """ def __init__( self, embed_dim: int, lstm_dim: int, go_forward: bool = True, recurrent_dropout_probability: float = 0.0, use_highway: bool = True, use_input_projection_bias: bool = True, ): super().__init__() self.embed_dim = embed_dim self.lstm_dim = lstm_dim self.go_forward = go_forward self.use_highway = use_highway self.recurrent_dropout_probability = recurrent_dropout_probability self.cell = AugmentedLSTMCell( self.embed_dim, self.lstm_dim, self.use_highway, use_input_projection_bias ) log_class_usage(__class__)
[docs] def get_dropout_mask( self, dropout_probability: float, tensor_for_masking: torch.Tensor ) -> torch.Tensor: binary_mask = (torch.rand(tensor_for_masking.size()) > dropout_probability).to( tensor_for_masking.device ) # Scale mask by 1/keep_prob to preserve output statistics. dropout_mask = binary_mask.float().div(1.0 - dropout_probability) return dropout_mask
[docs] def forward( self, inputs: PackedSequence, states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]: """ Warning: DO NOT USE THIS LAYER DIRECTLY, INSTEAD USE the AugmentedLSTM class Given an input batch of sequential data such as word embeddings, produces a single layer unidirectional AugmentedLSTM representation of the sequential input and new state tensors. Args: inputs (PackedSequence): Input tensor of shape (bsize x seq_len x input_dim). states (Tuple[torch.Tensor, torch.Tensor]): Tuple of tensors containing the initial hidden state and the cell state of each element in the batch. Each of these tensors have a dimension of (1 x bsize x num_directions * nhid). Defaults to `None`. Returns: Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]: AgumentedLSTM representation of input and the state of the LSTM `t = seq_len`. Shape of representation is (bsize x seq_len x representation_dim). Shape of each state is (1 x bsize x nhid). """ sequence_tensor, batch_lengths = pad_packed_sequence(inputs, batch_first=True) batch_size = sequence_tensor.size()[0] total_timesteps = sequence_tensor.size()[1] output_accumulator = sequence_tensor.new_zeros( batch_size, total_timesteps, self.lstm_dim ) if states is None: full_batch_previous_memory = sequence_tensor.new_zeros( batch_size, self.lstm_dim ) full_batch_previous_state = sequence_tensor.data.new_zeros( batch_size, self.lstm_dim ) else: full_batch_previous_state = states[0].squeeze(0) full_batch_previous_memory = states[1].squeeze(0) current_length_index = batch_size - 1 if self.go_forward else 0 if self.recurrent_dropout_probability > 0.0: dropout_mask = self.get_dropout_mask( self.recurrent_dropout_probability, full_batch_previous_memory ) else: dropout_mask = None for timestep in range(total_timesteps): index = timestep if self.go_forward else total_timesteps - timestep - 1 if self.go_forward: while batch_lengths[current_length_index] <= index: current_length_index -= 1 # If we're going backwards, we are _picking up_ more indices. else: # First conditional: Are we already at the maximum # number of elements in the batch? # Second conditional: Does the next shortest # sequence beyond the current batch # index require computation use this timestep? while ( current_length_index < (len(batch_lengths) - 1) and batch_lengths[current_length_index + 1] > index ): current_length_index += 1 previous_memory = full_batch_previous_memory[ 0 : current_length_index + 1 ].clone() previous_state = full_batch_previous_state[ 0 : current_length_index + 1 ].clone() timestep_input = sequence_tensor[0 : current_length_index + 1, index] timestep_output, memory = self.cell( timestep_input, (previous_state, previous_memory), dropout_mask[0 : current_length_index + 1] if dropout_mask is not None else None, ) full_batch_previous_memory = full_batch_previous_memory.data.clone() full_batch_previous_state = full_batch_previous_state.data.clone() full_batch_previous_memory[0 : current_length_index + 1] = memory full_batch_previous_state[0 : current_length_index + 1] = timestep_output output_accumulator[0 : current_length_index + 1, index, :] = timestep_output output_accumulator = pack_padded_sequence( output_accumulator, batch_lengths, batch_first=True ) # Mimic the pytorch API by returning state in the following shape: # (num_layers * num_directions, batch_size, lstm_dim). As this # LSTM cannot be stacked, the first dimension here is just 1. final_state = ( full_batch_previous_state.unsqueeze(0), full_batch_previous_memory.unsqueeze(0), ) return output_accumulator, final_state
[docs]class AugmentedLSTM(RepresentationBase): """ `AugmentedLSTM` implements a generic AugmentedLSTM representation layer. AugmentedLSTM is an LSTM which optionally appends an optional highway network to the output layer. Furthermore the dropout controlls the level of variational dropout done. Args: config (Config): Configuration object of type BiLSTM.Config. embed_dim (int): The number of expected features in the input. padding_value (float): Value for the padded elements. Defaults to 0.0. Attributes: padding_value (float): Value for the padded elements. forward_layers (nn.ModuleList): A module list of unidirectional AugmentedLSTM layers moving forward in time. backward_layers (nn.ModuleList): A module list of unidirectional AugmentedLSTM layers moving backward in time. representation_dim (int): The calculated dimension of the output features of AugmentedLSTM. """
[docs] class Config(RepresentationBase.Config, ConfigBase): """ Configuration class for `AugmentedLSTM`. Attributes: dropout (float): Variational dropout probability to use. Defaults to 0.0. lstm_dim (int): Number of features in the hidden state of the LSTM. Defaults to 32. num_layers (int): Number of recurrent layers. Eg. setting `num_layers=2` would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in the outputs of the first LSTM and computing the final result. Defaults to 1. bidirectional (bool): If `True`, becomes a bidirectional LSTM. Defaults to `True`. use_highway (bool): If `True` we append a highway network to the outputs of the LSTM. use_bias (bool): If `True` we use a bias in our LSTM calculations, otherwise we don't. """ dropout: float = 0.0 lstm_dim: int = 32 use_highway: bool = True bidirectional: bool = False num_layers: int = 1 use_bias: bool = False
def __init__( self, config: Config, embed_dim: int, padding_value: float = 0.0 ) -> None: super().__init__(config) self.config = config self.embed_dim = embed_dim self.padding_value = padding_value self.lstm_dim = self.config.lstm_dim self.num_layers = self.config.num_layers self.bidirectional = self.config.bidirectional self.dropout = self.config.dropout self.use_highway = self.config.use_highway self.use_bias = self.config.use_bias num_directions = int(self.bidirectional) + 1 self.forward_layers = nn.ModuleList() if self.bidirectional: self.backward_layers = nn.ModuleList() lstm_embed_dim = embed_dim for _ in range(self.num_layers): self.forward_layers.append( AugmentedLSTMUnidirectional( lstm_embed_dim, self.lstm_dim, go_forward=True, recurrent_dropout_probability=self.dropout, use_highway=self.use_highway, use_input_projection_bias=self.use_bias, ) ) if self.bidirectional: self.backward_layers.append( AugmentedLSTMUnidirectional( lstm_embed_dim, self.lstm_dim, go_forward=False, recurrent_dropout_probability=self.dropout, use_highway=self.use_highway, use_input_projection_bias=self.use_bias, ) ) lstm_embed_dim = self.lstm_dim * num_directions self.representation_dim = lstm_embed_dim log_class_usage(__class__)
[docs] def forward( self, embedded_tokens: torch.Tensor, seq_lengths: torch.Tensor, states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Given an input batch of sequential data such as word embeddings, produces a AugmentedLSTM representation of the sequential input and new state tensors. Args: embedded_tokens (torch.Tensor): Input tensor of shape (bsize x seq_len x input_dim). seq_lengths (torch.Tensor): List of sequences lengths of each batch element. states (Tuple[torch.Tensor, torch.Tensor]): Tuple of tensors containing the initial hidden state and the cell state of each element in the batch. Each of these tensors have a dimension of (bsize x num_layers x num_directions * nhid). Defaults to `None`. Returns: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: AgumentedLSTM representation of input and the state of the LSTM `t = seq_len`. Shape of representation is (bsize x seq_len x representation_dim). Shape of each state is (bsize x num_layers * num_directions x nhid). """ rnn_input = pack_padded_sequence( embedded_tokens, seq_lengths.int(), batch_first=True ) if states is not None: states = (states[0].transpose(0, 1), states[1].transpose(0, 1)) if self.bidirectional: return self._forward_bidirectional(rnn_input, states) return self._forward_unidirectional(rnn_input, states)
def _forward_bidirectional( self, inputs: PackedSequence, states: Optional[Tuple[torch.Tensor, torch.Tensor]], ): output_sequence = inputs final_h = [] final_c = [] if not states: hidden_states = [None] * self.num_layers elif states[0].size()[0] != self.num_layers: raise RuntimeError( "Initial states were passed to forward() but the number of " "initial states does not match the number of layers." ) else: hidden_states = list( # noqa zip( states[0].chunk(self.num_layers, 0), states[1].chunk(self.num_layers, 0), ) ) for i, state in enumerate(hidden_states): if state: forward_state = state[0].chunk(2, -1) backward_state = state[1].chunk(2, -1) else: forward_state = backward_state = None forward_layer = self.forward_layers[i] backward_layer = self.backward_layers[i] # The state is duplicated to mirror the Pytorch API for LSTMs. forward_output, final_forward_state = forward_layer( output_sequence, forward_state ) backward_output, final_backward_state = backward_layer( output_sequence, backward_state ) forward_output, lengths = pad_packed_sequence( forward_output, batch_first=True ) backward_output, _ = pad_packed_sequence(backward_output, batch_first=True) output_sequence = torch.cat([forward_output, backward_output], -1) output_sequence = pack_padded_sequence( output_sequence, lengths, batch_first=True ) final_h.extend([final_forward_state[0], final_backward_state[0]]) final_c.extend([final_forward_state[1], final_backward_state[1]]) final_h = torch.cat(final_h, dim=0).transpose(0, 1) final_c = torch.cat(final_c, dim=0).transpose(0, 1) final_state_tuple = (final_h, final_c) output_sequence, _ = pad_packed_sequence( output_sequence, padding_value=self.padding_value, batch_first=True ) return output_sequence, final_state_tuple def _forward_unidirectional( self, inputs: PackedSequence, states: Optional[Tuple[torch.Tensor, torch.Tensor]], ): output_sequence = inputs final_h = [] final_c = [] if not states: hidden_states = [None] * self.num_layers elif states[0].size()[0] != self.num_layers: raise RuntimeError( "Initial states were passed to forward() but the number of " "initial states does not match the number of layers." ) else: hidden_states = list( # noqa zip( states[0].chunk(self.num_layers, 0), states[1].chunk(self.num_layers, 0), ) ) for i, state in enumerate(hidden_states): forward_layer = self.forward_layers[i] # The state is duplicated to mirror the Pytorch API for LSTMs. forward_output, final_forward_state = forward_layer(output_sequence, state) output_sequence = forward_output final_h.append(final_forward_state[0]) final_c.append(final_forward_state[1]) final_h = torch.cat(final_h, dim=0).transpose(0, 1) final_c = torch.cat(final_c, dim=0).transpose(0, 1) final_state_tuple = (final_h, final_c) output_sequence, _ = pad_packed_sequence( output_sequence, padding_value=self.padding_value, batch_first=True ) return output_sequence, final_state_tuple