#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List
import torch
from pytext.common.constants import SpecialTokens
from pytext.config.component import ComponentType, create_component
from pytext.data.tensorizers import TokenTensorizer
from pytext.data.tokenizers import Tokenizer, WordPieceTokenizer
from pytext.data.utils import VocabBuilder, Vocabulary, pad_and_tensorize
[docs]class SquadTensorizer(TokenTensorizer):
"""Produces inputs and answer spans for Squad."""
__EXPANSIBLE__ = True
SPAN_PAD_IDX = -100
[docs] class Config(TokenTensorizer.Config):
# for model inputs
doc_column: str = "doc"
ques_column: str = "question"
# for labels
answers_column: str = "answers"
answer_starts_column: str = "answer_starts"
# Since Tokenizer is __EXPANSIBLE__, we don't need a Union type to
# support WordPieceTokenizer.
tokenizer: Tokenizer.Config = Tokenizer.Config(split_regex=r"\W+")
max_ques_seq_len: int = 64
max_doc_seq_len: int = 256
[docs] @classmethod
def from_config(cls, config: Config, **kwargs):
tokenizer = create_component(ComponentType.TOKENIZER, config.tokenizer)
vocab = None
if isinstance(tokenizer, WordPieceTokenizer):
print("Using WordPieceTokenizer")
replacements = {
"[UNK]": SpecialTokens.UNK,
"[PAD]": SpecialTokens.PAD,
"[CLS]": SpecialTokens.BOS,
"[SEP]": SpecialTokens.EOS,
"[MASK]": SpecialTokens.MASK,
}
vocab = Vocabulary(
[token for token, _ in tokenizer.vocab.items()],
replacements=replacements,
)
doc_tensorizer = TokenTensorizer(
text_column=config.doc_column,
tokenizer=tokenizer,
vocab=vocab,
max_seq_len=config.max_doc_seq_len,
)
ques_tensorizer = TokenTensorizer(
text_column=config.ques_column,
tokenizer=tokenizer,
vocab=vocab,
max_seq_len=config.max_ques_seq_len,
)
return cls(
doc_tensorizer=doc_tensorizer,
ques_tensorizer=ques_tensorizer,
doc_column=config.doc_column,
ques_column=config.ques_column,
answers_column=config.answers_column,
answer_starts_column=config.answer_starts_column,
tokenizer=tokenizer,
vocab=vocab,
**kwargs,
)
def __init__(
self,
doc_tensorizer: TokenTensorizer,
ques_tensorizer: TokenTensorizer,
doc_column: str = Config.doc_column,
ques_column: str = Config.ques_column,
answers_column: str = Config.answers_column,
answer_starts_column: str = Config.answer_starts_column,
**kwargs,
):
super().__init__(text_column=None, **kwargs)
self.ques_tensorizer = ques_tensorizer
self.doc_tensorizer = doc_tensorizer
self.doc_column = doc_column
self.ques_column = ques_column
self.answers_column = answers_column
self.answer_starts_column = answer_starts_column
[docs] def initialize(self, vocab_builder=None, from_scratch=True):
"""Build vocabulary based on training corpus."""
if isinstance(self.tokenizer, WordPieceTokenizer):
return
if not self.vocab_builder or from_scratch:
self.vocab_builder = vocab_builder or VocabBuilder()
self.vocab_builder.pad_index = 0
self.vocab_builder.unk_index = 1
ques_initializer = self.ques_tensorizer.initialize(
self.vocab_builder, from_scratch
)
doc_initializer = self.doc_tensorizer.initialize(
self.vocab_builder, from_scratch
)
ques_initializer.send(None)
doc_initializer.send(None)
try:
while True:
row = yield
ques_initializer.send(row)
doc_initializer.send(row)
except GeneratorExit:
self.vocab = self.vocab_builder.make_vocab()
def _lookup_tokens(self, text, source_is_doc=True):
# This is useful in SquadMetricReporter._unnumberize()
return (
self.doc_tensorizer._lookup_tokens(text)
if source_is_doc
else self.ques_tensorizer._lookup_tokens(text)
)
[docs] def numberize(self, row):
assert len(self.vocab) == len(self.ques_tensorizer.vocab)
assert len(self.vocab) == len(self.doc_tensorizer.vocab)
# Do NOT use self._lookup_tokens() because it won't enforce max_ques_seq_len.
ques_tokens, _, _ = self.ques_tensorizer._lookup_tokens(row[self.ques_column])
# Start and end indices are those of the tokens in original text.
# The behavior doesn't change for WordPieceTokenizer because...
# If there's a word piece, say, "##ly" then the start and end indices
# will be that of "ly" in original text. These are also char level.
doc_tokens, orig_start_idx, orig_end_idx = self.doc_tensorizer._lookup_tokens(
row[self.doc_column]
)
# Now map original character level answer spans to token level spans
start_idx_map = {}
end_idx_map = {}
for token_idx, (start_idx, end_idx) in enumerate(
zip(orig_start_idx, orig_end_idx)
):
start_idx_map[start_idx] = token_idx
end_idx_map[end_idx] = token_idx
answer_start_token_indices = [
start_idx_map.get(raw_idx, self.SPAN_PAD_IDX)
for raw_idx in row[self.answer_starts_column]
]
answer_end_token_indices = [
end_idx_map.get(raw_idx + len(answer), self.SPAN_PAD_IDX)
for raw_idx, answer in zip(
row[self.answer_starts_column], row[self.answers_column]
)
] # The end index is inclusive. Span = doc_tokens[start:end+1]
if (
not (answer_start_token_indices and answer_end_token_indices)
or self._only_pad(answer_start_token_indices)
or self._only_pad(answer_end_token_indices)
):
answer_start_token_indices = [self.SPAN_PAD_IDX]
answer_end_token_indices = [self.SPAN_PAD_IDX]
return (
doc_tokens,
len(doc_tokens),
ques_tokens,
len(ques_tokens),
answer_start_token_indices,
answer_end_token_indices,
)
[docs] def tensorize(self, batch):
(
doc_tokens,
doc_seq_len,
ques_tokens,
ques_seq_len,
answer_start_idx,
answer_end_idx,
) = zip(*batch)
doc_tokens = pad_and_tensorize(doc_tokens, self.vocab.get_pad_index())
doc_mask = (doc_tokens == self.vocab.get_pad_index()).byte() # 1 => pad
ques_tokens = pad_and_tensorize(ques_tokens, self.vocab.get_pad_index())
ques_mask = (ques_tokens == self.vocab.get_pad_index()).byte() # 1 => pad
answer_start_idx = pad_and_tensorize(answer_start_idx, self.SPAN_PAD_IDX)
answer_end_idx = pad_and_tensorize(answer_end_idx, self.SPAN_PAD_IDX)
# doc_tokens must be returned as the first element for
# SquadMetricReporter._add_decoded_answer_batch_stats() to work
return (
doc_tokens,
pad_and_tensorize(doc_seq_len),
doc_mask,
ques_tokens,
pad_and_tensorize(ques_seq_len),
ques_mask,
answer_start_idx,
answer_end_idx,
)
[docs] def sort_key(self, row):
raise NotImplementedError("SquadTensorizer.sort_key() should not be called.")
def _only_pad(self, token_id_list: List[int]) -> bool:
for token_id in token_id_list:
if token_id != self.SPAN_PAD_IDX:
return False
return True
[docs]class SquadTensorizerForKD(SquadTensorizer):
[docs] class Config(SquadTensorizer.Config):
start_logits_column: str = "start_logits"
end_logits_column: str = "end_logits"
has_answer_logits_column: str = "has_answer_logits"
pad_mask_column: str = "pad_mask"
segment_labels_column: str = "segment_labels"
[docs] @classmethod
def from_config(cls, config: Config, **kwargs):
return super().from_config(
config,
start_logits_column=config.start_logits_column,
end_logits_column=config.end_logits_column,
has_answer_logits_column=config.has_answer_logits_column,
pad_mask_column=config.pad_mask_column,
segment_labels_column=config.segment_labels_column,
)
def __init__(
self,
start_logits_column=Config.start_logits_column,
end_logits_column=Config.end_logits_column,
has_answer_logits_column=Config.has_answer_logits_column,
pad_mask_column=Config.pad_mask_column,
segment_labels_column=Config.segment_labels_column,
**kwargs,
):
super().__init__(**kwargs)
self.start_logits_column = start_logits_column
self.end_logits_column = end_logits_column
self.has_answer_logits_column = has_answer_logits_column
self.pad_mask_column = pad_mask_column
self.segment_labels_column = segment_labels_column
[docs] def numberize(self, row):
numberized_row_tuple = super().numberize(row)
start_logits = self._get_doc_logits(
row[self.start_logits_column],
row[self.pad_mask_column],
row[self.segment_labels_column],
)
end_logits = self._get_doc_logits(
row[self.end_logits_column],
row[self.pad_mask_column],
row[self.segment_labels_column],
)
return numberized_row_tuple + (
start_logits,
end_logits,
row[self.has_answer_logits_column],
)
[docs] def tensorize(self, batch):
(
doc_tokens,
doc_seq_len,
ques_tokens,
ques_seq_len,
answer_start_idx,
answer_end_idx,
start_logits,
end_logits,
has_answer_logits,
) = zip(*batch)
tensor_tuple = super().tensorize(
zip(
doc_tokens,
doc_seq_len,
ques_tokens,
ques_seq_len,
answer_start_idx,
answer_end_idx,
)
)
return tensor_tuple + (
pad_and_tensorize(start_logits, dtype=torch.float),
pad_and_tensorize(end_logits, dtype=torch.float),
pad_and_tensorize(has_answer_logits, dtype=torch.float),
)
def _get_doc_logits(self, logits, pad_mask, segment_labels):
ques_seq_len = segment_labels.index(1)
try:
pad_start = pad_mask.index(0)
except ValueError: # 0 doesn't exits in pad_mask
pad_start = len(logits)
return logits[ques_seq_len : pad_start - 1] # Last non-pad token is [SEP]