Source code for pytext.models.representations.transformer.representation

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import torch.nn as nn
from pytext.config import ConfigBase
from pytext.models.module import Module
from pytext.utils.usage import log_class_usage
from torch import Tensor

from .multihead_attention import MultiheadSelfAttention
from .residual_mlp import ResidualMLP
from .transformer import TransformerLayer


[docs]class TransformerRepresentation(Module): """ Representation consisting of stacked multi-head self-attention and position-wise feed-forward layers. Unlike `Transformer`, we assume inputs are already embedded, thus this representation can be used as a drop-in replacement for other temporal representations over text inputs (e.g., `BiLSTM` and `DeepCNNDeepCNNRepresentation`). """
[docs] class Config(ConfigBase): num_layers: int = 3 num_attention_heads: int = 4 ffnn_embed_dim: int = 32 dropout: float = 0.0
def __init__(self, config: Config, embed_dim: int) -> None: super().__init__() self.layers = nn.ModuleList( [ self._create_transformer_layer(config, embed_dim) for _ in range(config.num_layers) ] ) log_class_usage(__class__) def _create_transformer_layer(self, config: Config, embed_dim: int): return TransformerLayer( embedding_dim=embed_dim, attention=MultiheadSelfAttention( embed_dim=embed_dim, num_heads=config.num_attention_heads ), residual_mlp=ResidualMLP( input_dim=embed_dim, hidden_dims=[config.ffnn_embed_dim], dropout=config.dropout, ), )
[docs] def forward(self, embedded_tokens: Tensor, padding_mask: Tensor) -> Tensor: """ Forward inputs through the transformer layers. Args: embedded_tokens (B x T x H): Tokens previously encoded with token, positional, and segment embeddings. padding_mask (B x T): Boolean mask specifying token positions that self-attention should not operate on. Returns: last_state (B x T x H): Final transformer layer state. """ last_state = embedded_tokens.transpose(0, 1) for layer in self.layers: last_state = layer(last_state, padding_mask) return last_state.transpose(0, 1)