Source code for pytext.metric_reporters.seq2seq_metric_reporter

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, List

import torch
from pytext.common.constants import (
    BatchContext,
    DatasetFieldName,
    RawExampleFieldName,
    Stage,
)
from pytext.data.tensorizers import Tensorizer
from pytext.metric_reporters.channel import ConsoleChannel, FileChannel
from pytext.metric_reporters.metric_reporter import MetricReporter
from pytext.metrics import safe_division
from pytext.metrics.seq2seq_metrics import Seq2SeqMetrics, compute_f1


try:
    from fairseq.scoring import bleu
except ImportError:
    from fairseq import bleu


[docs]class Seq2SeqFileChannel(FileChannel): def __init__(self, stages, file_path, tensorizers): super().__init__(stages, file_path) self.tensorizers = tensorizers
[docs] def get_title(self, context_keys=()): return ("doc_index", "raw_input", "predictions", "targets")
[docs] def gen_content(self, metrics, loss, preds, targets, scores, context): batch_size = len(targets) assert batch_size == len(context[DatasetFieldName.RAW_SEQUENCE]) == len(preds) for i in range(batch_size): yield [ context[BatchContext.INDEX][i], context[DatasetFieldName.RAW_SEQUENCE][i], self.tensorizers["trg_seq_tokens"].stringify(preds[i][0]), self.tensorizers["trg_seq_tokens"].stringify(targets[i]), ]
[docs]class Seq2SeqMetricReporter(MetricReporter): lower_is_better = True
[docs] class Config(MetricReporter.Config): pass
def __init__(self, channels, log_gradient, tensorizers): super().__init__(channels, log_gradient) self.tensorizers = tensorizers def _reset(self): super()._reset() self.all_target_lens: List = [] self.all_src_tokens: List = []
[docs] @classmethod def from_config(cls, config: Config, tensorizers: Dict[str, Tensorizer]): return cls( [ ConsoleChannel(), Seq2SeqFileChannel([Stage.TEST], config.output_path, tensorizers), ], config.log_gradient, tensorizers, )
[docs] def add_batch_stats( self, n_batches, preds, targets, scores, loss, m_input, **context ): super().add_batch_stats( n_batches, preds, targets, scores, loss, m_input, **context ) src_tokens = m_input[0] self.aggregate_src_tokens(src_tokens)
[docs] def calculate_metric(self): total_exact_match = 0 total_f1 = 0.0 num_samples = len(self.all_targets) trg_vocab = self.tensorizers["trg_seq_tokens"].vocab bleu_scorer = bleu.Scorer( bleu.BleuConfig( pad=trg_vocab.get_pad_index(), eos=trg_vocab.get_eos_index(), unk=trg_vocab.get_unk_index(), ) ) for (beam_preds, target) in zip(self.all_preds, self.all_targets): pred = beam_preds[0] if self._compare_target_prediction_tokens(pred, target): total_exact_match += 1 total_f1 += compute_f1(pred, target) # Bleu Metric calculation is always done with tensors on CPU or # type checks in fairseq/bleu.py:add() will fail bleu_scorer.add(torch.IntTensor(target).cpu(), torch.IntTensor(pred).cpu()) loss = self.calculate_loss() exact_match = round(safe_division(total_exact_match, num_samples) * 100.0, 2) f1 = round(safe_division(total_f1, num_samples) * 100.0, 2) bleu_score = round(0.0 if len(self.all_preds) == 0 else bleu_scorer.score(), 2) return Seq2SeqMetrics(loss, exact_match, f1, bleu_score)
[docs] def aggregate_targets(self, new_batch, context=None): if new_batch is None: return target_pad_token = self.tensorizers["trg_seq_tokens"].vocab.get_pad_index() self.aggregate_data( self.all_targets, [ self._remove_tokens(targets, [target_pad_token]) for targets in self._make_simple_list(new_batch[0]) ], ) self.aggregate_data(self.all_target_lens, new_batch[1])
[docs] def aggregate_preds(self, new_batch, context=None): if new_batch is None: return self.aggregate_data(self.all_preds, new_batch)
[docs] def aggregate_src_tokens(self, new_batch): src_pad_token = self.tensorizers["src_seq_tokens"].vocab.get_pad_index() self.aggregate_data( self.all_src_tokens, [ self._remove_tokens(src, [src_pad_token]) for src in self._make_simple_list(new_batch) ], )
[docs] def gen_extra_context(self, *args): self.all_context[DatasetFieldName.SOURCE_SEQ_FIELD] = self.all_src_tokens
def _compare_target_prediction_tokens(self, prediction, target): return prediction == target def _remove_tokens(self, tokens_list, remove_tokens_list): cleaned_tokens = [] for token in tokens_list: if isinstance(token, list): cleaned_tokens.append(self._remove_tokens(token, remove_tokens_list)) elif token not in remove_tokens_list: cleaned_tokens.append(token) return cleaned_tokens
[docs] def get_model_select_metric(self, metrics): return metrics.loss
[docs] def batch_context(self, raw_batch, batch): return { DatasetFieldName.RAW_SEQUENCE: [ row["source_sequence"] for row in raw_batch ], BatchContext.INDEX: [ row[RawExampleFieldName.ROW_INDEX] for row in raw_batch ], }