#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Optional, Dict, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from pytext.config import ConfigBase
from pytext.utils.usage import log_class_usage
from torch import Tensor
from torch import nn
from .base import PyTextIncrementalDecoderComponent
from .utils import Linear
[docs]def create_src_lengths_mask(batch_size: int, src_lengths):
"""
Generate boolean mask to prevent attention beyond the end of source
Inputs:
batch_size : int
src_lengths : [batch_size] of sentence lengths
Outputs:
[batch_size, max_src_len]
"""
max_srclen = src_lengths.max()
src_indices = torch.arange(0, max_srclen).unsqueeze(0).type_as(src_lengths)
src_indices = src_indices.expand(batch_size, max_srclen)
src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_srclen)
# returns [batch_size, max_seq_len]
return (src_indices < src_lengths).int().detach()
[docs]def masked_softmax(scores, src_lengths, src_length_masking: bool = True):
"""Apply source length masking then softmax.
Input and output have shape bsz x src_len"""
if src_length_masking:
bsz, max_src_len = scores.size()
# compute masks
src_mask = create_src_lengths_mask(bsz, src_lengths)
# Fill pad positions with -inf
scores = scores.masked_fill(src_mask == 0, -np.inf)
# Cast to float and then back again to prevent loss explosion under fp16.
return F.softmax(scores.float(), dim=-1).type_as(scores)
[docs]class DotAttention(nn.Module):
def __init__(
self,
decoder_hidden_state_dim,
context_dim,
force_projection=False,
src_length_masking=True,
):
super().__init__()
self.decoder_hidden_state_dim = decoder_hidden_state_dim
self.context_dim = context_dim
self.input_proj = None
if force_projection or decoder_hidden_state_dim != context_dim:
self.input_proj = nn.Linear(
decoder_hidden_state_dim, context_dim, bias=True
)
self.src_length_masking = src_length_masking
log_class_usage(__class__)
[docs] def forward(self, decoder_state, source_hids, src_lengths):
# Reshape to bsz x src_len x context_dim
source_hids = source_hids.transpose(0, 1)
# decoder_state: bsz x context_dim
if self.input_proj is not None:
decoder_state = self.input_proj(decoder_state)
# compute attention (bsz x src_len x context_dim) * (bsz x context_dim x 1)
attn_scores = torch.bmm(source_hids, decoder_state.unsqueeze(2)).squeeze(2)
# Mask + softmax (bsz x src_len)
normalized_masked_attn_scores = masked_softmax(
attn_scores, src_lengths, self.src_length_masking
)
# Sum weighted sources
attn_weighted_context = (
(source_hids * normalized_masked_attn_scores.unsqueeze(2))
.contiguous()
.sum(1)
)
return attn_weighted_context, normalized_masked_attn_scores.t()
[docs]class MultiheadAttention(PyTextIncrementalDecoderComponent):
"""
Refer Attention is All You Need for more details.
This is a simplified implementation of multihead attention
optimized for exporting using torchscript. Usage of nn.Linear() instead of
F.Linear() helps to quantize the linear layers.
Query represents the output from last decoder step. Key and Values are obtained from
encoder. Attention weights are obtained from the dot product of query and key.
Attention weights multiplied by the value gives output.
"""
[docs] class Config(ConfigBase):
dropout: float = 0.0
kdim: Optional[int] = None
vdim: Optional[int] = None
bias: bool = True
[docs] @classmethod
def from_config(cls, config, embed_dim, num_heads):
return cls(embed_dim, num_heads, **config._asdict())
def __init__(self, embed_dim, num_heads, dropout, kdim=None, vdim=None, bias=True):
super().__init__()
self.embed_dim = embed_dim
self.kdim = embed_dim if kdim is None else kdim
self.vdim = embed_dim if vdim is None else vdim
self.q_proj = Linear(embed_dim, embed_dim, bias=bias)
self.k_proj = Linear(self.kdim, embed_dim, bias=bias)
self.v_proj = Linear(self.vdim, embed_dim, bias=bias)
self.num_heads = num_heads
self.dropout = nn.Dropout(dropout)
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
[docs] def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor],
need_weights: bool,
incremental_state: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
target_len, bsz, embed_dim = query.size()
src_len = key.size(0)
assert embed_dim == self.embed_dim, (
str(embed_dim) + " != " + str(self.embed_dim)
)
assert key is not None
assert value is not None
if incremental_state is not None:
prev_key = self._get_input_buffer(incremental_state, "prev_key")
else:
prev_key = None
bsz_X_num_heads = bsz * self.num_heads
q = self.q_proj(query)
q *= self.scaling
q = (
q.contiguous()
.view(target_len, bsz_X_num_heads, self.head_dim)
.transpose(0, 1)
)
if prev_key is not None and incremental_state is not None:
# This happens if its incremental decoding and prev time step has been
# cached. This condition won't be true for the first step in
# incremental decoding.
k = prev_key.view(bsz_X_num_heads, -1, self.head_dim)
prev_value = self._get_input_buffer(incremental_state, "prev_value")
assert prev_value is not None
v = prev_value.view(bsz_X_num_heads, -1, self.head_dim)
else:
# We will recompute key and value for all regular training and
# for first step of incremental decoding
k = self.k_proj(key)
k = k.contiguous().view(-1, bsz_X_num_heads, self.head_dim).transpose(0, 1)
v = self.v_proj(value)
v = v.contiguous().view(-1, bsz_X_num_heads, self.head_dim).transpose(0, 1)
# incremental state needs to be set only for the first decoder step
# when prev_key and prev_value was not present in incremental_state
if incremental_state is not None:
self._set_input_buffer(
incremental_state,
"prev_key",
k.view(bsz, self.num_heads, -1, self.head_dim),
)
self._set_input_buffer(
incremental_state,
"prev_value",
v.view(bsz, self.num_heads, -1, self.head_dim),
)
key_padding_mask = self._get_input_buffer(
incremental_state, "prev_key_padding_mask"
)
if key_padding_mask is not None:
self._set_input_buffer(
incremental_state, "prev_key_padding_mask", key_padding_mask
)
# q.size() : bsz_X_num_heads, target_len, self.head_dim
assert list(k.size()) == [
bsz_X_num_heads,
src_len,
self.head_dim,
], f"key.size() :{ k.size()} [ bsz_X_num_heads, src_len, self.head_dim] : [{bsz_X_num_heads}, {src_len}, {self.head_dim}]"
attn_weights = torch.bmm(q, k.transpose(1, 2))
# attn_weights.size() : bsz_X_num_heads, target_len, src_len
# Don't attend to padding symbols
if key_padding_mask is not None:
attn_weights = attn_weights.view(bsz, self.num_heads, target_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
)
attn_weights = attn_weights.view(bsz_X_num_heads, target_len, src_len)
assert list(attn_weights.size()) == [bsz_X_num_heads, target_len, src_len]
attn_weights = F.softmax(attn_weights, dim=-1)
attn_probs = self.dropout(attn_weights)
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz_X_num_heads, target_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(target_len, bsz, embed_dim)
attn = self.out_proj(attn)
if need_weights:
attn_weights = attn_weights.view(
bsz, self.num_heads, target_len, src_len
).transpose(1, 0)
attn_weights_out = attn_weights.mean(dim=0)
else:
attn_weights_out = None
return attn, attn_weights_out
[docs] def reorder_incremental_state(
self, incremental_state: Dict[str, Tensor], new_order: Tensor
):
"""Reorder buffered internal state (for incremental generation)."""
all_keys = ["prev_key", "prev_value", "prev_key_padding_mask"]
# ARBABU : why do we need to reorder_incremental_state as encoder_out
# is always the same?
for key in all_keys:
input_buffer = self._get_input_buffer(incremental_state, key)
if input_buffer is not None:
# During incremental decoding, all candidates will be along
# the batch dimension. We pick top candidates
input_buffer = input_buffer.index_select(0, new_order)
self._set_input_buffer(incremental_state, key, input_buffer)
def _get_input_buffer(self, incremental_state: Dict[str, Tensor], key: str):
return self.get_incremental_state(incremental_state, key)
def _set_input_buffer(
self, incremental_state: Dict[str, Tensor], key: str, value: Tensor
):
self.set_incremental_state(incremental_state, key, value)
[docs]class DecoupledMultiheadAttention(nn.Module):
"""
Multiheaded Scaled Dot Product Attention. This function
has the same exact signature as the one used in pytorch_translate
with the added benefit of supporting torchscript
"""
def __init__(
self,
embed_dim: int,
context_dim: int,
num_heads: int,
dropout: float,
unseen_mask=False,
src_length_mask=True,
):
super().__init__()
assert embed_dim == context_dim
d_model = embed_dim
assert d_model % num_heads == 0
if unseen_mask:
raise NotImplementedError(
"Unseen mask not supported with sequential decoding"
)
self._attn = MultiheadAttention(d_model, num_heads, dropout)
self.use_src_length_mask = src_length_mask
[docs] def forward(
self,
decoder_state: Tensor,
source_hids: Tensor,
src_len_mask: Optional[Tensor],
squeeze: bool = True,
) -> Tuple[Tensor, Tensor]:
"""
Computes MultiheadAttention with respect to either a vector
or a tensor
Inputs:
decoder_state: (bsz x decoder_hidden_state_dim) or
(bsz x T x decoder_hidden_state_dim)
source_hids: srclen x bsz x context_dim
src_lengths: bsz x 1, actual sequence lengths
squeeze: Whether or not to squeeze on the time dimension.
Even if decoder_state.dim() is 2 dimensional an
explicit time step dimension will be unsqueezed.
Outputs:
[batch_size, max_src_len] if decoder_state.dim() == 2 & squeeze
or
[batch_size, 1, max_src_len] if decoder_state.dim() == 2 & !squeeze
or
[batch_size, T, max_src_len] if decoder_state.dim() == 3 & !squeeze
or
[batch_size, T, max_src_len] if decoder_state.dim() == 3 & squeeze & T != 1
or
[batch_size, max_src_len] if decoder_state.dim() == 3 & squeeze & T == 1
"""
if decoder_state.dim() == 3:
query = decoder_state
elif decoder_state.dim() == 2:
query = decoder_state.unsqueeze(1)
else:
raise ValueError("decoder state must be either 2 or 3 dimensional")
query = query.transpose(0, 1)
value = key = source_hids
attn, attn_weights = self._attn.forward(
query, key, value, key_padding_mask=src_len_mask, need_weights=True
)
# Need to satify torchscript here
if attn_weights is None:
raise NotImplementedError("")
# attn.shape = T X bsz X embed_dim
# attn_weights.shape = bsz X T X src_len
attn_weights = attn_weights.transpose(0, 2)
# attn_weights.shape = src_len X T X bsz
if squeeze:
attn = attn.squeeze(0)
# attn.shape = squeeze(T) X bsz X embed_dim
attn_weights = attn_weights.squeeze(1)
# attn_weights.shape = src_len X squeeze(T) X bsz
return attn, attn_weights
return attn, attn_weights