#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import torch.nn as nn
import torch.onnx.operators
from pytext.config import ConfigBase
from pytext.models.module import Module
from pytext.utils.usage import log_class_usage
[docs]class SelfAttention(Module):
[docs] class Config(ConfigBase):
attn_dimension: int = 64
dropout: float = 0.4
def __init__(self, config: Config, n_input: int) -> None:
super().__init__(config)
self.dropout = nn.Dropout(config.dropout)
self.n_input = n_input
self.n_attn = config.attn_dimension
self.ws1 = nn.Linear(n_input, self.n_attn, bias=False)
self.ws2 = nn.Linear(self.n_attn, 1, bias=False)
self.tanh = nn.Tanh()
self.softmax = nn.Softmax()
self.init_weights()
log_class_usage(__class__)
[docs] def init_weights(self, init_range: float = 0.1) -> None:
self.ws1.weight.data.uniform_(-init_range, init_range)
self.ws2.weight.data.uniform_(-init_range, init_range)
[docs] def forward(
self, inputs: torch.Tensor, seq_lengths: torch.Tensor = None
) -> torch.Tensor:
# size: (bsz, sent_len, rep_dim)
size = torch.onnx.operators.shape_as_tensor(inputs)
flat_2d_shape = torch.cat((torch.LongTensor([-1]), size[2].view(1)))
compressed_emb = torch.onnx.operators.reshape_from_tensor_shape(
inputs, flat_2d_shape
) # (bsz * sent_len, rep_len)
hbar = self.tanh(
self.ws1(self.dropout(compressed_emb))
) # (bsz * sent_len, attention_dim)
alphas = self.ws2(hbar) # (bsz * sent_len, 1)
alphas = torch.onnx.operators.reshape_from_tensor_shape(
alphas, size[:2]
) # (bsz, sent_len)
alphas = self.softmax(alphas) # (bsz, sent_len)
# (bsz, rep_dim)
return torch.bmm(alphas.unsqueeze(1), inputs).squeeze(1)
[docs]class MaxPool(Module):
def __init__(self, config: Module.Config, n_input: int) -> None:
super().__init__(config)
log_class_usage(__class__)
[docs] def forward(
self, inputs: torch.Tensor, seq_lengths: torch.Tensor = None
) -> torch.Tensor:
return torch.max(inputs, 1)[0]
[docs]class MeanPool(Module):
def __init__(self, config: Module.Config, n_input: int) -> None:
super().__init__(config)
log_class_usage(__class__)
[docs] def forward(self, inputs: torch.Tensor, seq_lengths: torch.Tensor) -> torch.Tensor:
return torch.sum(inputs, 1) / seq_lengths.unsqueeze(1).float()
[docs]class NoPool(Module):
def __init__(self, config: Module.Config, n_input: int) -> None:
super().__init__(config)
log_class_usage(__class__)
[docs] def forward(
self, inputs: torch.Tensor, seq_lengths: torch.Tensor = None
) -> torch.Tensor:
return inputs
[docs]class BoundaryPool(Module):
[docs] class Config(ConfigBase):
# first, last, firstlast
boundary_type: str = "first"
def __init__(self, config: Config, n_input: int) -> None:
super().__init__(config)
self.boundary_type = config.boundary_type
log_class_usage(__class__)
[docs] def forward(
self, inputs: torch.Tensor, seq_lengths: torch.Tensor = None
) -> torch.Tensor:
max_len = inputs.size()[1]
if self.boundary_type == "first":
return inputs[:, 0, :]
elif self.boundary_type == "last":
# could only have the bos values if add_bos or add_eos as False
# should not reach here if the eos is not added.
assert max_len > 1
return inputs[:, max_len - 1, :]
elif self.boundary_type == "firstlast":
assert max_len > 1
# merge from embed_dim into 2*emded_dim
return torch.cat((inputs[:, 0, :], inputs[:, max_len - 1, :]), dim=1)
else:
raise Exception("Unknown configuration type {}".format(self.boundary_type))
[docs]class LastTimestepPool(Module):
def __init__(self, config: Module.Config, n_input: int) -> None:
super().__init__(config)
log_class_usage(__class__)
[docs] def forward(self, inputs: torch.Tensor, seq_lengths: torch.Tensor) -> torch.Tensor:
# inputs: (bsz, max_len, dim)
# seq_lengths: (bsz,)
if torch._C._get_tracing_state():
# if it is exporting, the batch size = 1, so we return the last hidden state
# by returning the last dimension to avoid introducing extra operators
assert inputs.shape[0] == 1
return inputs[:, -1, :]
bsz, _, dim = inputs.shape
idx = seq_lengths.unsqueeze(1).expand(bsz, dim).unsqueeze(1)
return inputs.gather(1, idx - 1).squeeze(1)