Source code for pytext.models.decoders.mlp_decoder_query_response

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import List

import torch
import torch.nn as nn
from pytext.utils.usage import log_class_usage

from .decoder_base import DecoderBase

[docs]class MLPDecoderQueryResponse(DecoderBase): """ Implements a 'two-tower' MLP: one for query and one for response Used in search pairwise ranking: both pos_response and neg_response use the response-MLP """
[docs] class Config(DecoderBase.Config): # Intermediate hidden dimensions hidden_dims: List[int] = []
def __init__(self, config: Config, from_dim: int, to_dim: int) -> None: super().__init__(config) self.mlp_for_response = MLPDecoderQueryResponse.get_mlp( from_dim, to_dim, config.hidden_dims ) self.mlp_for_query = MLPDecoderQueryResponse.get_mlp( from_dim, to_dim, config.hidden_dims ) self.out_dim = (3, to_dim) log_class_usage
[docs] @staticmethod def get_mlp(from_dim: int, to_dim: int, hidden_dims: List[int]): layers = [] current_dim = from_dim for dim in hidden_dims or []: layers.append(nn.Linear(current_dim, dim)) layers.append(nn.ReLU()) current_dim = dim layers.append(nn.Linear(current_dim, to_dim)) return nn.Sequential(*layers)
[docs] def forward(self, *x: List[torch.Tensor]) -> List[torch.Tensor]: output = [] assert len(x) == 3 output.append(self.mlp_for_response(x[0])) output.append(self.mlp_for_response(x[1])) output.append(self.mlp_for_query(x[2])) return output
[docs] def get_decoder(self) -> List[nn.Module]: return [self.mlp_for_response, self.mlp_for_query]