Source code for pytext.models.decoders.mlp_decoder_two_tower

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from enum import Enum
from typing import List

import torch
import torch.nn as nn
from pytext.config.module_config import Activation
from pytext.models.decoders.decoder_base import DecoderBase
from pytext.optimizer import get_activation
from pytext.utils import precision
from pytext.utils.usage import log_class_usage


[docs]class ExportType(Enum): RIGHT = "RIGHT" LEFT = "LEFT" NONE = "NONE"
[docs]class MLPDecoderTwoTower(DecoderBase): """ Implements a 'two-tower' MLPDecoder: one for left and one for right """
[docs] class Config(DecoderBase.Config): # Intermediate hidden dimensions right_hidden_dims: List[int] = [] left_hidden_dims: List[int] = [] hidden_dims: List[int] = [] layer_norm: bool = False dropout: float = 0.0
def __init__( self, config: Config, right_dim: int, left_dim: int, to_dim: int, export_type=ExportType.NONE, ) -> None: super().__init__(config) self.mlp_for_right = MLPDecoderTwoTower.get_mlp( right_dim, 0, config.right_hidden_dims, config.layer_norm, config.dropout, export_embedding=True, ) self.mlp_for_left = MLPDecoderTwoTower.get_mlp( left_dim, 0, config.left_hidden_dims, config.layer_norm, config.dropout, export_embedding=True, ) from_dim = config.right_hidden_dims[-1] + config.left_hidden_dims[-1] self.mlp = MLPDecoderTwoTower.get_mlp( from_dim, to_dim, config.hidden_dims, config.layer_norm, config.dropout ) self.out_dim = to_dim self.export_type = export_type log_class_usage
[docs] @staticmethod def get_mlp( from_dim: int, to_dim: int, hidden_dims: List[int], layer_norm: bool, dropout: float, export_embedding: bool = False, ): layers = [] for i in range(len(hidden_dims)): dim = hidden_dims[i] layers.append(nn.Linear(from_dim, dim, True)) # Skip ReLU, LayerNorm, and dropout for the last layer if export_embedding if not (export_embedding and i == len(hidden_dims) - 1): layers.append(get_activation(Activation.RELU)) if layer_norm: layers.append(nn.LayerNorm(dim)) if dropout > 0: layers.append(nn.Dropout(dropout)) from_dim = dim if to_dim > 0: layers.append(nn.Linear(from_dim, to_dim, True)) return nn.Sequential(*layers)
[docs] def forward(self, *x: List[torch.Tensor]) -> torch.Tensor: # x[0]: right_text_emb, x[1]: left_text_emb, x[2]: right_dense, x[3]: left_dense assert len(x) == 4 if self.export_type == ExportType.RIGHT or self.export_type == ExportType.NONE: right_tensor = ( torch.cat((x[0], x[2]), 1).half() if precision.FP16_ENABLED else torch.cat((x[0], x[2]), 1).float() ) right_output = self.mlp_for_right(right_tensor) if self.export_type == ExportType.RIGHT: return right_output if self.export_type == ExportType.LEFT or self.export_type == ExportType.NONE: left_tensor = ( torch.cat((x[1], x[3]), 1).half() if precision.FP16_ENABLED else torch.cat((x[1], x[3]), 1).float() ) left_output = self.mlp_for_left(left_tensor) if self.export_type == ExportType.LEFT: return left_output return self.mlp(torch.cat((right_output, left_output), 1))
[docs] def get_decoder(self) -> List[nn.Module]: return [self.mlp_for_left, self.mlp_for_right]