Source code for pytext.models.embeddings.word_embedding

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import collections
from typing import List, Optional

import torch
from pytext.config.field_config import WordFeatConfig
from pytext.data.tensorizers import Tensorizer
from pytext.fields import FieldMeta
from pytext.models.embeddings.embedding_base import EmbeddingBase
from pytext.utils.embeddings import PretrainedEmbedding
from pytext.utils.usage import log_class_usage
from torch import nn
from torch.utils.tensorboard import SummaryWriter


[docs]class WordEmbedding(EmbeddingBase): """ A word embedding wrapper module around `torch.nn.Embedding` with options to initialize the word embedding weights and add MLP layers acting on each word. Note: Embedding weights for UNK token are always initialized to zeros. Args: num_embeddings (int): Total number of words/tokens (vocabulary size). embedding_dim (int): Size of embedding vector. embeddings_weight (torch.Tensor): Pretrained weights to initialize the embedding table with. init_range (List[int]): Range of uniform distribution to initialize the weights with if `embeddings_weight` is None. unk_token_idx (int): Index of UNK token in the word vocabulary. mlp_layer_dims (List[int]): List of layer dimensions (if any) to add on top of the embedding lookup. """ Config = WordFeatConfig
[docs] @classmethod def from_config( cls, config: WordFeatConfig, metadata: Optional[FieldMeta] = None, tensorizer: Optional[Tensorizer] = None, init_from_saved_state: Optional[bool] = False, ): """Factory method to construct an instance of WordEmbedding from the module's config object and the field's metadata object. Args: config (WordFeatConfig): Configuration object specifying all the parameters of WordEmbedding. metadata (FieldMeta): Object containing this field's metadata. Returns: type: An instance of WordEmbedding. """ if tensorizer is not None: if config.vocab_from_pretrained_embeddings: raise ValueError( "In new data design, to add tokens from a pretrained embeddings " "file to the vocab, specify `vocab_file` in the token tensorizer." ) embeddings_weight = None # We don't need to load pretrained embeddings if we know the # embedding weights are going to be loaded from a snapshot. if config.pretrained_embeddings_path and not init_from_saved_state: pretrained_embedding = PretrainedEmbedding( config.pretrained_embeddings_path, # doesn't support fbpkg lowercase_tokens=config.lowercase_tokens, skip_header=config.skip_header, delimiter=config.delimiter, ) embeddings_weight = pretrained_embedding.initialize_embeddings_weights( tensorizer.vocab.idx, tensorizer.vocab.unk_token, config.embed_dim, config.embedding_init_strategy, ) num_embeddings = len(tensorizer.vocab) unk_token_idx = tensorizer.vocab.get_unk_index() vocab = tensorizer.vocab vocab_pad_idx = vocab.get_pad_index(value=-1) if vocab_pad_idx == -1: vocab_pad_idx = None else: # This else condition should go away after metadata goes away. num_embeddings = metadata.vocab_size embeddings_weight = metadata.pretrained_embeds_weight unk_token_idx = metadata.unk_token_idx vocab = metadata.vocab vocab_pad_idx = None return cls( num_embeddings=num_embeddings, embedding_dim=config.embed_dim, embeddings_weight=embeddings_weight, init_range=config.embedding_init_range, init_std=config.embeddding_init_std, unk_token_idx=unk_token_idx, mlp_layer_dims=config.mlp_layer_dims, padding_idx=config.padding_idx or vocab_pad_idx, vocab=vocab, )
def __init__( self, num_embeddings: int, embedding_dim: int = 300, embeddings_weight: Optional[torch.Tensor] = None, init_range: Optional[List[int]] = None, init_std: Optional[float] = None, unk_token_idx: int = 0, mlp_layer_dims: List[int] = (), padding_idx: Optional[int] = None, vocab: Optional[List[str]] = None, ) -> None: output_embedding_dim = mlp_layer_dims[-1] if mlp_layer_dims else embedding_dim EmbeddingBase.__init__(self, embedding_dim=output_embedding_dim) # Create word embedding self.word_embedding = nn.Embedding( num_embeddings, embedding_dim, _weight=embeddings_weight, padding_idx=padding_idx, ) if embeddings_weight is None: if init_range: self.word_embedding.weight.data.uniform_(init_range[0], init_range[1]) if init_std: self.word_embedding.weight.data.normal_(mean=0.0, std=init_std) # Initialize unk embedding with zeros # to guard the model against randomized decisions based on unknown words self.word_embedding.weight.data[unk_token_idx].fill_(0.0) # Create MLP layers if mlp_layer_dims is None: mlp_layer_dims = [] self.mlp = nn.Sequential( *( nn.Sequential(nn.Linear(m, n), nn.ReLU()) for m, n in zip([embedding_dim] + list(mlp_layer_dims), mlp_layer_dims) ) ) self.vocab = vocab self.padding_idx = padding_idx log_class_usage(__class__) def __getattr__(self, name): if name == "weight": return self.word_embedding.weight return super().__getattr__(name)
[docs] def forward(self, input): return self.mlp(self.word_embedding(input))
[docs] def freeze(self): for param in self.word_embedding.parameters(): param.requires_grad = False
[docs] def visualize(self, summary_writer: SummaryWriter): if self.vocab: summary_writer.add_embedding( self.word_embedding.weight, metadata=self.vocab )