Source code for pytext.models.representations.transformer.transformer
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, Optional
import torch
import torch.nn.functional as F
from pytext.utils.usage import log_class_usage
from torch import nn
from .multihead_attention import MultiheadSelfAttention
from .positional_embedding import PositionalEmbedding
from .residual_mlp import ResidualMLP
DEFAULT_EMBEDDING_DIM = 768
DEFAULT_VOCAB_SIZE = 50265
DEFAULT_PADDING_IDX = 1
DEFAULT_NUM_LAYERS = 12
DEFAULT_MAX_SEQUENCE_LENGTH = 514
DEFAULT_NUM_ATTENTION_HEADS = 12
[docs]class TransformerLayer(nn.Module):
def __init__(
self,
embedding_dim: int = DEFAULT_EMBEDDING_DIM,
attention: Optional[MultiheadSelfAttention] = None,
residual_mlp: Optional[ResidualMLP] = None,
dropout: float = 0.1,
):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.attention = attention or MultiheadSelfAttention(
embedding_dim, num_heads=DEFAULT_NUM_ATTENTION_HEADS
)
self.residual_mlp = residual_mlp or ResidualMLP(
embedding_dim, hidden_dims=[embedding_dim * 4]
)
self.attention_layer_norm = nn.LayerNorm(embedding_dim)
self.final_layer_norm = nn.LayerNorm(embedding_dim)
log_class_usage(__class__)
[docs] def forward(self, input, key_padding_mask):
attention = self.attention(input, key_padding_mask)
attention = self.dropout(attention)
biased_input = input + attention
biased_input = self.attention_layer_norm(biased_input)
biased = self.residual_mlp(biased_input)
return self.final_layer_norm(biased)
[docs]class Transformer(nn.Module):
def __init__(
self,
vocab_size: int = DEFAULT_VOCAB_SIZE,
embedding_dim: int = DEFAULT_EMBEDDING_DIM,
padding_idx: int = DEFAULT_PADDING_IDX,
max_seq_len: int = DEFAULT_MAX_SEQUENCE_LENGTH,
layers: List[TransformerLayer] = (),
dropout: float = 0.1,
):
super().__init__()
self.padding_idx = padding_idx
self.token_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx)
self.layers = nn.ModuleList(
layers
or [TransformerLayer(embedding_dim) for _ in range(DEFAULT_NUM_LAYERS)]
)
self.positional_embedding = PositionalEmbedding(
max_seq_len, embedding_dim, padding_idx
)
self.embedding_layer_norm = nn.LayerNorm(embedding_dim)
self.dropout = nn.Dropout(dropout)
log_class_usage(__class__)
[docs] def forward(self, tokens: torch.Tensor) -> List[torch.Tensor]:
# compute padding mask. This is needed for multi-head attention
padding_mask = tokens.eq(self.padding_idx)
embedded = self.token_embedding(tokens)
embedded_positions = self.positional_embedding(tokens)
normed = self.embedding_layer_norm(embedded + embedded_positions)
normed = self.dropout(normed)
# account for padding while computing the representation
padded_normed = normed * (1 - padding_mask.unsqueeze(-1).type_as(normed))
# B x T x C -> T x B x C
encoded = padded_normed.transpose(0, 1)
states = [encoded]
for layer in self.layers:
encoded = layer(encoded, padding_mask)
states.append(encoded)
# states are returned as T x B x C
# commonly you can retrieve a single "sentence representation" as
# states[-1].transpose(0, 1)
return states
[docs]class SELFIETransformer(Transformer):
[docs] def forward(
self, tokens: torch.Tensor, dense: List[torch.Tensor]
) -> List[torch.Tensor]:
# compute padding mask. This is needed for multi-head attention
padding_mask = tokens.eq(self.padding_idx)
embedded = self.token_embedding(tokens)
embedded_positions = self.positional_embedding(tokens)
normed = self.embedding_layer_norm(embedded + embedded_positions)
normed = self.dropout(normed)
# account for padding while computing the representation
padded_normed = normed * (1 - padding_mask.unsqueeze(-1).type_as(normed))
# Selfie transformer prepends dense input as first token.
# Dim of dense must be <= embedding_dim, for now
for i in range(len(dense)):
padded_dense = F.pad(
dense[i], (0, embedded.size(2) - dense[i].size(1), 0, 0), value=1.0
)
padded_normed = torch.cat([padded_dense.unsqueeze(1), padded_normed], dim=1)
padding_mask = F.pad(padding_mask, (1, 0, 0, 0), value=0.0)
# B x T x C -> T x B x C
encoded = padded_normed.transpose(0, 1)
states = [encoded]
for layer in self.layers:
encoded = layer(encoded, padding_mask)
states.append(encoded)
# states are returned as T x B x C
# commonly you can retrieve a single "sentence representation" as
# states[-1].transpose(0, 1)
return states