Source code for pytext.data.squad_for_bert_tensorizer

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import itertools
from typing import List

import torch
from pytext.data.bert_tensorizer import BERTTensorizer
from pytext.data.roberta_tensorizer import RoBERTaTensorizer
from pytext.data.tensorizers import lookup_tokens
from pytext.data.utils import pad_and_tensorize
from pytext.torchscript.tensorizer import ScriptRoBERTaTensorizerWithIndices
from pytext.torchscript.vocab import ScriptVocabulary


[docs]class SquadForBERTTensorizer(BERTTensorizer): """Produces BERT inputs and answer spans for Squad.""" __EXPANSIBLE__ = True SPAN_PAD_IDX = -100
[docs] class Config(BERTTensorizer.Config): columns: List[str] = ["question", "doc"] # for labels answers_column: str = "answers" answer_starts_column: str = "answer_starts" max_seq_len: int = 256
[docs] @classmethod def from_config(cls, config: Config, **kwargs): # reuse parent class's from_config, which will pass extra args # in **kwargs to cls.__init__ return super().from_config( config, answers_column=config.answers_column, answer_starts_column=config.answer_starts_column, **kwargs, )
def __init__( self, answers_column: str = Config.answers_column, answer_starts_column: str = Config.answer_starts_column, **kwargs, ): # Arguments which are common to both current and base class are passed # as **kwargs. These are then passed to the __init__ of the base class super().__init__(**kwargs) self.answers_column = answers_column self.answer_starts_column = answer_starts_column def _lookup_tokens(self, text: str, seq_len: int = None): # BoS token is added explicitly in numberize(), -1 from max_seq_len max_seq_len = (seq_len or self.max_seq_len) - 1 return lookup_tokens( text, tokenizer=self.tokenizer, vocab=self.vocab, bos_token=None, eos_token=self.vocab.eos_token, max_seq_len=max_seq_len, ) def _calculate_answer_indices(self, row, offset, start_idx, end_idx): # now map original answer spans to tokenized spans start_idx_map = {} end_idx_map = {} for tokenized_idx, (raw_start_idx, raw_end_idx) in enumerate( zip(start_idx[:-1], end_idx[:-1]) ): start_idx_map[raw_start_idx] = tokenized_idx + offset end_idx_map[raw_end_idx] = tokenized_idx + offset answer_start_indices = [ start_idx_map.get(raw_idx, self.SPAN_PAD_IDX) for raw_idx in row[self.answer_starts_column] ] answer_end_indices = [ end_idx_map.get(raw_idx + len(answer), self.SPAN_PAD_IDX) for raw_idx, answer in zip( row[self.answer_starts_column], row[self.answers_column] ) ] if not (answer_start_indices and answer_end_indices): answer_start_indices = [self.SPAN_PAD_IDX] answer_end_indices = [self.SPAN_PAD_IDX] return answer_start_indices, answer_end_indices
[docs] def numberize(self, row): question_column, doc_column = self.columns doc_tokens, start_idx, end_idx = self._lookup_tokens(row[doc_column]) question_tokens, _, _ = self._lookup_tokens(row[question_column]) question_tokens = [self.vocab.get_bos_index()] + question_tokens seq_lens = (len(question_tokens), len(doc_tokens)) segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens)) tokens = list(itertools.chain(question_tokens, doc_tokens)) segment_labels = list(itertools.chain(*segment_labels)) seq_len = len(tokens) positions = list(range(seq_len)) # now map original answer spans to tokenized spans offset = len(question_tokens) answer_start_indices, answer_end_indices = self._calculate_answer_indices( row, offset, start_idx, end_idx ) return ( tokens, segment_labels, seq_len, positions, answer_start_indices, answer_end_indices, )
[docs] def tensorize(self, batch): ( tokens, segment_labels, seq_len, positions, answer_start_idx, answer_end_idx, ) = zip(*batch) tokens = pad_and_tensorize(tokens, self.vocab.get_pad_index()) segment_labels = pad_and_tensorize(segment_labels, self.vocab.get_pad_index()) pad_mask = (tokens != self.vocab.get_pad_index()).long() positions = pad_and_tensorize(positions) answer_start_idx = pad_and_tensorize(answer_start_idx, self.SPAN_PAD_IDX) answer_end_idx = pad_and_tensorize(answer_end_idx, self.SPAN_PAD_IDX) return ( tokens, pad_mask, segment_labels, positions, answer_start_idx, answer_end_idx, )
[docs]class SquadForBERTTensorizerForKD(SquadForBERTTensorizer):
[docs] class Config(SquadForBERTTensorizer.Config): start_logits_column: str = "start_logits" end_logits_column: str = "end_logits" has_answer_logits_column: str = "has_answer_logits" pad_mask_column: str = "pad_mask" segment_labels_column: str = "segment_labels"
[docs] @classmethod def from_config(cls, config: Config, **kwargs): return super().from_config( config, start_logits_column=config.start_logits_column, end_logits_column=config.end_logits_column, has_answer_logits_column=config.has_answer_logits_column, pad_mask_column=config.pad_mask_column, segment_labels_column=config.segment_labels_column, )
def __init__( self, start_logits_column=Config.start_logits_column, end_logits_column=Config.end_logits_column, has_answer_logits_column=Config.has_answer_logits_column, pad_mask_column=Config.pad_mask_column, segment_labels_column=Config.segment_labels_column, **kwargs, ): super().__init__(**kwargs) self.start_logits_column = start_logits_column self.end_logits_column = end_logits_column self.has_answer_logits_column = has_answer_logits_column self.pad_mask_column = pad_mask_column self.segment_labels_column = segment_labels_column # For logging self.total = 0 self.mismatches = 0 def __del__(self): print("Destroying SquadForBERTTensorizerForKD object") print(f"SquadForBERTTensorizerForKD: Number of rows read: {self.total}") print(f"SquadForBERTTensorizerForKD: Number of rows dropped: {self.mismatches}")
[docs] def numberize(self, row): self.total += 1 numberized_row_tuple = super().numberize(row) try: tup = numberized_row_tuple + ( self._get_token_logits( row[self.start_logits_column], row[self.pad_mask_column] ), self._get_token_logits( row[self.end_logits_column], row[self.pad_mask_column] ), row[self.has_answer_logits_column], ) except KeyError: # Logits for KD Tensorizer not provided, using padding. tup = numberized_row_tuple + ( [self.vocab.get_pad_index()] * len(numberized_row_tuple[0]), [self.vocab.get_pad_index()] * len(numberized_row_tuple[0]), [self.vocab.get_pad_index()] * 2, ) try: assert len(tup[0]) == len(tup[6]) except AssertionError: self.mismatches += 1 print( f"len(tup[0]) = {len(tup[0])} and len(tup[6]) = {len(tup[6])}", flush=True, ) raise return tup
[docs] def tensorize(self, batch): ( tokens, segment_labels, seq_lens, positions, answer_start_idx, answer_end_idx, start_logits, end_logits, has_answer_logits, ) = zip(*batch) tensor_tuple = super().tensorize( zip( tokens, segment_labels, seq_lens, positions, answer_start_idx, answer_end_idx, ) ) return tensor_tuple + ( pad_and_tensorize(start_logits, dtype=torch.float), pad_and_tensorize(end_logits, dtype=torch.float), pad_and_tensorize( has_answer_logits, dtype=torch.float, pad_shape=[len(has_answer_logits), len(has_answer_logits[0])], ), )
def _get_token_logits(self, logits, pad_mask): try: pad_start = pad_mask.index(self.vocab.get_pad_index()) except ValueError: # pad_index doesn't exits in pad_mask pad_start = len(logits) return logits[:pad_start]
[docs]class SquadForRoBERTaTensorizer(RoBERTaTensorizer, SquadForBERTTensorizer): """Produces RoBERTa inputs and answer spans for Squad.""" __EXPANSIBLE__ = True
[docs] class Config(RoBERTaTensorizer.Config): columns: List[str] = ["question", "doc"] # for labels answers_column: str = "answers" answer_starts_column: str = "answer_starts" max_seq_len: int = 256
[docs] @classmethod def from_config(cls, config: Config, **kwargs): # reuse parent class's from_config, which will pass extra args # in **kwargs to cls.__init__ return super().from_config( config, answers_column=config.answers_column, answer_starts_column=config.answer_starts_column, **kwargs, )
def __init__( self, answers_column: str = Config.answers_column, answer_starts_column: str = Config.answer_starts_column, **kwargs, ): # Arguments which are common to both current and base class are passed # as **kwargs. These are then passed to the __init__ of the base class super().__init__(**kwargs) self.answers_column = answers_column self.answer_starts_column = answer_starts_column self.wrap_special_tokens = False def _lookup_tokens(self, text: str, seq_len: int = None): # BoS token is added explicitly in numberize() return lookup_tokens( text, tokenizer=self.tokenizer, vocab=self.vocab, bos_token=None, eos_token=self.vocab.eos_token, max_seq_len=seq_len if seq_len else self.max_seq_len, )
[docs] def torchscriptify(self): return ScriptRoBERTaTensorizerWithIndices( tokenizer=self.tokenizer.torchscriptify(), vocab=ScriptVocabulary( list(self.vocab), pad_idx=self.vocab.get_pad_index(), bos_idx=self.vocab.get_bos_index(), eos_idx=self.vocab.get_eos_index(), ), max_seq_len=self.max_seq_len, )
[docs]class SquadForRoBERTaTensorizerForKD(SquadForRoBERTaTensorizer):
[docs] class Config(SquadForRoBERTaTensorizer.Config): start_logits_column: str = "start_logits" end_logits_column: str = "end_logits" has_answer_logits_column: str = "has_answer_logits" pad_mask_column: str = "pad_mask" segment_labels_column: str = "segment_labels"
[docs] @classmethod def from_config(cls, config: Config, **kwargs): return super().from_config( config, start_logits_column=config.start_logits_column, end_logits_column=config.end_logits_column, has_answer_logits_column=config.has_answer_logits_column, pad_mask_column=config.pad_mask_column, segment_labels_column=config.segment_labels_column, )
def __init__( self, start_logits_column=Config.start_logits_column, end_logits_column=Config.end_logits_column, has_answer_logits_column=Config.has_answer_logits_column, pad_mask_column=Config.pad_mask_column, segment_labels_column=Config.segment_labels_column, **kwargs, ): super().__init__(**kwargs) self.start_logits_column = start_logits_column self.end_logits_column = end_logits_column self.has_answer_logits_column = has_answer_logits_column self.pad_mask_column = pad_mask_column self.segment_labels_column = segment_labels_column # For logging self.total = 0 self.mismatches = 0 def __del__(self): print("Destroying SquadForRoBERTaTensorizerForKD object") print(f"SquadForRoBERTaTensorizerForKD: Number of rows read: {self.total}") print( f"SquadForRoBERTaTensorizerForKD: Number of rows dropped: {self.mismatches}" )
[docs] def numberize(self, row): self.total += 1 numberized_row_tuple = super().numberize(row) try: tup = numberized_row_tuple + ( self._get_token_logits( row[self.start_logits_column], row[self.pad_mask_column] ), self._get_token_logits( row[self.end_logits_column], row[self.pad_mask_column] ), row[self.has_answer_logits_column], ) except KeyError: # Logits for KD Tensorizer not provided, using padding. tup = numberized_row_tuple + ( [self.vocab.get_pad_index()] * len(numberized_row_tuple[0]), [self.vocab.get_pad_index()] * len(numberized_row_tuple[0]), [self.vocab.get_pad_index()] * 2, ) try: assert len(tup[0]) == len(tup[6]) except AssertionError: self.mismatches += 1 print( f"len(tup[0]) = {len(tup[0])} and len(tup[6]) = {len(tup[6])}", flush=True, ) raise return tup
[docs] def tensorize(self, batch): ( tokens, segment_labels, seq_lens, positions, answer_start_idx, answer_end_idx, start_logits, end_logits, has_answer_logits, ) = zip(*batch) tensor_tuple = super().tensorize( zip( tokens, segment_labels, seq_lens, positions, answer_start_idx, answer_end_idx, ) ) return tensor_tuple + ( pad_and_tensorize(start_logits, dtype=torch.float), pad_and_tensorize(end_logits, dtype=torch.float), pad_and_tensorize( has_answer_logits, dtype=torch.float, pad_shape=[len(has_answer_logits), len(has_answer_logits[0])], ), )
def _get_token_logits(self, logits, pad_mask): pad_start = pad_mask.count(self.vocab.get_pad_index()) return logits[:pad_start]