#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, Optional
from pytext.common.constants import BatchContext
from .metric_reporter import MetricReporter
AVRG_LOSS = "_avrg_loss"
[docs]class DisjointMultitaskMetricReporter(MetricReporter):
lower_is_better = False
[docs] class Config(MetricReporter.Config):
use_subtask_select_metric: bool = False
def __init__(
self,
reporters: Dict[str, MetricReporter],
loss_weights: Dict[str, float],
target_task_name: Optional[str],
use_subtask_select_metric: bool,
) -> None:
"""Short summary.
Args:
reporters (Dict[str, MetricReporter]):
Dictionary of sub-task metric-reporters.
target_task_name (Optional[str]):
Dev metric for this task will be used to select best epoch.
Returns:
None: Description of returned object.
"""
super().__init__(None)
self.reporters = reporters
self.target_task_name = target_task_name or ""
self.target_reporter = self.reporters.get(self.target_task_name, None)
self.loss_weights = loss_weights
self.use_subtask_select_metric = use_subtask_select_metric
def _reset(self):
self.total_loss = 0
self.num_batches = 0
[docs] def batch_context(self, raw_batch, batch):
context = {BatchContext.TASK_NAME: batch[BatchContext.TASK_NAME]}
reporter = self.reporters[context[BatchContext.TASK_NAME]]
context.update(reporter.batch_context(raw_batch, batch))
return context
[docs] def add_batch_stats(
self, n_batches, preds, targets, scores, loss, m_input, **context
):
self.total_loss += loss.item()
self.num_batches += 1
# losses are weighted in DisjointMultitaskModel. Here we undo the
# weighting for proper reporting.
if self.loss_weights[context[BatchContext.TASK_NAME]] != 0:
loss /= self.loss_weights[context[BatchContext.TASK_NAME]]
reporter = self.reporters[context[BatchContext.TASK_NAME]]
reporter.add_batch_stats(
n_batches, preds, targets, scores, loss, m_input, **context
)
[docs] def add_channel(self, channel):
for reporter in self.reporters.values():
reporter.add_channel(channel)
[docs] def report_metric(
self, model, stage, epoch, reset=True, print_to_channels=True, optimizer=None
):
# Initialize `metrics_dict` with the average loss across sub-tasks.
metrics_dict = {AVRG_LOSS: self.total_loss / self.num_batches}
# Store computed metrics for each sub-task in `metrics_dict`.
for name, reporter in self.reporters.items():
print(f"Reporting on task: {name}")
metrics_dict[name] = reporter.report_metric(
model, stage, epoch, reset, print_to_channels, optimizer=optimizer
)
# Reset loss and batch counters.
if reset:
self._reset()
# If the target task is specified, return all target metrics.
if self.target_reporter:
return metrics_dict[self.target_task_name]
# Otherwise, for each task, return its model selection metric.
for name, reporter in self.reporters.items():
metrics_dict[name] = reporter.get_model_select_metric(metrics_dict[name])
return metrics_dict
[docs] def get_model_select_metric(self, metrics):
if self.target_reporter:
metric = self.target_reporter.get_model_select_metric(metrics)
if self.target_reporter.lower_is_better:
metric = -metric
elif self.use_subtask_select_metric:
metric = 0.0
for name, reporter in self.reporters.items():
sub_metric = metrics[name]
if reporter.lower_is_better:
sub_metric = -sub_metric
metric += sub_metric
else: # default to training loss
metric = -metrics[AVRG_LOSS]
return metric
[docs] def report_realtime_metric(self, stage):
for _, reporter in self.reporters.items():
reporter.report_realtime_metric(stage)