Source code for pytext.metric_reporters.calibration_metric_reporter

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, Dict, List

from pytext.config import PyTextConfig
from pytext.metric_reporters.channel import Channel, ConsoleChannel
from pytext.metrics import LabelPrediction
from pytext.metrics.calibration_metrics import compute_calibration
from torch import Tensor

from .metric_reporter import MetricReporter


[docs]class CalibrationMetricReporter(MetricReporter): def __init__(self, channels: List[Channel], pad_index: int = -1) -> None: super().__init__(channels) self.pad_index = pad_index
[docs] @classmethod def from_config(cls, config: PyTextConfig, pad_index: int = -1): return cls(channels=[ConsoleChannel()], pad_index=pad_index)
[docs] def aggregate_preds(self, batch_preds: Tensor, batch_context=Dict[str, Any]): self.all_preds.append(batch_preds.flatten().tolist())
[docs] def aggregate_targets(self, batch_targets: Tensor, batch_context=Dict[str, Any]): self.all_targets.append(batch_targets.flatten().tolist())
[docs] def aggregate_scores(self, batch_scores: Tensor): batch_scores = batch_scores.view(-1, batch_scores.size(-1)) self.all_scores.append(batch_scores.tolist())
[docs] def calculate_metric(self): scores_list: List[float] = [] preds_list: List[int] = [] targets_list: List[int] = [] for (scores, preds, targets) in zip( self.all_scores, self.all_preds, self.all_targets ): non_pad_idxs = [ idx for (idx, target) in enumerate(targets) if target != self.pad_index ] scores = [scores[idx] for idx in non_pad_idxs] preds = [preds[idx] for idx in non_pad_idxs] targets = [targets[idx] for idx in non_pad_idxs] assert len(scores) == len(preds) == len(targets) scores_list.extend(scores) preds_list.extend(preds) targets_list.extend(targets) label_predictions: List[LabelPrediction] = [ LabelPrediction(scores, pred, target) for (scores, pred, target) in zip(scores_list, preds_list, targets_list) ] calibration_metrics = compute_calibration(label_predictions) return calibration_metrics