Source code for pytext.models.embeddings.mlp_embedding

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, Optional

import torch
from pytext.config.field_config import MLPFeatConfig
from pytext.data.tensorizers import Tensorizer
from pytext.fields import FieldMeta
from pytext.models.embeddings.embedding_base import EmbeddingBase
from pytext.utils.usage import log_class_usage
from torch import nn
from torch.utils.tensorboard import SummaryWriter


[docs]class MLPEmbedding(EmbeddingBase): """ An MLP embedding wrapper module around `torch.nn.Embedding` to add transformations for float tensors. Args: num_embeddings (int): Total number of words/tokens (vocabulary size). embedding_dim (int): Size of embedding vector. embeddings_weight (torch.Tensor): Pretrained weights to initialize the embedding table with. init_range (List[int]): Range of uniform distribution to initialize the weights with if `embeddings_weight` is None. mlp_layer_dims (List[int]): List of layer dimensions (if any) to add on top of the embedding lookup. """ Config = MLPFeatConfig
[docs] @classmethod def from_config( cls, config: MLPFeatConfig, metadata: Optional[FieldMeta] = None, tensorizer: Optional[Tensorizer] = None, init_from_saved_state: Optional[bool] = False, ): """Factory method to construct an instance of MLPEmbedding from the module's config object and the field's metadata object. Args: config (MLPFeatConfig): Configuration object specifying all the parameters of MLPEmbedding. metadata (FieldMeta): Object containing this field's metadata. Returns: type: An instance of MLPEmbedding. """ if tensorizer is not None: embeddings_weight = None else: # This else condition should go away after metadata goes away. embeddings_weight = metadata.pretrained_embeds_weight return cls( embedding_dim=config.embed_dim, embeddings_weight=embeddings_weight, init_range=config.embedding_init_range, init_std=config.embeddding_init_std, mlp_layer_dims=config.mlp_layer_dims, )
def __init__( self, embedding_dim: int = 300, embeddings_weight: Optional[torch.Tensor] = None, init_range: Optional[List[int]] = None, init_std: Optional[float] = None, mlp_layer_dims: List[int] = (), ) -> None: output_embedding_dim = mlp_layer_dims[-1] if mlp_layer_dims else embedding_dim EmbeddingBase.__init__(self, embedding_dim=output_embedding_dim) # Create MLP layers if mlp_layer_dims is None: mlp_layer_dims = [] self.mlp = nn.Sequential( *( nn.Sequential(nn.Linear(m, n), nn.ReLU()) for m, n in zip([embedding_dim] + list(mlp_layer_dims), mlp_layer_dims) ) ) log_class_usage(__class__)
[docs] def forward(self, input): return self.mlp(input)
[docs] def visualize(self, summary_writer: SummaryWriter): pass