Source code for pytext.models.output_layers.output_layer_base

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, Dict, List, Optional, Tuple

import torch
from caffe2.python import core
from pytext.loss import Loss
from pytext.models.module import Module
from pytext.utils.usage import log_class_usage


[docs]class OutputLayerBase(Module): """ Base class for all output layers in PyText. The responsibilities of this layer are 1. Implement how loss is computed from logits and targets. 2. Implement how to get predictions from logits. 3. Implement the Caffe2 operator for performing the above tasks. This is used when PyText exports PyTorch model to Caffe2. Args: loss_fn (type): The loss function object to use for computing loss. Defaults to None. Attributes: loss_fn: The loss function object to use for computing loss. """ def __init__( self, target_names: Optional[List[str]] = None, loss_fn: Optional[Loss] = None, *args, **kwargs ) -> None: super().__init__() self.target_names = target_names self.loss_fn = loss_fn log_class_usage(__class__)
[docs] def get_loss( self, logit: torch.Tensor, target: torch.Tensor, context: Optional[Dict[str, Any]] = None, reduce: bool = True, ) -> torch.Tensor: """Compute and return the loss given logits and targets. Args: logit (torch.Tensor): Logits returned :class:`~pytext.models.Model`. target (torch.Tensor): True label/target to compute loss against. context (Optional[Dict[str, Any]]): Context is a dictionary of items that's passed as additional metadata by the :class:`~pytext.data.DataHandler`. Defaults to None. reduce (bool): Whether to reduce loss over the batch. Defaults to True. Returns: torch.Tensor: Model loss. """ return self.loss_fn(logit, target, reduce) if self.loss_fn else None
[docs] def get_pred( self, logit: torch.Tensor, targets: Optional[torch.Tensor] = None, context: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute and return prediction and scores from the model. Args: logit (torch.Tensor): Logits returned :class:`~pytext.models.Model`. targets (Optional[torch.Tensor]): True label/target. Only used by :class:`~pytext.models.output_layer.LMOutputLayer`. Defaults to None. context (Optional[Dict[str, Any]]): Context is a dictionary of items that's passed as additional metadata by the :class:`~pytext.data.DataHandler`. Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: Model prediction and scores. """ return logit, None
[docs] def export_to_caffe2( self, workspace: core.workspace, init_net: core.Net, predict_net: core.Net, model_out: torch.Tensor, output_name: str, ) -> List[core.BlobReference]: """ Exports the output layer to Caffe2 by manually adding the necessary operators to the init_net and predict_net and, returns the list of external output blobs to be added to the model. By default this does nothing, so any sub-class must override this method (if necessary). To learn about Caffe2 computation graphs and why we need two networks, `init_net` and `predict_net`/`exec_net` read https://caffe2.ai/docs/intro-tutorial#null__nets-and-operators. Args: workspace (core.workspace): Caffe2 `workspace` to use for adding the operator. See https://caffe2.ai/docs/workspace.html to learn about Caffe2 workspace. init_net (core.Net): Caffe2 `init_net` to add the operator to. predict_net (core.Net): Caffe2 `predict_net` to add the operator to. model_out (torch.Tensor): Output logit Tensor from the model to . output_name (str): Name of `model_out` to use in Caffe2 net. label_names (List[str]): List of names of the targets/labels to expose from the Caffe2 net. Returns: List[core.BlobReference]: List of output blobs that the `output_layer` generates. """ return []