Source code for pytext.models.embeddings.embedding_base

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Dict, List

import torch.nn as nn
from pytext.models.module import Module
from pytext.utils.usage import log_class_usage
from torch.utils.tensorboard import SummaryWriter


[docs]class EmbeddingBase(Module): """Base class for token level embedding modules. Args: embedding_dim (int): Size of embedding vector. Attributes: num_emb_modules (int): Number of ways to embed a token. embedding_dim (int): Size of embedding vector. """ __EXPANSIBLE__ = True def __init__(self, embedding_dim: int): super().__init__() # By default has 1 embedding which is itself, for EmbeddingList, this num # can be greater than 1 self.num_emb_modules = 1 self.embedding_dim = embedding_dim log_class_usage(__class__)
[docs] def visualize(self, summary_writer: SummaryWriter): """ Overridden in sub classes to implement Tensorboard visualization of embedding space """ pass