#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, List, Sized, Tuple
import torch as torch
import torch.nn as nn
from pytext.torchscript.utils import reverse_tensor_list
from pytext.utils.cuda import FloatTensor
from pytext.utils.tensor import xaviervar
[docs]class Element:
"""
Generic element representing a token / non-terminal / sub-tree on a stack.
Used to compute valid actions in the RNNG parser.
"""
def __init__(self, node: Any) -> None:
self.node = node
def __eq__(self, other) -> bool:
return self.node == other.node
def __repr__(self) -> str:
return str(self.node)
[docs]class StackLSTM(Sized):
"""
The Stack LSTM from Dyer et al: https://arxiv.org/abs/1505.08075
"""
def __init__(self, lstm: nn.LSTM):
"""
Shapes:
initial_state: (lstm_layers, 1, lstm_hidden_dim) each
"""
self.lstm = lstm
initial_state = (
FloatTensor(lstm.num_layers, 1, lstm.hidden_size).fill_(0),
FloatTensor(lstm.num_layers, 1, lstm.hidden_size).fill_(0),
)
# Stack of (state, (embedding, element))
self.stack = [
(initial_state, (self._lstm_output(initial_state), Element("Root")))
]
def _lstm_output(self, state: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
"""
Shapes:
state: (lstm_layers, 1, lstm_hidden_dim) each
return value: (1, lstm_hidden_dim)
"""
return state[0][-1]
[docs] def push(self, expression: torch.Tensor, element: Element) -> None:
"""
Shapes:
expression: (1, lstm_input_dim)
"""
old_top_state = self.stack[-1][0]
# Unsqueezing expression for sequence_length = 1
_, new_top_state = self.lstm(expression.unsqueeze(0), old_top_state)
# Push in (state, (embedding, element))
self.stack.append((new_top_state, (self._lstm_output(new_top_state), element)))
[docs] def pop(self) -> Tuple[torch.Tensor, Element]:
"""
Pops and returns tuple of output embedding (1, lstm_hidden_dim) and element
"""
return self.stack.pop()[1]
[docs] def embedding(self) -> torch.Tensor:
"""
Shapes:
return value: (1, lstm_hidden_dim)
"""
assert len(self.stack) > 0, "stack size must be greater than 0"
top_state = self.stack[-1][0]
return self._lstm_output(top_state)
[docs] def element_from_top(self, index: int) -> Element:
return self.stack[-(index + 1)][1][1]
def __len__(self) -> int:
return len(self.stack) - 1
def __str__(self) -> str:
return "->".join([str(x[1][1]) for x in self.stack])
[docs] def copy(self):
other = StackLSTM(self.lstm)
other.stack = list(self.stack)
return other
[docs]class CompositionalNN(torch.jit.ScriptModule):
"""
Combines a list / sequence of embeddings into one using a biLSTM
"""
__constants__ = ["lstm_dim", "linear_seq"]
def __init__(self, lstm_dim: int):
super().__init__()
self.lstm_dim = lstm_dim
self.lstm_fwd = nn.LSTM(lstm_dim, lstm_dim, num_layers=1)
self.lstm_rev = nn.LSTM(lstm_dim, lstm_dim, num_layers=1)
self.linear_seq = nn.Sequential(nn.Linear(2 * lstm_dim, lstm_dim), nn.Tanh())
@torch.jit.script_method
def forward(self, x: List[torch.Tensor], device: str = "cpu") -> torch.Tensor:
"""
Embed the sequence. If the input corresponds to [IN:GL where am I at]:
- x will contain the embeddings of [at I am where IN:GL] in that order.
- Forward LSTM will embed the sequence [IN:GL where am I at].
- Backward LSTM will embed the sequence [IN:GL at I am where].
The final hidden states are concatenated and then projected.
Args:
x: Embeddings of the input tokens in *reversed* order
Shapes:
x: (1, lstm_dim) each
return value: (1, lstm_dim)
"""
# reset hidden state every time
lstm_hidden_fwd = (
xaviervar([1, 1, self.lstm_dim], device=device),
xaviervar([1, 1, self.lstm_dim], device=device),
)
lstm_hidden_rev = (
xaviervar([1, 1, self.lstm_dim], device=device),
xaviervar([1, 1, self.lstm_dim], device=device),
)
nonterminal_element = x[-1]
reversed_rest = x[:-1]
# Always put nonterminal_element at the front
fwd_input = [nonterminal_element] + reverse_tensor_list(reversed_rest)
rev_input = [nonterminal_element] + reversed_rest
stacked_fwd = self.lstm_fwd(torch.stack(fwd_input), lstm_hidden_fwd)[0][0]
stacked_rev = self.lstm_rev(torch.stack(rev_input), lstm_hidden_rev)[0][0]
combined = torch.cat([stacked_fwd, stacked_rev], dim=1)
subtree_embedding = self.linear_seq(combined)
return subtree_embedding
[docs]class CompositionalSummationNN(torch.jit.ScriptModule):
"""
Simpler version of CompositionalNN
"""
__constants__ = ["lstm_dim", "linear_seq"]
def __init__(self, lstm_dim: int):
super().__init__()
self.lstm_dim = lstm_dim
self.linear_seq = nn.Sequential(nn.Linear(lstm_dim, lstm_dim), nn.Tanh())
@torch.jit.script_method
def forward(self, x: List[torch.Tensor], device: str = "cpu") -> torch.Tensor:
combined = torch.sum(torch.cat(x, dim=0), dim=0, keepdim=True)
subtree_embedding = self.linear_seq(combined)
return subtree_embedding
[docs]class ParserState:
"""
Maintains state of the Parser. Useful for beam search
"""
def __init__(self, parser=None):
if not parser:
return
self.buffer_stackrnn = StackLSTM(parser.buff_rnn)
self.stack_stackrnn = StackLSTM(parser.stack_rnn)
self.action_stackrnn = StackLSTM(parser.action_rnn)
self.predicted_actions_idx = []
self.action_scores = []
self.num_open_NT = 0
self.is_open_NT: List[bool] = []
self.found_unsupported = False
self.action_p = torch.Tensor()
# negative cumulative log prob so sort(states) is in descending order
self.neg_prob = 0
[docs] def finished(self):
return len(self.stack_stackrnn) == 1 and len(self.buffer_stackrnn) == 0
[docs] def copy(self):
other = ParserState()
other.buffer_stackrnn = self.buffer_stackrnn.copy()
other.stack_stackrnn = self.stack_stackrnn.copy()
other.action_stackrnn = self.action_stackrnn.copy()
other.predicted_actions_idx = self.predicted_actions_idx.copy()
other.action_scores = self.action_scores.copy()
other.num_open_NT = self.num_open_NT
other.is_open_NT = self.is_open_NT.copy()
other.neg_prob = self.neg_prob
other.found_unsupported = self.found_unsupported
# detach to avoid making copies, only called in inference to share data
other.action_p = self.action_p.detach()
return other
def __gt__(self, other):
return self.neg_prob > other.neg_prob
def __eq__(self, other):
return self.neg_prob == other.neg_prob