Source code for pytext.models.semantic_parsers.rnng.rnng_data_structures

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, List, Sized, Tuple

import torch as torch
import torch.nn as nn
from pytext.torchscript.utils import reverse_tensor_list
from pytext.utils.cuda import FloatTensor
from pytext.utils.tensor import xaviervar


[docs]class Element: """ Generic element representing a token / non-terminal / sub-tree on a stack. Used to compute valid actions in the RNNG parser. """ def __init__(self, node: Any) -> None: self.node = node def __eq__(self, other) -> bool: return self.node == other.node def __repr__(self) -> str: return str(self.node)
[docs]class StackLSTM(Sized): """ The Stack LSTM from Dyer et al: https://arxiv.org/abs/1505.08075 """ def __init__(self, lstm: nn.LSTM): """ Shapes: initial_state: (lstm_layers, 1, lstm_hidden_dim) each """ self.lstm = lstm initial_state = ( FloatTensor(lstm.num_layers, 1, lstm.hidden_size).fill_(0), FloatTensor(lstm.num_layers, 1, lstm.hidden_size).fill_(0), ) # Stack of (state, (embedding, element)) self.stack = [ (initial_state, (self._lstm_output(initial_state), Element("Root"))) ] def _lstm_output(self, state: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: """ Shapes: state: (lstm_layers, 1, lstm_hidden_dim) each return value: (1, lstm_hidden_dim) """ return state[0][-1]
[docs] def push(self, expression: torch.Tensor, element: Element) -> None: """ Shapes: expression: (1, lstm_input_dim) """ old_top_state = self.stack[-1][0] # Unsqueezing expression for sequence_length = 1 _, new_top_state = self.lstm(expression.unsqueeze(0), old_top_state) # Push in (state, (embedding, element)) self.stack.append((new_top_state, (self._lstm_output(new_top_state), element)))
[docs] def pop(self) -> Tuple[torch.Tensor, Element]: """ Pops and returns tuple of output embedding (1, lstm_hidden_dim) and element """ return self.stack.pop()[1]
[docs] def embedding(self) -> torch.Tensor: """ Shapes: return value: (1, lstm_hidden_dim) """ assert len(self.stack) > 0, "stack size must be greater than 0" top_state = self.stack[-1][0] return self._lstm_output(top_state)
[docs] def element_from_top(self, index: int) -> Element: return self.stack[-(index + 1)][1][1]
def __len__(self) -> int: return len(self.stack) - 1 def __str__(self) -> str: return "->".join([str(x[1][1]) for x in self.stack])
[docs] def copy(self): other = StackLSTM(self.lstm) other.stack = list(self.stack) return other
[docs]class CompositionalNN(torch.jit.ScriptModule): """ Combines a list / sequence of embeddings into one using a biLSTM """ __constants__ = ["lstm_dim", "linear_seq"] def __init__(self, lstm_dim: int): super().__init__() self.lstm_dim = lstm_dim self.lstm_fwd = nn.LSTM(lstm_dim, lstm_dim, num_layers=1) self.lstm_rev = nn.LSTM(lstm_dim, lstm_dim, num_layers=1) self.linear_seq = nn.Sequential(nn.Linear(2 * lstm_dim, lstm_dim), nn.Tanh()) @torch.jit.script_method def forward(self, x: List[torch.Tensor], device: str = "cpu") -> torch.Tensor: """ Embed the sequence. If the input corresponds to [IN:GL where am I at]: - x will contain the embeddings of [at I am where IN:GL] in that order. - Forward LSTM will embed the sequence [IN:GL where am I at]. - Backward LSTM will embed the sequence [IN:GL at I am where]. The final hidden states are concatenated and then projected. Args: x: Embeddings of the input tokens in *reversed* order Shapes: x: (1, lstm_dim) each return value: (1, lstm_dim) """ # reset hidden state every time lstm_hidden_fwd = ( xaviervar([1, 1, self.lstm_dim], device=device), xaviervar([1, 1, self.lstm_dim], device=device), ) lstm_hidden_rev = ( xaviervar([1, 1, self.lstm_dim], device=device), xaviervar([1, 1, self.lstm_dim], device=device), ) nonterminal_element = x[-1] reversed_rest = x[:-1] # Always put nonterminal_element at the front fwd_input = [nonterminal_element] + reverse_tensor_list(reversed_rest) rev_input = [nonterminal_element] + reversed_rest stacked_fwd = self.lstm_fwd(torch.stack(fwd_input), lstm_hidden_fwd)[0][0] stacked_rev = self.lstm_rev(torch.stack(rev_input), lstm_hidden_rev)[0][0] combined = torch.cat([stacked_fwd, stacked_rev], dim=1) subtree_embedding = self.linear_seq(combined) return subtree_embedding
[docs]class CompositionalSummationNN(torch.jit.ScriptModule): """ Simpler version of CompositionalNN """ __constants__ = ["lstm_dim", "linear_seq"] def __init__(self, lstm_dim: int): super().__init__() self.lstm_dim = lstm_dim self.linear_seq = nn.Sequential(nn.Linear(lstm_dim, lstm_dim), nn.Tanh()) @torch.jit.script_method def forward(self, x: List[torch.Tensor], device: str = "cpu") -> torch.Tensor: combined = torch.sum(torch.cat(x, dim=0), dim=0, keepdim=True) subtree_embedding = self.linear_seq(combined) return subtree_embedding
[docs]class ParserState: """ Maintains state of the Parser. Useful for beam search """ def __init__(self, parser=None): if not parser: return self.buffer_stackrnn = StackLSTM(parser.buff_rnn) self.stack_stackrnn = StackLSTM(parser.stack_rnn) self.action_stackrnn = StackLSTM(parser.action_rnn) self.predicted_actions_idx = [] self.action_scores = [] self.num_open_NT = 0 self.is_open_NT: List[bool] = [] self.found_unsupported = False self.action_p = torch.Tensor() # negative cumulative log prob so sort(states) is in descending order self.neg_prob = 0
[docs] def finished(self): return len(self.stack_stackrnn) == 1 and len(self.buffer_stackrnn) == 0
[docs] def copy(self): other = ParserState() other.buffer_stackrnn = self.buffer_stackrnn.copy() other.stack_stackrnn = self.stack_stackrnn.copy() other.action_stackrnn = self.action_stackrnn.copy() other.predicted_actions_idx = self.predicted_actions_idx.copy() other.action_scores = self.action_scores.copy() other.num_open_NT = self.num_open_NT other.is_open_NT = self.is_open_NT.copy() other.neg_prob = self.neg_prob other.found_unsupported = self.found_unsupported # detach to avoid making copies, only called in inference to share data other.action_p = self.action_p.detach() return other
def __gt__(self, other): return self.neg_prob > other.neg_prob def __eq__(self, other): return self.neg_prob == other.neg_prob