Source code for pytext.models.embeddings.contextual_token_embedding

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
from pytext.config.field_config import ContextualTokenEmbeddingConfig

from .embedding_base import EmbeddingBase


[docs]class ContextualTokenEmbedding(EmbeddingBase): """Module for providing token embeddings from a pretrained model.""" Config = ContextualTokenEmbeddingConfig
[docs] @classmethod def from_config(cls, config: ContextualTokenEmbeddingConfig, *args, **kwargs): return cls(config.embed_dim)
[docs] def forward(self, embedding: torch.Tensor) -> torch.Tensor: embedding_shape = torch.onnx.operators.shape_as_tensor(embedding) # Since embeddings vector is flattened, verify its shape correctness. if embedding_shape[1].item() % self.embedding_dim != 0: raise ValueError( f"Input embedding_dim {embedding_shape[1]} is not a" + f" multiple of specified embedding_dim {self.embedding_dim}" ) # Unflatten embedding Tensor from (batch_size, seq_len * embedding_size) # to (batch_size, seq_len, embedding_size). # TODO: Predictor forward compatibility issue: change back once # D18945316 is available on predictor: # num_tokens = embedding_shape[1] // self.embedding_dim num_tokens = (embedding_shape[1] / self.embedding_dim).type(dtype=torch.long) new_embedding_shape = torch.cat( ( torch.LongTensor([-1]), num_tokens.view(1), torch.LongTensor([self.embedding_dim]), ) ) return torch.onnx.operators.reshape_from_tensor_shape( embedding, new_embedding_shape )