#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, List
import torch.nn as nn
from pytext.models.module import Module
from pytext.utils.usage import log_class_usage
from torch.utils.tensorboard import SummaryWriter
[docs]class EmbeddingBase(Module):
"""Base class for token level embedding modules.
Args:
embedding_dim (int): Size of embedding vector.
Attributes:
num_emb_modules (int): Number of ways to embed a token.
embedding_dim (int): Size of embedding vector.
"""
__EXPANSIBLE__ = True
def __init__(self, embedding_dim: int):
super().__init__()
# By default has 1 embedding which is itself, for EmbeddingList, this num
# can be greater than 1
self.num_emb_modules = 1
self.embedding_dim = embedding_dim
log_class_usage(__class__)
[docs] def visualize(self, summary_writer: SummaryWriter):
"""
Overridden in sub classes to implement Tensorboard visualization of
embedding space
"""
pass