Source code for pytext.models.output_layers.utils

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List

import torch
from caffe2.python import core


[docs]class OutputLayerUtils:
[docs] @staticmethod def gen_additional_blobs( predict_net: core.Net, probability_out, model_out: torch.Tensor, output_name: str, label_names: List[str], ) -> List[core.BlobReference]: """ Utility method to generate additional blobs for human readable result for models that use explicit labels. """ res = [] tmp_out_score = predict_net.Log(probability_out) label_scores = predict_net.Split( tmp_out_score, label_names, axis=model_out.dim() - 1 ) # Make sure label_scores is iterable if not isinstance(label_scores, tuple): label_scores = (label_scores,) for name, label_score in zip(label_names, label_scores): res.append(predict_net.Copy(label_score, "{}:{}".format(output_name, name))) return res