Source code for pytext.metric_reporters.classification_metric_reporter

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from enum import Enum
from typing import List, Optional

from pytext.common.constants import Stage
from pytext.data import CommonMetadata
from pytext.metrics import (
    RECALL_AT_PRECISION_THRESHOLDS,
    LabelListPrediction,
    LabelPrediction,
    compute_classification_metrics,
    compute_multi_label_classification_metrics,
)

from .channel import Channel, ConsoleChannel, FileChannel
from .metric_reporter import MetricReporter


META_LABEL_NAMES = "label_names"


[docs]class ComparableClassificationMetric(Enum): ACCURACY = "accuracy" ROC_AUC = "roc_auc" MCC = "mcc" MACRO_F1 = "macro_f1" LABEL_F1 = "label_f1" LABEL_AVG_PRECISION = "label_avg_precision" LABEL_ROC_AUC = "label_roc_auc" # use negative because the reporter's lower_is_better value is False NEGATIVE_LOSS = "negative_loss"
[docs]class ClassificationMetricReporter(MetricReporter): __EXPANSIBLE__ = True
[docs] class Config(MetricReporter.Config): model_select_metric: ComparableClassificationMetric = ( ComparableClassificationMetric.ACCURACY ) target_label: Optional[str] = None #: These column names correspond to raw input data columns. Text in these #: columns (usually just 1 column) will be concatenated and output in #: the IntentModelChannel as an evaluation tsv. text_column_names: List[str] = ["text"] #: These column names correspond to raw input data columns, that #: will be read by data_source into context, and included in the #: run_model output file along with other saving results. additional_column_names: List[str] = [] recall_at_precision_thresholds: List[float] = RECALL_AT_PRECISION_THRESHOLDS
def __init__( self, label_names: List[str], channels: List[Channel], model_select_metric: ComparableClassificationMetric = ( ComparableClassificationMetric.ACCURACY ), target_label: Optional[str] = None, text_column_names: List[str] = Config.text_column_names, additional_column_names: List[str] = Config.additional_column_names, recall_at_precision_thresholds: List[float] = ( Config.recall_at_precision_thresholds ), ) -> None: super().__init__(channels) self.label_names = label_names self.model_select_metric = model_select_metric self.target_label = target_label self.text_column_names = text_column_names self.additional_column_names = additional_column_names self.recall_at_precision_thresholds = recall_at_precision_thresholds
[docs] @classmethod def from_config(cls, config, meta: CommonMetadata = None, tensorizers=None): # TODO: refactor metric reporting and remove this hack if tensorizers: labels = list(tensorizers["labels"].vocab) else: labels = meta.target.vocab.itos config.text_column_names = [] return cls.from_config_and_label_names(config, labels)
[docs] @classmethod def from_config_and_label_names(cls, config, label_names: List[str]): if config.model_select_metric in ( ComparableClassificationMetric.LABEL_F1, ComparableClassificationMetric.LABEL_AVG_PRECISION, ComparableClassificationMetric.LABEL_ROC_AUC, ): assert config.target_label is not None assert config.target_label in label_names if config.model_select_metric in ( ComparableClassificationMetric.ROC_AUC, ComparableClassificationMetric.MCC, ): assert len(label_names) == 2 return cls( label_names, [ConsoleChannel(), FileChannel((Stage.TEST,), config.output_path)], config.model_select_metric, config.target_label, config.text_column_names, config.additional_column_names, config.recall_at_precision_thresholds, )
[docs] def batch_context(self, raw_batch, batch): context = super().batch_context(raw_batch, batch) context["text"] = [ " | ".join(str(row[column_name]) for column_name in self.text_column_names) for row in raw_batch ] # if there are additional colnames, read their contexts into batch if len(self.additional_column_names) > 0: for additional_colname in self.additional_column_names: context[additional_colname] = [ row[additional_colname] for row in raw_batch ] return context
[docs] def calculate_metric(self): return compute_classification_metrics( [ LabelPrediction(scores, pred, expect) for scores, pred, expect in zip( self.all_scores, self.all_preds, self.all_targets ) ], self.label_names, self.calculate_loss(), recall_at_precision_thresholds=self.recall_at_precision_thresholds, )
[docs] def predictions_to_report(self): """ Generate human readable predictions """ return [self.label_names[pred] for pred in self.all_preds]
[docs] def targets_to_report(self): """ Generate human readable targets """ return [self.label_names[target] for target in self.all_targets]
[docs] def get_meta(self): return {META_LABEL_NAMES: self.label_names}
[docs] def get_model_select_metric(self, metrics): if self.model_select_metric == ComparableClassificationMetric.ACCURACY: metric = metrics.accuracy elif self.model_select_metric == ComparableClassificationMetric.ROC_AUC: metric = metrics.roc_auc elif self.model_select_metric == ComparableClassificationMetric.MCC: metric = metrics.mcc elif self.model_select_metric == ComparableClassificationMetric.MACRO_F1: metric = metrics.macro_prf1_metrics.macro_scores.f1 elif self.model_select_metric == ComparableClassificationMetric.LABEL_F1: metric = metrics.macro_prf1_metrics.per_label_scores[self.target_label].f1 elif ( self.model_select_metric == ComparableClassificationMetric.LABEL_AVG_PRECISION ): metric = metrics.per_label_soft_scores[self.target_label].average_precision elif self.model_select_metric == ComparableClassificationMetric.LABEL_ROC_AUC: metric = metrics.per_label_soft_scores[self.target_label].roc_auc elif self.model_select_metric == ComparableClassificationMetric.NEGATIVE_LOSS: metric = -metrics.loss else: raise ValueError(f"unknown metric: {self.model_select_metric}") assert metric is not None return metric
[docs]class MultiLabelClassificationMetricReporter(ClassificationMetricReporter):
[docs] def calculate_metric(self): return compute_multi_label_classification_metrics( [ LabelListPrediction(scores, pred, expect) for scores, pred, expect in zip( self.all_scores, self.all_preds, self.all_targets ) ], self.label_names, self.calculate_loss(), recall_at_precision_thresholds=self.recall_at_precision_thresholds, )
[docs] def predictions_to_report(self): """ Generate human readable predictions """ return [ [self.label_names[pred] for pred in predictions] for predictions in self.all_preds ]
[docs] def targets_to_report(self): """ Generate human readable targets """ return [ [self.label_names[target] for target in targets] for targets in self.all_targets ]