Source code for pytext.metric_reporters.dense_retrieval_metric_reporter

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from enum import Enum
from typing import Any, Dict, List

import numpy as np
from pytext.common.constants import Stage
from pytext.metric_reporters.channel import Channel, ConsoleChannel, FileChannel
from pytext.metric_reporters.metric_reporter import MetricReporter
from pytext.metrics.dense_retrieval_metrics import DenseRetrievalMetrics


[docs]class DenseRetrievalMetricNames(Enum): ACCURACY = "accuracy" AVG_RANK = "avg_rank" MEAN_RECIPROCAL_RANK = "mean_reciprocal_rank" # use negative because the reporter's lower_is_better value is False NEGATIVE_LOSS = "negative_loss"
[docs]class DenseRetrievalMetricReporter(MetricReporter):
[docs] class Config(MetricReporter.Config): text_column_names: List[str] = ["question", "positive_ctx", "negative_ctxs"] model_select_metric: DenseRetrievalMetricNames = ( DenseRetrievalMetricNames.ACCURACY ) # We need this because the id of positive index depens on the batch size. # This is needed to set the global id of the positive contexts when # computing average rank. # Set by PairwiseClassificationForDenseRetrievalTask._init_tensorizers() task_batch_size: int = 0 num_negative_ctxs: int = 0
[docs] @classmethod def from_config(cls, config, *args, tensorizers=None, **kwargs): return cls( channels=[ConsoleChannel(), FileChannel((Stage.TEST,), config.output_path)], text_column_names=config.text_column_names, model_select_metric=config.model_select_metric, task_batch_size=config.task_batch_size, num_negative_ctxs=config.num_negative_ctxs, )
def __init__( self, channels: List[Channel], text_column_names: List[str], model_select_metric: DenseRetrievalMetricNames, task_batch_size: int, num_negative_ctxs: int = 0, ) -> None: super().__init__(channels) self.channels = channels self.text_column_names = text_column_names self.model_select_metric = model_select_metric # Assert these values to make sure that they are set explicitly. assert ( task_batch_size != 0 ), "DenseRetrievalMetricReporter: Batch size cannot be zero." print(f"DenseRetrievalMetricReporter: task_batch_size = {task_batch_size}") assert ( num_negative_ctxs != 0 ), "DenseRetrievalMetricReporter: Number of hard negatives cannot be zero." print(f"DenseRetrievalMetricReporter: num_negative_ctxs = {num_negative_ctxs}") self.task_batch_size = task_batch_size self.num_negative_ctxs = num_negative_ctxs def _reset(self): super()._reset() self.all_question_logits = [] self.all_context_logits = []
[docs] def aggregate_preds(self, preds, context): preds, question_logits, context_logits = preds super().aggregate_preds(preds) self.aggregate_data(self.all_question_logits, question_logits) self.aggregate_data(self.all_context_logits, context_logits)
[docs] def batch_context(self, raw_batch, batch) -> Dict[str, Any]: context = super().batch_context(raw_batch, batch) for name in self.text_column_names: context[name] = [row[name] for row in raw_batch] return context
[docs] def calculate_metric(self) -> DenseRetrievalMetrics: average_rank, mean_reciprocal_rank = self._get_ranking_metrics() return DenseRetrievalMetrics( num_examples=len(self.all_preds), accuracy=self._get_accuracy(), average_rank=average_rank, mean_reciprocal_rank=mean_reciprocal_rank, )
[docs] def get_model_select_metric(self, metrics: DenseRetrievalMetrics): if self.model_select_metric == DenseRetrievalMetricNames.ACCURACY: metric = metrics.accuracy elif self.model_select_metric == DenseRetrievalMetricNames.AVG_RANK: metric = metrics.average_rank elif self.model_select_metric == DenseRetrievalMetricNames.MEAN_RECIPROCAL_RANK: metric = metrics.mean_reciprocal_rank else: raise ValueError(f"Unknown metric: {self.model_select_metric}") return metric
def _get_accuracy(self): num_correct = sum(int(p == t) for p, t in zip(self.all_preds, self.all_targets)) return num_correct / len(self.all_preds) def _get_ranking_metrics(self): dot_products = np.matmul( self.all_question_logits, np.transpose(self.all_context_logits) ) inverse_sorted_indices = np.argsort(dot_products, axis=1) # ascending positive_indices_per_question = self._get_positive_indices() num_questions = inverse_sorted_indices.shape[0] num_docs = inverse_sorted_indices.shape[1] rank_sum = 0 reciprocal_rank_sum = 0 # Sum up the rank of positive context in sorted scores for each question for i, pos_ctx_idx in enumerate(positive_indices_per_question): # Numpy returns a tuple of lists. So handle that. gold_idx = (inverse_sorted_indices[i] == pos_ctx_idx).nonzero()[0][0] rank = num_docs - gold_idx rank_sum += rank reciprocal_rank_sum += 1.0 / rank average_rank = rank_sum / num_questions mean_reciprocal_rank = reciprocal_rank_sum / num_questions return average_rank, mean_reciprocal_rank def _get_positive_indices(self): positive_indices_per_question = [] batch_id, total_ctxs = 0, 0 batch_size = self.task_batch_size * (1 + self.num_negative_ctxs) for i, pos_ctx_idx in enumerate(self.all_targets): batch_id = i // self.task_batch_size total_ctxs = batch_id * batch_size positive_indices_per_question.append(total_ctxs + pos_ctx_idx) return positive_indices_per_question