Source code for pytext.metric_reporters.language_model_metric_reporter

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import json
import math
import operator
import time

import torch
import torch.nn.functional as F
from pytext.common.constants import Stage
from pytext.config.module_config import PerplexityType
from import CommonMetadata
from pytext.metrics.language_model_metrics import (
from pytext.utils import cuda

from .channel import ConsoleChannel, FileChannel
from .metric_reporter import MetricReporter

    PerplexityType.MIN: torch.min,
    PerplexityType.MAX: torch.max,
    PerplexityType.MEAN: torch.mean,
    PerplexityType.MEDIAN: torch.median,
    PerplexityType.EOS: operator.itemgetter(-1),

[docs]def get_perplexity_func(perplexity_type): func = PERPLEXITY_FUNC_MAP.get(perplexity_type, None) if not func: raise NotImplementedError return func
[docs]class LanguageModelChannel(FileChannel):
[docs] def get_title(self, context_keys=()): return ("text", "perplexity")
[docs] def gen_content(self, metrics, loss, preds, targets, scores, contexts): for i in range(len(scores)): yield [contexts["utterance"][i], scores[i]]
[docs]class LanguageModelMetricReporter(MetricReporter): UTTERANCE_COLUMN = "utterance" RAW_TEXT_COLUMN = "text" TOKENS_COLUMN = "tokens" LABELS_COLUMN = "labels" lower_is_better = True
[docs] class Config(MetricReporter.Config): aggregate_metrics: bool = True perplexity_type: PerplexityType = PerplexityType.MEDIAN
[docs] @classmethod def from_config(cls, config: Config, meta: CommonMetadata = None, tensorizers=None): return cls( [ConsoleChannel(), LanguageModelChannel((Stage.TEST,), config.output_path)], meta, tensorizers, config.aggregate_metrics, config.perplexity_type, config.pep_format, config.log_gradient, )
def __init__( self, channels, metadata, tensorizers, aggregate_metrics, perplexity_type, pep_format, log_gradient=False, ): super().__init__(channels, log_gradient=log_gradient, pep_format=pep_format) self.metadata = metadata self.tensorizers = tensorizers self.aggregate_metrics = aggregate_metrics assert metadata or tensorizers if metadata: self.pad_index = if tensorizers: if self.TOKENS_COLUMN in tensorizers: column = self.TOKENS_COLUMN elif self.LABELS_COLUMN in tensorizers: column = self.LABELS_COLUMN if hasattr(tensorizers[column], "vocab"): self.pad_index = tensorizers[column].vocab.get_pad_index() else: self.pad_index = tensorizers[column].PAD_BYTE self.perplexity_func = get_perplexity_func(perplexity_type)
[docs] def add_batch_stats( self, n_batches, preds, targets, scores, loss, m_input, **context ): if isinstance(loss, torch.Tensor): loss = loss.item() num_words_in_batch = targets[1].sum().item() self.aggregate_loss += loss * num_words_in_batch self.total_num_tokens += num_words_in_batch if self.aggregate_metrics and num_words_in_batch > 0: # unpacks logits from `targets` and computes scores for # each item in the batch, e.g. sentence-level perplexity if isinstance(targets, tuple): targets = targets[0] scores = self.compute_scores(preds, targets) # scores is None when the every element in the target is self.pad_index if scores is not None: self.aggregate_scores(scores) self.aggregate_context(context)
[docs] def calculate_loss(self) -> float: if self.total_num_tokens == 0: return 0.0 return self.aggregate_loss / float(self.total_num_tokens)
def _reset(self): super()._reset() self.aggregate_loss = 0.0 self.total_num_tokens = 0
[docs] def calculate_metric(self) -> LanguageModelMetric: # In language model self.total_loss is the loss per word return compute_language_model_metric(self.total_loss)
[docs] def get_model_select_metric(self, metrics) -> float: return metrics.perplexity_per_word
[docs] def batch_context(self, raw_batch, batch): context = {} if any(self.RAW_TEXT_COLUMN in row for row in raw_batch): context.update( { self.UTTERANCE_COLUMN: [ row.get(self.RAW_TEXT_COLUMN) for row in raw_batch ] } ) return context
[docs] def compute_scores(self, logits, targets): def _compute_score(tensor): """ Uses a perplexity reduction function to compute a score for a given tensor, e.g. the mean perplexity. Filters ignored tensor items -- these are 0 by default. """ return torch.exp(self.perplexity_func(tensor)) # compute cross-entropy loss of logits wrt targets -- don't reduce # to access the loss of each item in the batch scores = F.cross_entropy( logits.permute(0, 2, 1), targets, ignore_index=self.pad_index, reduction="none", ) # scores is 0 at positions of the target == pad_index, non_padding_scores_all_sentences = [] for sentence_score, sentence_target in zip(scores, targets): non_padding_score_per_sentence = sentence_score[ sentence_target != self.pad_index ] # exclude padding position at each sentence, filter the sentence # from score calculation if every target in the # sentence == self.pad_index if non_padding_score_per_sentence.numel() > 0: non_padding_scores_all_sentences.append(non_padding_score_per_sentence) if len(non_padding_scores_all_sentences) == 0: return return map(lambda x: _compute_score(x).item(), non_padding_scores_all_sentences)
[docs] def aggregate_scores(self, scores): self.all_scores.extend(scores)
[docs] def aggregate_context(self, context): for key, val in context.items(): if key not in self.all_context: self.all_context[key] = [] self.all_context[key].extend(val)
[docs]class MaskedLMMetricReporter(LanguageModelMetricReporter):
[docs] @classmethod def from_config(cls, config, meta: CommonMetadata = None, tensorizers=None): return cls( [ConsoleChannel()], meta, tensorizers, config.aggregate_metrics, config.perplexity_type, config.pep_format, )
[docs] def add_batch_stats( self, n_batches, preds, targets, scores, loss, m_input, **context ): now = time.time() total_masked_tokens = targets[1].sum().item() self.aggregate_loss += loss.item() * total_masked_tokens self.total_masked_tokens += total_masked_tokens # realtime stats total_tokens = float(targets[2].sum()) self.realtime_meters["tps"].update(total_tokens) self.last_batch_tps = total_tokens / (now - self.time + 1e-6) self.last_batch_loss = loss.item() self.total_batches = n_batches self.time = now
[docs] def report_realtime_metric(self, stage): if stage != Stage.TRAIN: return if cuda.DISTRIBUTED_WORLD_SIZE > 1: all_reduce_stats = cuda.tensor( [ self.last_batch_tps, self.last_batch_loss, self.aggregate_loss, self.total_masked_tokens, self.realtime_meters["tps"].n, ], dtype=torch.float32, ) total_elapsed_time = self.realtime_meters["tps"].elapsed_time torch.distributed.all_reduce(all_reduce_stats) # average last_batch_loss by distributed_world_size all_reduce_stats[1:2].div_(cuda.DISTRIBUTED_WORLD_SIZE) [ last_batch_tps, last_batch_loss, aggregate_loss, total_masked_tokens, total_tokens, ] = all_reduce_stats.tolist() tps = total_tokens / total_elapsed_time else: last_batch_tps = self.last_batch_tps last_batch_loss = self.last_batch_loss aggregate_loss = self.aggregate_loss total_masked_tokens = self.total_masked_tokens tps = self.realtime_meters["tps"].avg print( f"Tokens/s: {last_batch_tps:.0f}, " f"batch ppl: {math.exp(last_batch_loss):.2f}, " f"agg ppl: {math.exp(self._calculate_loss(aggregate_loss, total_masked_tokens)):.2f}, " f"number of batches: {self.total_batches:.0f}, " f"accumulated tokens/s: {tps:.0f}", flush=True, ) # TODO: remove GPU0 report print( f"GPU-0 tokens/s: {self.last_batch_tps:.0f}, " f"batch ppl: {math.exp(self.last_batch_loss):.2f}, " f"agg ppl: {math.exp(self.calculate_loss()):.2f}, " f"number of batches: {self.total_batches}, " f"accumulated tokens/s: {self.realtime_meters['tps'].avg:.0f}", flush=True, ) if self.pep_format: # used for pep regression benchmark print( "PyTorchObserver " + json.dumps( { "type": "MLM", "metric": "tps", "unit": "token/sec", "value": f"{tps:.0f}", } ), flush=True, )
[docs] def calculate_loss(self) -> float: return self._calculate_loss(self.aggregate_loss, self.total_masked_tokens)
def _calculate_loss(self, aggregate_loss, total_masked_tokens) -> float: return aggregate_loss / max(1, total_masked_tokens) def _reset(self): super()._reset() self.aggregate_loss = 0.0 self.total_masked_tokens = 0 def _reset_realtime(self): super()._reset_realtime() self.last_batch_tps = 0 self.last_batch_loss = 0 self.total_batches = 0 self.time = time.time()