Source code for pytext.metric_reporters.squad_metric_reporter

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import re
import string
from collections import Counter
from typing import Dict, List

import numpy as np
from pytext.common.constants import Stage
from pytext.metric_reporters.channel import Channel, ConsoleChannel, FileChannel
from pytext.metric_reporters.metric_reporter import MetricReporter
from pytext.metrics.squad_metrics import SquadMetrics


[docs]class SquadFileChannel(FileChannel):
[docs] def get_title(self, context_keys=()): return ( "index", "ques", "doc", "predicted_answer", "true_answers", "predicted_start_pos", "predicted_end_pos", "true_start_pos", "true_end_pos", "start_pos_scores", "end_pos_scores", "predicted_has_answer", "true_has_answer", "has_answer_scores", )
[docs] def gen_content(self, metrics, loss, preds, targets, scores, contexts, *args): pred_answers, pred_start_pos, pred_end_pos, pred_has_answer = preds true_answers, true_start_pos, true_end_pos, true_has_answer = targets start_pos_scores, end_pos_scores, has_answer_scores = scores for i in range(len(pred_answers)): yield [ contexts[SquadMetricReporter.ROW_INDEX][i], contexts[SquadMetricReporter.QUES_COLUMN][i], contexts[SquadMetricReporter.DOC_COLUMN][i], pred_answers[i], true_answers[i], pred_start_pos[i], pred_end_pos[i], true_start_pos[i], true_end_pos[i], start_pos_scores[i], end_pos_scores[i], pred_has_answer[i], true_has_answer[i], has_answer_scores[i], ]
[docs]class SquadMetricReporter(MetricReporter): QUES_COLUMN = "question" ANSWERS_COLUMN = "answers" DOC_COLUMN = "doc" ROW_INDEX = "id"
[docs] class Config(MetricReporter.Config): n_best_size: int = 5 max_answer_length: int = 16 ignore_impossible: bool = True false_label: str = "False"
[docs] @classmethod def from_config(cls, config, *args, tensorizers=None, **kwargs): return cls( channels=[ ConsoleChannel(), SquadFileChannel((Stage.TEST,), config.output_path), ], n_best_size=config.n_best_size, max_answer_length=config.max_answer_length, ignore_impossible=config.ignore_impossible, has_answer_labels=tensorizers["has_answer"].vocab._vocab, tensorizer=tensorizers["squad_input"], false_label=config.false_label, )
def __init__( self, channels: List[Channel], n_best_size: int, max_answer_length: int, ignore_impossible: bool, has_answer_labels: List[str], tensorizer=None, false_label=Config.false_label, ) -> None: super().__init__(channels) self.channels = channels self.tensorizer = tensorizer self.ignore_impossible = ignore_impossible self.has_answer_labels = has_answer_labels self.false_label = false_label self.false_idx = 1 if has_answer_labels[1] == false_label else 0 self.true_idx = 1 - self.false_idx def _reset(self): self.all_start_pos_preds: List = [] self.all_start_pos_targets: List = [] self.all_start_pos_scores: List = [] self.all_end_pos_preds: List = [] self.all_end_pos_targets: List = [] self.all_end_pos_scores: List = [] self.all_has_answer_targets: List = [] self.all_has_answer_preds: List = [] self.all_has_answer_scores: List = [] self.all_preds = ( self.all_start_pos_preds, self.all_end_pos_preds, self.all_has_answer_preds, ) self.all_targets = ( self.all_start_pos_targets, self.all_end_pos_targets, self.all_has_answer_targets, ) self.all_scores = ( self.all_start_pos_scores, self.all_end_pos_scores, self.all_has_answer_scores, ) self.all_context: Dict = {} self.all_loss: List = [] self.all_pred_answers: List = [] # self.all_true_answers: List = [] self.batch_size: List = [] self.n_batches = 0 def _add_decoded_answer_batch_stats(self, m_input, preds, **contexts): # For BERT, doc_tokens = concatenated tokens from question and document. doc_tokens = m_input[0] starts, ends, has_answers = preds pred_answers = [ self._unnumberize(tokens[start : end + 1].tolist(), doc_str) for tokens, start, end, doc_str in zip( doc_tokens, starts, ends, contexts[self.DOC_COLUMN] ) ] self.aggregate_data(self.all_pred_answers, pred_answers)
[docs] def add_batch_stats( self, n_batches, preds, targets, scores, loss, m_input, **contexts ): # contexts object is the dict returned by self.batch_context(). super().add_batch_stats( n_batches, preds, targets, scores, loss, m_input, **contexts ) self._add_decoded_answer_batch_stats(m_input, preds, **contexts)
[docs] def aggregate_preds(self, new_batch, context=None): self.aggregate_data(self.all_start_pos_preds, new_batch[0]) self.aggregate_data(self.all_end_pos_preds, new_batch[1]) self.aggregate_data(self.all_has_answer_preds, new_batch[2])
[docs] def aggregate_targets(self, new_batch, context=None): self.aggregate_data(self.all_start_pos_targets, new_batch[0]) self.aggregate_data(self.all_end_pos_targets, new_batch[1]) self.aggregate_data(self.all_has_answer_targets, new_batch[2])
[docs] def aggregate_scores(self, new_batch): self.aggregate_data(self.all_start_pos_scores, new_batch[0]) self.aggregate_data(self.all_end_pos_scores, new_batch[1]) self.aggregate_data(self.all_has_answer_scores, new_batch[2])
[docs] def batch_context(self, raw_batch, batch): context = super().batch_context(raw_batch, batch) context[self.ROW_INDEX] = [row[self.ROW_INDEX] for row in raw_batch] context[self.QUES_COLUMN] = [row[self.QUES_COLUMN] for row in raw_batch] context[self.ANSWERS_COLUMN] = [row[self.ANSWERS_COLUMN] for row in raw_batch] context[self.DOC_COLUMN] = [row[self.DOC_COLUMN] for row in raw_batch] return context
[docs] def calculate_metric(self): all_rows = zip( self.all_context[self.ROW_INDEX], self.all_context[self.ANSWERS_COLUMN], self.all_context[self.QUES_COLUMN], self.all_context[self.DOC_COLUMN], self.all_pred_answers, self.all_start_pos_preds, self.all_end_pos_preds, self.all_has_answer_preds, self.all_start_pos_targets, self.all_end_pos_targets, self.all_has_answer_targets, self.all_start_pos_scores, self.all_end_pos_scores, self.all_has_answer_scores, ) all_rows_dict = {} for row in all_rows: try: all_rows_dict[row[0]].append(row) except KeyError: all_rows_dict[row[0]] = [row] all_rows = [] for rows in all_rows_dict.values(): argmax = np.argmax([row[11] + row[12] for row in rows]) all_rows.append(rows[argmax]) sorted(all_rows, key=lambda x: int(x[0])) ( self.all_context[self.ROW_INDEX], self.all_context[self.ANSWERS_COLUMN], self.all_context[self.QUES_COLUMN], self.all_context[self.DOC_COLUMN], self.all_pred_answers, self.all_start_pos_preds, self.all_end_pos_preds, self.all_has_answer_preds, self.all_start_pos_targets, self.all_end_pos_targets, self.all_has_answer_targets, self.all_start_pos_scores, self.all_end_pos_scores, self.all_has_answer_scores, ) = zip(*all_rows) exact_matches, count = self._compute_exact_matches( self.all_pred_answers, self.all_context[self.ANSWERS_COLUMN], self.all_has_answer_preds, self.all_has_answer_targets, ) f1_score = self._compute_f1_score( self.all_pred_answers, self.all_context[self.ANSWERS_COLUMN], self.all_has_answer_preds, self.all_has_answer_targets, ) self.all_preds = ( self.all_pred_answers, self.all_start_pos_preds, self.all_end_pos_preds, self.all_has_answer_preds, ) self.all_targets = ( self.all_context[self.ANSWERS_COLUMN], self.all_start_pos_targets, self.all_end_pos_targets, self.all_has_answer_targets, ) self.all_scores = ( self.all_start_pos_scores, self.all_end_pos_scores, self.all_has_answer_scores, ) metrics = SquadMetrics( exact_matches=100.0 * exact_matches / count, f1_score=f1_score, num_examples=count, ) return metrics
[docs] def get_model_select_metric(self, metric: SquadMetrics): return metric.f1_score
def _compute_exact_matches( self, pred_answer_list, target_answers_list, pred_has_answer_list, target_has_answer_list, ): exact_matches = 0 for pred_answer, target_answers, pred_has_answer, target_has_answer in zip( pred_answer_list, target_answers_list, pred_has_answer_list, target_has_answer_list, ): if not self.ignore_impossible: if pred_has_answer != target_has_answer: continue if pred_has_answer == self.false_idx: exact_matches += 1 continue pred = self._normalize_answer(pred_answer) for answer in target_answers: true = self._normalize_answer(answer) if pred == true: exact_matches += 1 break return exact_matches, len(pred_answer_list) def _compute_f1_score( self, pred_answer_list, target_answers_list, pred_has_answer_list, target_has_answer_list, ): f1_scores_sum = 0.0 for pred_answer, target_answers, pred_has_answer, target_has_answer in zip( pred_answer_list, target_answers_list, pred_has_answer_list, target_has_answer_list, ): if not self.ignore_impossible: if pred_has_answer != target_has_answer: continue if pred_has_answer == self.false_idx: f1_scores_sum += 1.0 continue f1_scores_sum += max( self._compute_f1_per_answer(answer, pred_answer) for answer in target_answers ) return 100.0 * f1_scores_sum / len(pred_answer_list) def _unnumberize(self, ans_tokens, doc_str): """ Tokens is the span of token ids that the model predicted. We re-tokenize and re-numberize the raw context (doc_str) here to get doc_tokens to get access to start_idx and end_idx mappings. At this point, ans_tokens is a sub-list of doc_tokens (hopefully, if the model predicted a span in the context). Then we find tokens inside doc_tokens, and return the corresponding span in the raw text using the idx mapping. """ # start_idx and end_idx are lists of char start and end positions in doc_str. doc_tokens, start_idx, end_idx = self.tensorizer._lookup_tokens(doc_str) doc_tokens = list(doc_tokens) num_ans_tokens = len(ans_tokens) answer_str = "" for doc_token_idx in range(len(doc_tokens) - num_ans_tokens): if doc_tokens[doc_token_idx : doc_token_idx + num_ans_tokens] == ans_tokens: start_char_idx = start_idx[doc_token_idx] end_char_idx = end_idx[doc_token_idx + num_ans_tokens - 1] answer_str = doc_str[start_char_idx:end_char_idx] break return answer_str # The following three functions are copied from Squad's evaluation script. # https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ def _normalize_answer(self, s): """Lower text and remove punctuation, articles and extra whitespace.""" def white_space_fix(text): return " ".join(text.split()) def remove_articles(text): regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) return re.sub(regex, " ", text) def remove_punc(text): exclude = set(string.punctuation) return "".join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def _get_tokens(self, s): if not s: return [] return self._normalize_answer(s).split() def _compute_f1_per_answer(self, a_gold, a_pred): gold_toks = self._get_tokens(a_gold) pred_toks = self._get_tokens(a_pred) common = Counter(gold_toks) & Counter(pred_toks) num_same = sum(common.values()) if len(gold_toks) == 0 or len(pred_toks) == 0: # If either is no-answer, then F1 is 1 if they agree, 0 otherwise return int(gold_toks == pred_toks) if num_same == 0: return 0 precision = 1.0 * num_same / len(pred_toks) recall = 1.0 * num_same / len(gold_toks) f1 = (2 * precision * recall) / (precision + recall) return f1