#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from enum import Enum
from itertools import tee, zip_longest
from typing import Generator, List, Optional
import torch
from pytext.common.constants import Stage
from pytext.data import CommonMetadata
from pytext.metrics import (
RECALL_AT_PRECISION_THRESHOLDS,
LabelListPrediction,
LabelPrediction,
compute_classification_metrics,
compute_multi_label_classification_metrics,
)
from .channel import Channel, ConsoleChannel, FileChannel
from .metric_reporter import MetricReporter
META_LABEL_NAMES = "label_names"
[docs]class ComparableClassificationMetric(Enum):
ACCURACY = "accuracy"
ROC_AUC = "roc_auc"
MCC = "mcc"
MACRO_F1 = "macro_f1"
LABEL_F1 = "label_f1"
LABEL_AVG_PRECISION = "label_avg_precision"
LABEL_ROC_AUC = "label_roc_auc"
# use negative because the reporter's lower_is_better value is False
NEGATIVE_LOSS = "negative_loss"
[docs]class ClassificationMetricReporter(MetricReporter):
__EXPANSIBLE__ = True
[docs] class Config(MetricReporter.Config):
model_select_metric: ComparableClassificationMetric = (
ComparableClassificationMetric.ACCURACY
)
target_label: Optional[str] = None
#: These column names correspond to raw input data columns. Text in these
#: columns (usually just 1 column) will be concatenated and output in
#: the IntentModelChannel as an evaluation tsv.
text_column_names: List[str] = ["text"]
#: These column names correspond to raw input data columns, that
#: will be read by data_source into context, and included in the
#: run_model output file along with other saving results.
additional_column_names: List[str] = []
recall_at_precision_thresholds: List[float] = RECALL_AT_PRECISION_THRESHOLDS
# Boolean which is used to run a more memory efficient version of the
# metric reporter. This involves storing as little information as possible
# in memory and as a result, we don't compute metrics other than accuracy
# and F1. This is useful when the label space is huge and we don't want to
# keep around the score for every label for every example in memory.
# This also means that the output debug file in the Test operator will have
# only predictions and no label scores.
is_memory_efficient: bool = False
def __init__(
self,
label_names: List[str],
channels: List[Channel],
model_select_metric: ComparableClassificationMetric = (
ComparableClassificationMetric.ACCURACY
),
target_label: Optional[str] = None,
text_column_names: List[str] = Config.text_column_names,
additional_column_names: List[str] = Config.additional_column_names,
recall_at_precision_thresholds: List[float] = (
Config.recall_at_precision_thresholds
),
is_memory_efficient: bool = Config.is_memory_efficient,
) -> None:
super().__init__(channels)
self.label_names = label_names
self.model_select_metric = model_select_metric
self.target_label = target_label
self.text_column_names = text_column_names
self.additional_column_names = additional_column_names
self.recall_at_precision_thresholds = recall_at_precision_thresholds
self.is_memory_efficient = is_memory_efficient
[docs] @classmethod
def from_config(cls, config, meta: CommonMetadata = None, tensorizers=None):
# TODO: refactor metric reporting and remove this hack
if tensorizers:
labels = list(tensorizers["labels"].vocab)
else:
labels = meta.target.vocab.itos
config.text_column_names = []
return cls.from_config_and_label_names(config, labels)
[docs] @classmethod
def from_config_and_label_names(cls, config, label_names: List[str]):
if config.model_select_metric in (
ComparableClassificationMetric.LABEL_F1,
ComparableClassificationMetric.LABEL_AVG_PRECISION,
ComparableClassificationMetric.LABEL_ROC_AUC,
):
assert (
config.target_label is not None
), "target_label must be set for selected metric"
assert config.target_label in label_names
if config.model_select_metric in (
ComparableClassificationMetric.ROC_AUC,
ComparableClassificationMetric.MCC,
):
assert (
len(label_names) == 2
), "selected metric is valid for binary labels only"
return cls(
label_names,
[ConsoleChannel(), FileChannel((Stage.TEST,), config.output_path)],
config.model_select_metric,
config.target_label,
config.text_column_names,
config.additional_column_names,
config.recall_at_precision_thresholds,
config.is_memory_efficient,
)
[docs] def batch_context(self, raw_batch, batch):
context = super().batch_context(raw_batch, batch)
context["text"] = [
" | ".join(str(row[column_name]) for column_name in self.text_column_names)
for row in raw_batch
]
# if there are additional colnames, read their contexts into batch
if len(self.additional_column_names) > 0:
for additional_colname in self.additional_column_names:
context[additional_colname] = [
row[additional_colname] for row in raw_batch
]
return context
[docs] def add_batch_stats(
self, n_batches, preds, targets, scores, loss, m_input, **context
):
"""
Aggregates a batch of output data (predictions, scores, targets/true labels
and loss).
Args:
n_batches (int): number of current batch
preds (torch.Tensor): predictions of current batch
targets (torch.Tensor): targets of current batch
scores (torch.Tensor): scores of current batch
loss (double): average loss of current batch
m_input (Tuple[torch.Tensor, ...]): model inputs of current batch
context (Dict[str, Any]): any additional context data, it could be
either a list of data which maps to each example, or a single value
for the batch
"""
self.n_batches = n_batches
self.aggregate_preds(preds, context)
self.aggregate_targets(targets, context)
# if we are running in memory efficient mode, then don't store all the scores
# This list in general is the most memory hungry of all the data structures.
# For a problem with 10K classes, we store 10K floats for every instance in
# the epoch. This is bad.
if not self.is_memory_efficient:
self.aggregate_scores(scores)
for key, val in context.items():
if not (isinstance(val, torch.Tensor) or isinstance(val, List)):
continue
if key not in self.all_context:
self.all_context[key] = []
self.aggregate_data(self.all_context[key], val)
# some loss functions (eg: in NewBertRegressionTask) return a tensor
# convert tensor to float
if loss is not None:
self.all_loss.append(float(loss))
# append the size of the first tensor (should be the batch size)
if isinstance(m_input, Generator):
# first element without updating the initial iterator
first = tee(m_input, 1)
self.batch_size.append(len(first))
else:
self.batch_size.append(len(m_input[0]))
[docs] def calculate_metric(self):
# If we are running in memory efficient mode, then scores in
# LabelPrediction should be an empty list
label_predictions = [
LabelPrediction(scores, pred, expect)
for scores, pred, expect in zip_longest(
self.all_scores, self.all_preds, self.all_targets, fillvalue=[]
)
]
return compute_classification_metrics(
label_predictions,
self.label_names,
self.calculate_loss(),
# Compute soft-metrics only if self.is_memory_efficient is False
average_precisions=(not self.is_memory_efficient),
recall_at_precision_thresholds=self.recall_at_precision_thresholds,
)
[docs] def predictions_to_report(self):
"""
Generate human readable predictions
"""
return [self.label_names[pred] for pred in self.all_preds]
[docs] def targets_to_report(self):
"""
Generate human readable targets
"""
return [self.label_names[target] for target in self.all_targets]
[docs] def get_model_select_metric(self, metrics):
if self.model_select_metric == ComparableClassificationMetric.ACCURACY:
metric = metrics.accuracy
elif self.model_select_metric == ComparableClassificationMetric.ROC_AUC:
metric = metrics.roc_auc
elif self.model_select_metric == ComparableClassificationMetric.MCC:
metric = metrics.mcc
elif self.model_select_metric == ComparableClassificationMetric.MACRO_F1:
metric = metrics.macro_prf1_metrics.macro_scores.f1
elif self.model_select_metric == ComparableClassificationMetric.LABEL_F1:
metric = metrics.macro_prf1_metrics.per_label_scores[self.target_label].f1
elif (
self.model_select_metric
== ComparableClassificationMetric.LABEL_AVG_PRECISION
):
metric = metrics.per_label_soft_scores[self.target_label].average_precision
elif self.model_select_metric == ComparableClassificationMetric.LABEL_ROC_AUC:
metric = metrics.per_label_soft_scores[self.target_label].roc_auc
elif self.model_select_metric == ComparableClassificationMetric.NEGATIVE_LOSS:
metric = -metrics.loss
else:
raise ValueError(f"unknown metric: {self.model_select_metric}")
assert metric is not None
return metric
[docs]class MultiLabelClassificationMetricReporter(ClassificationMetricReporter):
[docs] class Config(ClassificationMetricReporter.Config):
pass
[docs] def calculate_metric(self):
return compute_multi_label_classification_metrics(
[
LabelListPrediction(scores, pred, expect)
for scores, pred, expect in zip(
self.all_scores, self.all_preds, self.all_targets
)
],
self.label_names,
self.calculate_loss(),
recall_at_precision_thresholds=self.recall_at_precision_thresholds,
)
[docs] def predictions_to_report(self):
"""
Generate human readable predictions
"""
return [
[
self.label_names[pred_idx]
for pred_idx, pred in enumerate(predictions)
if pred == 1
]
for predictions in self.all_preds
]
[docs] def targets_to_report(self):
"""
Generate human readable targets
"""
return [
[self.label_names[target] for target in targets if target != -1]
for targets in self.all_targets
]