Source code for pytext.models.representations.sparse_transformer_sentence_encoder
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from fairseq.modules.sparse_transformer_sentence_encoder import (
SparseTransformerSentenceEncoder as SparseTransformerSentenceEncoderModule,
)
from pytext.config import ConfigBase
from pytext.models.representations.transformer_sentence_encoder import (
TransformerSentenceEncoder,
)
from pytext.utils.usage import log_class_usage
[docs]class SparseTransformerSentenceEncoder(TransformerSentenceEncoder):
"""
Implementation of the Transformer Sentence Encoder. This directly makes
use of the TransformerSentenceEncoder module in Fairseq.
A few interesting config options:
- encoder_normalize_before detemines whether the layer norm is applied
before or after self_attention. This is similar to original
implementation from Google.
- activation_fn can be set to 'gelu' instead of the default of 'relu'.
- project_representation adds a linear projection + tanh to the pooled output
in the style of BERT.
"""
[docs] class Config(TransformerSentenceEncoder.Config, ConfigBase):
# Dropout parameters
dropout: float = 0.1
attention_dropout: float = 0.1
activation_dropout: float = 0.1
# Parameters related to hidden states and self-attention
embedding_dim: int = 768
ffn_embedding_dim: int = 3072
num_encoder_layers: int = 6
num_attention_heads: int = 8
num_segments: int = 2
# Parameters related to positions
use_position_embeddings: bool = True
# the fairseq module for position embeddings offsets all position
# ids by the padding index. Disable this offset by setting this flag
# to False. This will work correctly since we mask out the embeddings
# associated with padding in the encoder
offset_positions_by_padding: bool = True
# Model Initialization parameters
apply_bert_init: bool = True
# Misc. Params
encoder_normalize_before: bool = True
activation_fn: str = "relu"
project_representation: bool = False
max_seq_len: int = 128
# multilingual is set to true for cross-lingual LM training
multilingual: bool = False
# Flags for freezing parameters (e.g. during fine-tuning)
freeze_embeddings: bool = False
n_trans_layers_to_freeze: int = 0
# Sparse multihead attention parameters
is_bidirectional: bool = True
stride: int = 32
expressivity: int = 8
def __init__(
self,
config: Config,
output_encoded_layers: bool,
padding_idx: int,
vocab_size: int,
*args,
**kwarg,
) -> None:
super().__init__(
config,
output_encoded_layers=output_encoded_layers,
padding_idx=padding_idx,
vocab_size=vocab_size,
)
self.sentence_encoder = SparseTransformerSentenceEncoderModule(
padding_idx=padding_idx,
vocab_size=vocab_size,
num_encoder_layers=config.num_encoder_layers,
embedding_dim=config.embedding_dim,
ffn_embedding_dim=config.ffn_embedding_dim,
num_attention_heads=config.num_attention_heads,
dropout=config.dropout,
attention_dropout=config.attention_dropout,
activation_dropout=config.activation_dropout,
max_seq_len=config.max_seq_len,
num_segments=config.num_segments,
use_position_embeddings=config.use_position_embeddings,
offset_positions_by_padding=config.offset_positions_by_padding,
encoder_normalize_before=config.encoder_normalize_before,
apply_bert_init=config.apply_bert_init,
activation_fn=config.activation_fn,
freeze_embeddings=config.freeze_embeddings,
n_trans_layers_to_freeze=config.n_trans_layers_to_freeze,
export=self.export,
is_bidirectional=config.is_bidirectional,
stride=config.stride,
expressivity=config.expressivity,
)
log_class_usage(__class__)