Source code for pytext.models.embeddings.embedding_list

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Dict, Iterable, List, Tuple, Union

import torch
import torch.nn as nn
from torch.nn import ModuleList
from torch.utils.tensorboard import SummaryWriter

from .embedding_base import EmbeddingBase


[docs]class EmbeddingList(EmbeddingBase, ModuleList): """ There are more than one way to embed a token and this module provides a way to generate a list of sub-embeddings, concat embedding tensors into a single Tensor or return a tuple of Tensors that can be used by downstream modules. Args: embeddings (Iterable[EmbeddingBase]): A sequence of embedding modules to embed a token. concat (bool): Whether to concatenate the embedding vectors emitted from `embeddings` modules. Attributes: num_emb_modules (int): Number of flattened embeddings in `embeddings`, e.g: ((e1, e2), e3) has 3 in total input_start_indices (List[int]): List of indices of the sub-embeddings in the embedding list. concat (bool): Whether to concatenate the embedding vectors emitted from `embeddings` modules. embedding_dim: Total embedding size, can be a single int or tuple of int depending on concat setting """ def __init__(self, embeddings: Iterable[EmbeddingBase], concat: bool) -> None: EmbeddingBase.__init__(self, 0) embeddings = list(filter(None, embeddings)) self.num_emb_modules = sum(emb.num_emb_modules for emb in embeddings) embeddings_list, input_start_indices = [], [] start = 0 for emb in embeddings: if emb.embedding_dim > 0: embeddings_list.append(emb) input_start_indices.append(start) start += emb.num_emb_modules ModuleList.__init__(self, embeddings_list) self.input_start_indices = input_start_indices self.concat = concat assert len(self) > 0, "must have at least 1 sub embedding" embedding_dims = tuple(emb.embedding_dim for emb in self) self.embedding_dim = sum(embedding_dims) if concat else embedding_dims
[docs] def forward(self, *emb_input) -> Union[torch.Tensor, Tuple[torch.Tensor]]: """ Get embeddings from all sub-embeddings and either concatenate them into one Tensor or return them in a tuple. Args: *emb_input (type): Sequence of token level embeddings to combine. The inputs should match the size of configured embeddings. Each of them is either a Tensor or a tuple of Tensors. Returns: Union[torch.Tensor, Tuple[torch.Tensor]]: If `concat` is True then a Tensor is returned by concatenating all embeddings. Otherwise all embeddings are returned in a tuple. """ # tokens dim: (bsz, max_seq_len) -> (bsz, max_seq_len, dim) OR # (bsz, max_num_sen, max_seq_len) -> (bsz, max_num_sen, max_seq_len, dim) # for seqnn if self.num_emb_modules != len(emb_input): raise Exception( f"expecting {self.num_emb_modules} embeddings, " + f"but got {len(emb_input)} input" ) tensors = [] for emb, start in zip(self, self.input_start_indices): end = start + emb.num_emb_modules input = emb_input[start:end] # single embedding if len(input) == 1: # the input for the single embedding is a tuple or list of tensors if isinstance(input[0], list) or isinstance(input[0], tuple): [input] = input emb_tensor = emb(*input) tensors.append(emb_tensor) if self.concat: return torch.cat(tensors, 2) else: return tuple(tensors) if len(tensors) > 1 else tensors[0]
[docs] def get_param_groups_for_optimizer(self) -> List[Dict[str, nn.Parameter]]: """ Organize child embedding parameters into param_groups (or layers), so the optimizer and / or schedulers can have custom behavior per layer. The param_groups from each child embedding are aligned at the first (lowest) param_group. """ param_groups: List[Dict[str, nn.Parameter]] = [] for module_name, embedding_module in self.named_children(): child_params = embedding_module.get_param_groups_for_optimizer() for i, child_param_group in enumerate(child_params): if i >= len(param_groups): param_groups.append({}) for param_name, param in child_param_group.items(): param_name = "%s.%s" % (module_name, param_name) param_groups[i][param_name] = param return param_groups
[docs] def visualize(self, summary_writer: SummaryWriter): for child in self: child.visualize(summary_writer)