Source code for pytext.models.decoders.intent_slot_model_decoder

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from pytext.models.module import create_module
from pytext.utils.usage import log_class_usage

from .decoder_base import DecoderBase
from .mlp_decoder import MLPDecoder


[docs]class IntentSlotModelDecoder(DecoderBase): """ `IntentSlotModelDecoder` implements the decoder layer for intent-slot models. Intent-slot models jointly predict intent and slots from an utterance. At the core these models learn to jointly perform document classification and word tagging tasks. `IntentSlotModelDecoder` accepts arguments for decoding both document classification and word tagging tasks, namely, `in_dim_doc` and `in_dim_word`. Args: config (type): Configuration object of type IntentSlotModelDecoder.Config. in_dim_doc (type): Dimension of input Tensor for projecting document representation. in_dim_word (type): Dimension of input Tensor for projecting word representation. out_dim_doc (type): Dimension of projected output Tensor for document classification. out_dim_word (type): Dimension of projected output Tensor for word tagging. Attributes: use_doc_probs_in_word (bool): Whether to use intent probabilities for predicting slots. doc_decoder (type): Document/intent decoder module. word_decoder (type): Word/slot decoder module. """
[docs] class Config(DecoderBase.Config): """ Configuration class for `IntentSlotModelDecoder`. Attributes: use_doc_probs_in_word (bool): Whether to use intent probabilities for predicting slots. """ use_doc_probs_in_word: bool = False doc_decoder: MLPDecoder.Config = MLPDecoder.Config() word_decoder: MLPDecoder.Config = MLPDecoder.Config()
def __init__( self, config: Config, in_dim_doc: int, in_dim_word: int, out_dim_doc: int, out_dim_word: int, ) -> None: super().__init__(config) self.use_doc_probs_in_word = config.use_doc_probs_in_word self.doc_decoder = create_module( config.doc_decoder, in_dim=in_dim_doc, out_dim=out_dim_doc ) if self.use_doc_probs_in_word: in_dim_word += out_dim_doc self.word_decoder = create_module( config.word_decoder, in_dim=in_dim_word, out_dim=out_dim_word ) log_class_usage(__class__)
[docs] def forward( self, x_d: torch.Tensor, x_w: torch.Tensor, dense: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: if dense is not None: logit_d = self.doc_decoder(torch.cat((x_d, dense), 1)) else: logit_d = self.doc_decoder(x_d) if self.use_doc_probs_in_word: # Get doc probability distribution doc_prob = F.softmax(logit_d, 1) word_input_shape = x_w.size() doc_prob = doc_prob.unsqueeze(1).repeat(1, word_input_shape[1], 1) x_w = torch.cat((x_w, doc_prob), 2) if dense is not None: word_input_shape = x_w.size() dense = dense.unsqueeze(1).repeat(1, word_input_shape[1], 1) x_w = torch.cat((x_w, dense), 2) return logit_d, self.word_decoder(x_w)
[docs] def get_decoder(self) -> List[nn.Module]: """Returns the document and word decoder modules.""" return [self.doc_decoder, self.word_decoder]