Source code for pytext.models.representations.transformer.residual_mlp

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import List

from pytext.utils.usage import log_class_usage
from torch import nn
from torch.nn import functional as F


[docs]class GeLU(nn.Module): """Component class to wrap F.gelu."""
[docs] def forward(self, input): return F.gelu(input)
[docs]class ResidualMLP(nn.Module): """A square MLP component which can learn a bias on an input vector. This MLP in particular defaults to using GeLU as its activation function (this can be changed by passing a different activation function), and retains a residual connection to its original input to help with gradient propogation. Unlike pytext's MLPDecoder it doesn't currently allow adding a LayerNorm in between hidden layers. """ def __init__( self, input_dim: int, hidden_dims: List[int], dropout: float = 0.1, activation=GeLU, ): super().__init__() modules = [] for last_dim, dim in zip([input_dim] + hidden_dims, hidden_dims): modules.extend( [nn.Linear(last_dim, dim), activation(), nn.Dropout(dropout)] ) last_dim = hidden_dims[-1] if hidden_dims else input_dim # Unlike normal PyText mlp, we don't put an activation layer at the end. modules.extend([nn.Linear(last_dim, input_dim), nn.Dropout(dropout)]) self.mlp = nn.Sequential(*modules) log_class_usage(__class__)
[docs] def forward(self, input): bias = self.mlp(input) return input + bias