Source code for pytext.metric_reporters.regression_metric_reporter

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from pytext.metrics import compute_regression_metrics

from .channel import ConsoleChannel
from .metric_reporter import MetricReporter


[docs]class RegressionMetricReporter(MetricReporter): lower_is_better = False
[docs] class Config(MetricReporter.Config): pass
[docs] @classmethod def from_config(cls, config, tensorizers=None): return cls([ConsoleChannel()])
[docs] def calculate_metric(self): assert len(self.all_preds) == len(self.all_targets) return compute_regression_metrics(self.all_preds, self.all_targets)
[docs] def get_model_select_metric(self, metrics): return metrics.pearson_correlation