Source code for pytext.models.output_layers.intent_slot_output_layer

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from caffe2.python import core
from pytext.common.constants import DatasetFieldName
from pytext.data.utils import Vocabulary
from pytext.models.module import create_module
from pytext.utils.usage import log_class_usage
from torch import jit

from .doc_classification_output_layer import ClassificationOutputLayer
from .output_layer_base import OutputLayerBase
from .utils import query_word_reprs
from .word_tagging_output_layer import CRFOutputLayer, WordTaggingOutputLayer


[docs]class IntentSlotScores(nn.Module): def __init__(self, doc_scores: jit.ScriptModule, word_scores: jit.ScriptModule): super().__init__() self.doc_scores = doc_scores self.word_scores = word_scores log_class_usage(__class__)
[docs] def forward( self, logits: Tuple[torch.Tensor, torch.Tensor], context: Dict[str, torch.Tensor], ) -> Tuple[List[Dict[str, float]], List[List[Dict[str, float]]]]: d_logits, w_logits = logits if "token_indices" in context: w_logits = query_word_reprs(w_logits, context["token_indices"]) d_results = self.doc_scores(d_logits) w_results = self.word_scores(w_logits, context) return d_results, w_results
[docs]class IntentSlotOutputLayer(OutputLayerBase): """ Output layer for joint intent classification and slot-filling models. Intent classification is a document classification problem and slot filling is a word tagging problem. Thus terms these can be used interchangeably in the documentation. Args: doc_output (ClassificationOutputLayer): Output layer for intent classification task. See :class:`~pytext.models.output_layers.ClassificationOutputLayer` for details. word_output (WordTaggingOutputLayer): Output layer for slot filling task. See :class:`~pytext.models.output_layers.WordTaggingOutputLayer` for details. Attributes: doc_output (type): Output layer for intent classification task. word_output (type): Output layer for slot filling task. """
[docs] class Config(OutputLayerBase.Config): doc_output: ClassificationOutputLayer.Config = ( ClassificationOutputLayer.Config() ) word_output: Union[ WordTaggingOutputLayer.Config, CRFOutputLayer.Config ] = WordTaggingOutputLayer.Config()
[docs] @classmethod def from_config( cls, config: Config, doc_labels: Vocabulary, word_labels: Vocabulary ): return cls( create_module(config.doc_output, labels=doc_labels), create_module(config.word_output, labels=word_labels), )
def __init__( self, doc_output: ClassificationOutputLayer, word_output: WordTaggingOutputLayer ) -> None: super().__init__() self.doc_output = doc_output self.word_output = word_output log_class_usage(__class__)
[docs] def get_loss( self, logits: Tuple[torch.Tensor, torch.Tensor], targets: Tuple[torch.Tensor, torch.Tensor], context: Dict[str, Any] = None, *args, **kwargs ) -> torch.Tensor: """Compute and return the averaged intent and slot-filling loss. Args: logit (Tuple[torch.Tensor, torch.Tensor]): Logits returned by :class:`~pytext.models.joint_model.JointModel`. It is a tuple containing logits for intent classification and slot filling. targets (Tuple[torch.Tensor, torch.Tensor]): Tuple of target Tensors containing true document label/target and true word labels/targets. context (Dict[str, Any]): Context is a dictionary of items that's passed as additional metadata. Defaults to None. Returns: torch.Tensor: Averaged intent and slot loss. """ d_logit, w_logit = logits if DatasetFieldName.TOKEN_INDICES in context: w_logit = query_word_reprs(w_logit, context[DatasetFieldName.TOKEN_INDICES]) d_target, w_target = targets d_weight = context[DatasetFieldName.DOC_WEIGHT_FIELD] # noqa w_weight = context[DatasetFieldName.WORD_WEIGHT_FIELD] # noqa d_loss = self.doc_output.get_loss( d_logit, d_target, context=context, reduce=False ) w_loss = self.word_output.get_loss( w_logit, w_target, context=context, reduce=False ) # w_loss could have been flattened w_hard_target = w_target[0] if type(w_target) is tuple else w_target if w_loss.size()[0] != w_hard_target.size()[0]: w_loss = w_loss.reshape(w_hard_target.size()) w_loss = torch.mean(w_loss, dim=1) d_weighted_loss = torch.mean(torch.mul(d_loss, d_weight)) w_weighted_loss = torch.mean(torch.mul(w_loss, w_weight)) return d_weighted_loss + w_weighted_loss
[docs] def get_pred( self, logits: Tuple[torch.Tensor, torch.Tensor], targets: Optional[torch.Tensor] = None, context: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute and return prediction and scores from the model. Args: logit (Tuple[torch.Tensor, torch.Tensor]): Logits returned by :class:`~pytext.models.joint_model.JointModel`. It's tuple containing logits for intent classification and slot filling. targets (Optional[torch.Tensor]): Not applicable. Defaults to None. context (Optional[Dict[str, Any]]): Context is a dictionary of items that's passed as additional metadata. Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: Model prediction and scores. """ d_logit, w_logit = logits if DatasetFieldName.TOKEN_INDICES in context: w_logit = query_word_reprs(w_logit, context[DatasetFieldName.TOKEN_INDICES]) d_pred, d_score = self.doc_output.get_pred(d_logit, None, context) w_pred, w_score = self.word_output.get_pred(w_logit, None, context) return (d_pred, w_pred), (d_score, w_score)
[docs] def export_to_caffe2( self, workspace: core.workspace, init_net: core.Net, predict_net: core.Net, model_out: List[torch.Tensor], doc_out_name: str, word_out_name: str, ) -> List[core.BlobReference]: """ Exports the intent slot output layer to Caffe2. See `OutputLayerBase.export_to_caffe2()` for details. """ return self.doc_output.export_to_caffe2( workspace, init_net, predict_net, model_out[0], doc_out_name ) + self.word_output.export_to_caffe2( workspace, init_net, predict_net, model_out[1], word_out_name )
[docs] def torchscript_predictions(self): doc_scores = self.doc_output.torchscript_predictions() word_scores = self.word_output.torchscript_predictions() return jit.script(IntentSlotScores(doc_scores, word_scores))