Source code for pytext.models.representations.transformer.multihead_attention

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import List

import numpy as np
import torch
from pytext.utils.usage import log_class_usage
from torch import nn
from torch.nn import functional as F


[docs]class MultiheadSelfAttention(nn.Module): """ This is a TorchScriptable implementation of MultiheadAttention from fairseq for the purposes of creating a productionized RoBERTa model. It distills just the elements which are required to implement the RoBERTa use cases of MultiheadAttention, and within that is restructured and rewritten to be able to be compiled by TorchScript for production use cases. The default constructor values match those required to import the public RoBERTa weights. Unless you are pretraining your own model, there's no need to change them. """ def __init__( self, embed_dim: int, num_heads: int, scaling: float = 0.125, dropout: float = 0.1, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scaling = scaling self.dropout = nn.Dropout(dropout) self.input_projection = nn.Linear(embed_dim, 3 * embed_dim) self.output_projection = nn.Linear(embed_dim, embed_dim) log_class_usage(__class__)
[docs] def prune_multi_heads(self, heads: List[int]): mask = torch.ones(self.num_heads * 3, self.head_dim) for head in heads: mask[head] = 0 mask[head + self.num_heads] = 0 mask[head + self.num_heads * 2] = 0 mask = mask.view(-1).contiguous().eq(1) if torch.onnx.is_in_onnx_export(): index = np.arange(len(mask), dtype=np.int64) index = torch.from_numpy(index).to(device=mask.device) else: index = torch.arange(len(mask), dtype=torch.long, device=mask.device) # Prune linear layers self.input_projection = self._prune_linear_layer( self.input_projection, index[mask], dim=0 ) self.output_projection = self._prune_linear_layer( self.output_projection, index[0 : self.num_heads * self.head_dim][ mask[0 : self.num_heads * self.head_dim] ], dim=1, ) # Update hyper params self.num_heads = self.num_heads - len(heads)
def _prune_linear_layer( self, layer: torch.nn.Linear, index: torch.Tensor, dim: int = 0 ): """ Prune a linear layer (a model parameters) to keep only entries in index. Return the pruned layer as a new layer with requires_grad=True. Used to remove heads. """ index = index.to(layer.weight.device) W = layer.weight.index_select(dim, index).clone().detach() if layer.bias is not None: if dim == 1: b = layer.bias.clone().detach() else: b = layer.bias[index].clone().detach() new_size = list(layer.weight.size()) new_size[dim] = len(index) new_layer = torch.nn.Linear( new_size[1], new_size[0], bias=layer.bias is not None ).to(layer.weight.device) new_layer.weight.requires_grad = False new_layer.weight.copy_(W.contiguous()) new_layer.weight.requires_grad = True if layer.bias is not None: new_layer.bias.requires_grad = False new_layer.bias.copy_(b.contiguous()) new_layer.bias.requires_grad = True return new_layer
[docs] def forward(self, query, key_padding_mask): """Input shape: Time x Batch x Channel Timesteps can be masked by supplying a T x T mask in the `attn_mask` argument. Padding elements can be excluded from the key by passing a binary ByteTensor (`key_padding_mask`) with shape: batch x source_length, where padding elements are indicated by 1s. """ target_length, batch_size, embed_dim = query.size() mask_batch_size, source_length = key_padding_mask.size() torch._assert(embed_dim == self.embed_dim, "query embed dim doesn't match") torch._assert( batch_size == mask_batch_size, "query and key_padding_mask batch sizes differed", ) # input projection projection = self.input_projection(query) q, k, v = projection.chunk(3, dim=-1) q = self.scaling * q batch_heads = batch_size * self.num_heads q = q.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1) k = k.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1) v = v.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1) torch._assert( k.size(1) == source_length, "key size should be equal to source length" ) attn_weights = torch.bmm(q, k.transpose(1, 2)) torch._assert(attn_weights.dim() == 3, "Unexpected attn_weights dim") torch._assert( attn_weights.size(0) == batch_heads, "attn_weights shape didn't match for batch heads", ) torch._assert( attn_weights.size(1) == target_length, "attn_weights shape didn't match for target length", ) torch._assert( attn_weights.size(2) == source_length, "attn_weights shape didn't match for source length", ) # don't attend to padding symbols attn_weights = attn_weights.view( batch_size, self.num_heads, target_length, source_length ) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf") ) attn_weights = attn_weights.view(batch_heads, target_length, source_length) attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( attn_weights ) attn_weights = self.dropout(attn_weights) attn = torch.bmm(attn_weights, v) torch._assert( attn.dim() == 3, "unexpected attn dim size", ) torch._assert( attn.size(0) == batch_heads, "attn shape didn't match for batch heads", ) torch._assert( attn.size(1) == target_length, "attn shape didn't match for target length", ) torch._assert( attn.size(2) == self.head_dim, "attn shape didn't match for head dim", ) attn = ( attn.transpose(0, 1) .contiguous() .view(target_length, batch_size, self.head_dim * self.num_heads) ) attn = self.output_projection(attn) return attn