Source code for pytext.models.representations.docnn

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import torch
import torch.nn as nn
import torch.nn.functional as F
from pytext.config import ConfigBase
from pytext.config.module_config import CNNParams, PoolingType
from pytext.utils.usage import log_class_usage

from .representation_base import RepresentationBase


[docs]class DocNNRepresentation(RepresentationBase): """CNN based representation of a document."""
[docs] class Config(RepresentationBase.Config): dropout: float = 0.4 cnn: CNNParams = CNNParams() pooling: PoolingType = PoolingType.MAX
def __init__(self, config: Config, embed_dim: int) -> None: super().__init__(config) self.max_kernel = max(config.cnn.kernel_sizes) self.convs = nn.ModuleList( [ nn.Conv1d(embed_dim, config.cnn.kernel_num, K, padding=K) for K in config.cnn.kernel_sizes ] ) self.dropout = nn.Dropout(config.dropout) self.representation_dim = len(config.cnn.kernel_sizes) * config.cnn.kernel_num self.pooling_type = config.pooling log_class_usage(__class__)
[docs] def forward(self, embedded_tokens: torch.Tensor, *args) -> torch.Tensor: # embedded_tokens of size (N,W,D) rep = embedded_tokens # nn.Conv1d expects a tensor of dim (batch_size x embed_dim x seq_len) rep = rep.transpose(1, 2) rep = [self.conv_and_pool(rep, conv) for conv in self.convs] rep = self.dropout(torch.cat(rep, 1)) # (N,len(Ks)*Co) return rep
[docs] def conv_and_pool(self, x, conv): x = F.relu(conv(x)) if self.pooling_type == PoolingType.MAX: x, _ = torch.max(x, dim=2) elif self.pooling_type == PoolingType.MEAN: x = torch.mean(x, dim=2) elif self.pooling_type == PoolingType.LOGSUMEXP: x = torch.logsumexp(x, dim=2) else: raise NotImplementedError return x