Source code for pytext.models.output_layers.doc_classification_output_layer

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, List, Optional, Union

import torch
import torch.nn.functional as F
from caffe2.python import core
from pytext.common import Padding
from pytext.config.component import create_loss
from pytext.data.utils import Vocabulary
from pytext.fields import FieldMeta
from pytext.loss import (
    AUCPRHingeLoss,
    BinaryCrossEntropyLoss,
    BinaryCrossEntropyWithLogitsLoss,
    CrossEntropyLoss,
    HingeLoss,
    KLDivergenceBCELoss,
    KLDivergenceCELoss,
    LabelSmoothedCrossEntropyLoss,
    MultiLabelSoftMarginLoss,
)
from pytext.utils.label import get_label_weights
from torch import jit

from .output_layer_base import OutputLayerBase
from .utils import OutputLayerUtils


[docs]class ClassificationOutputLayer(OutputLayerBase): """ Output layer for document classification models. It supports `CrossEntropyLoss` and `BinaryCrossEntropyLoss` per document. Args: loss_fn (Union[CrossEntropyLoss, BinaryCrossEntropyLoss]): The loss function to use for computing loss. Defaults to None. Attributes: loss_fn: The loss function to use for computing loss. """
[docs] class Config(OutputLayerBase.Config): loss: Union[ CrossEntropyLoss.Config, BinaryCrossEntropyLoss.Config, BinaryCrossEntropyWithLogitsLoss.Config, MultiLabelSoftMarginLoss.Config, AUCPRHingeLoss.Config, HingeLoss.Config, KLDivergenceBCELoss.Config, KLDivergenceCELoss.Config, LabelSmoothedCrossEntropyLoss.Config, ] = CrossEntropyLoss.Config() label_weights: Optional[Dict[str, float]] = None
[docs] @classmethod def from_config( cls, config: Config, metadata: Optional[FieldMeta] = None, labels: Optional[Vocabulary] = None, ): if labels is not None: vocab = list(labels) vocab_dict = labels.idx pad_token_idx = labels.idx.get( labels.pad_token, Padding.DEFAULT_LABEL_PAD_IDX ) else: vocab = metadata.vocab.itos vocab_dict = metadata.vocab.stoi pad_token_idx = getattr(metadata, "pad_token_idx", -1) label_weights = ( get_label_weights(vocab_dict, config.label_weights) if config.label_weights else None ) loss = create_loss( config.loss, weight=label_weights, ignore_index=pad_token_idx ) if isinstance(loss, BinaryCrossEntropyLoss): cls = BinaryClassificationOutputLayer elif isinstance(loss, MultiLabelSoftMarginLoss): cls = MultiLabelOutputLayer elif isinstance(loss, BinaryCrossEntropyWithLogitsLoss): cls = MultiLabelOutputLayer else: cls = MulticlassOutputLayer return cls(vocab, loss, config)
[docs] def get_pred(self, logit, *args, **kwargs): """Compute and return prediction and scores from the model. Prediction is computed using argmax over the document label/target space. Scores are sigmoid or softmax scores over the model logits depending on the loss component being used. Args: logit (torch.Tensor): Logits returned :class:`~pytext.models.doc_model.DocModel`. Returns: Tuple[torch.Tensor, torch.Tensor]: Model prediction and scores. """ raise NotImplementedError
[docs]class ClassificationScores(jit.ScriptModule): def __init__(self, classes, score_function, score_function_dim=None): super().__init__() self.classes = jit.Attribute(classes, List[str]) self.score_function = score_function self.score_function_dim = score_function_dim @jit.script_method def forward(self, logits: torch.Tensor): # In pure python, this code would be implemented as follows: # scores = self.score_function(logits) # return [ # {class: score for class, score in zip(self.classes, example_scores} # for example_scores in scores.tolist() # ] # Extra verbosity is due to jit.script. if self.score_function_dim is None: scores = self.score_function(logits) else: scores = self.score_function(logits, dim=self.score_function_dim) results = jit.annotate(List[Dict[str, float]], []) for example_scores in scores.chunk(len(scores)): example_scores = example_scores.squeeze(dim=0) example_response = jit.annotate(Dict[str, float], {}) for i in range(len(self.classes)): example_response[self.classes[i]] = float(example_scores[i].item()) results.append(example_response) return results
[docs]class BinaryClassificationOutputLayer(ClassificationOutputLayer):
[docs] def get_pred(self, logit, *args, **kwargs): """See `OutputLayerBase.get_pred()`.""" preds = torch.max(logit, -1)[1] scores = F.logsigmoid(logit) return preds, scores
[docs] def torchscript_predictions(self): return ClassificationScores(self.target_names, F.logsigmoid)
[docs] def export_to_caffe2( self, workspace: core.workspace, init_net: core.Net, predict_net: core.Net, model_out: torch.Tensor, output_name: str, ) -> List[core.BlobReference]: """See `OutputLayerBase.export_to_caffe2()`.""" probability_out = predict_net.Sigmoid(output_name) return OutputLayerUtils.gen_additional_blobs( predict_net, probability_out, model_out, output_name, self.target_names )
[docs]class MulticlassOutputLayer(ClassificationOutputLayer):
[docs] def get_pred(self, logit, *args, **kwargs): """See `OutputLayerBase.get_pred()`.""" preds = torch.max(logit, -1)[1] scores = F.log_softmax(logit, -1) return preds, scores
[docs] def torchscript_predictions(self): return ClassificationScores(self.target_names, F.log_softmax, -1)
[docs] def export_to_caffe2( self, workspace: core.workspace, init_net: core.Net, predict_net: core.Net, model_out: torch.Tensor, output_name: str, ) -> List[core.BlobReference]: """See `OutputLayerBase.export_to_caffe2()`.""" probability_out = predict_net.Softmax(output_name, axis=model_out.dim() - 1) return OutputLayerUtils.gen_additional_blobs( predict_net, probability_out, model_out, output_name, self.target_names )
[docs]class MultiLabelOutputLayer(ClassificationOutputLayer):
[docs] def get_pred(self, logit, *args, **kwargs): """See `OutputLayerBase.get_pred()`.""" preds = logit > 0 scores = F.logsigmoid(logit) return preds, scores
[docs] def torchscript_predictions(self): return ClassificationScores(self.target_names, F.logsigmoid)
[docs] def export_to_caffe2( self, workspace: core.workspace, init_net: core.Net, predict_net: core.Net, model_out: torch.Tensor, output_name: str, ) -> List[core.BlobReference]: """See `OutputLayerBase.export_to_caffe2()`.""" probability_out = predict_net.Sigmoid(output_name) return OutputLayerUtils.gen_additional_blobs( predict_net, probability_out, model_out, output_name, self.target_names )