#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import json
import math
from random import choice
from typing import List, Optional
from pytext.data.sources.data_source import (
DataSource,
JSONString,
SafeFileWrapper,
generator_property,
)
from pytext.data.sources.tsv import TSV
from pytext.utils.file_io import PathManager
from pytext.utils.path import get_absolute_path
def _shift_answers(orig_starts, piece_start, piece_end):
# Re-align answer index for each piece when we split a long document.
answer_starts = []
has_answer = False
for start in orig_starts:
if start >= piece_start and start < piece_end:
answer_starts.append(start - piece_start)
has_answer = True
return answer_starts, has_answer
def _split_document(
id,
doc,
question,
answers,
answer_starts,
has_answer,
ignore_impossible,
max_character_length,
min_overlap,
):
pieces = []
min_overlap = math.floor(max_character_length * min_overlap)
if has_answer or not ignore_impossible:
n_pieces = 1 + math.ceil(
max(0, len(doc) - max_character_length)
/ (max_character_length - min_overlap)
)
overlap = (
math.floor((n_pieces * max_character_length - len(doc)) / (n_pieces - 1))
if n_pieces > 1
else 0
)
for n in range(n_pieces):
start, end = (
n * (max_character_length - overlap),
(n + 1) * (max_character_length - overlap) + overlap,
)
answer_starts, piece_has_answer = _shift_answers(answer_starts, start, end)
pieces.append(
{
"id": id,
"doc": doc[start:end],
"question": question,
"answers": answers,
"answer_starts": answer_starts,
"has_answer": str(has_answer and piece_has_answer),
}
)
return pieces
[docs]class SquadDataSource(DataSource):
"""
Download data from https://rajpurkar.github.io/SQuAD-explorer/
Will return tuples of (doc, question, answer, answer_start, has_answer)
"""
__EXPANSIBLE__ = True
DEFAULT_SCHEMA = {
"id": int,
"doc": str,
"question": str,
"answers": List[str],
"answer_starts": List[int],
"answer_ends": List[int],
"has_answer": str,
}
[docs] class Config(DataSource.Config):
train_filename: Optional[str] = "train-v2.0.json"
test_filename: Optional[str] = "dev-v2.0.json"
eval_filename: Optional[str] = "dev-v2.0.json"
ignore_impossible: bool = True
max_character_length: int = 2 ** 20
min_overlap: float = 0.1 # Expressed as a fraction of the max_character_length.
delimiter: str = "\t"
quoted: bool = False
[docs] @classmethod
def from_config(cls, config: Config, schema=DEFAULT_SCHEMA):
return cls(
train_filename=config.train_filename,
test_filename=config.test_filename,
eval_filename=config.eval_filename,
ignore_impossible=config.ignore_impossible,
max_character_length=config.max_character_length,
min_overlap=config.min_overlap,
delimiter=config.delimiter,
quoted=config.quoted,
)
def __init__(
self,
train_filename=None,
test_filename=None,
eval_filename=None,
ignore_impossible=Config.ignore_impossible,
max_character_length=Config.max_character_length,
min_overlap=Config.min_overlap,
delimiter=Config.delimiter,
quoted=Config.quoted,
schema=DEFAULT_SCHEMA,
):
super().__init__(schema)
self.train_filename = train_filename
self.test_filename = test_filename
self.eval_filename = eval_filename
self.ignore_impossible = ignore_impossible
self.max_character_length = max_character_length
self.min_overlap = min_overlap
self.delimiter = delimiter
self.quoted = quoted
[docs] def process_file(self, fname):
# Pick which method to use based on extension.
if fname.split(".")[-1] == "json":
return self.process_squad_json(fname=fname)
else:
return self.process_squad_tsv(fname=fname)
[docs] def process_squad_json(self, fname):
if not fname:
return
if not PathManager.exists(fname):
print(f"{fname} does not exist. Not unflattening.")
return
with PathManager.open(fname) as infile:
dump = json.load(infile)
id = 0
for article in dump["data"]:
for paragraph in article["paragraphs"]:
doc = paragraph["context"]
for question in paragraph["qas"]:
has_answer = not question.get("is_impossible", False)
answers = (
question["answers"]
if has_answer
else question["plausible_answers"]
)
question = question["question"]
answer_texts = [answer["text"] for answer in answers]
answer_starts = [int(answer["answer_start"]) for answer in answers]
for piece_dict in _split_document(
id,
doc,
question,
answer_texts,
answer_starts,
has_answer,
self.ignore_impossible,
self.max_character_length,
self.min_overlap,
):
yield piece_dict
id += 1
[docs] def process_squad_tsv(self, fname):
if not fname:
print("Empty file name!")
return
field_names = ["doc", "question", "answers", "answer_starts", "has_answer"]
tsv_file = SafeFileWrapper(
get_absolute_path(fname), encoding="utf-8", errors="replace"
)
tsv = TSV(
tsv_file,
field_names=field_names,
delimiter=self.delimiter,
quoted=self.quoted,
drop_incomplete_rows=True,
)
for id, row in enumerate(tsv):
parts = (row[f] for f in field_names)
doc, question, answers, answer_starts, has_answer = parts
try:
# if we have paraphrases for question
question = json.loads(question)
if isinstance(question, list):
question = choice(question)
except ValueError:
pass
answers = json.loads(answers)
answer_starts = json.loads(answer_starts)
for piece_dict in _split_document(
id,
doc,
question,
answers,
answer_starts,
has_answer == "True",
self.ignore_impossible,
self.max_character_length,
self.min_overlap,
):
yield piece_dict
@generator_property
def train(self):
return self.process_file(self.train_filename)
@generator_property
def test(self):
return self.process_file(self.test_filename)
@generator_property
def eval(self):
return self.process_file(self.eval_filename)
[docs]class SquadDataSourceForKD(SquadDataSource):
"""
Squad-like data along with soft labels (logits).
Will return tuples of (
doc, question, answer, answer_start, has_answer,
start_logits, end_logits, has_answer_logits, pad_mask, segment_labels
)
"""
def __init__(self, **kwargs):
kwargs["schema"] = {
"id": int,
"doc": JSONString,
"question": JSONString,
"answers": List[str],
"answer_starts": List[int],
"has_answer": JSONString,
"start_logits": List[float],
"end_logits": List[float],
"has_answer_logits": List[float],
"pad_mask": List[int],
"segment_labels": List[int],
}
super().__init__(**kwargs)
[docs] def process_squad_tsv(self, fname):
# Process SQUAD TSV for KD
if not fname:
print("Empty file name!")
return
field_names = [
"id1",
"doc",
"question",
"answers",
"answer_starts",
"has_answer",
"id2",
"start_logits",
"end_logits",
"has_answer_logits",
"pad_mask",
"segment_labels",
]
tsv_file = SafeFileWrapper(
get_absolute_path(fname), encoding="utf-8", errors="replace"
)
tsv = TSV(
tsv_file,
field_names=field_names,
delimiter=self.delimiter,
quoted=self.quoted,
drop_incomplete_rows=True,
)
for id, row in enumerate(tsv):
parts = (row[f] for f in field_names)
# All model output for KD are dumped using json serialization.
(
id1,
doc,
question,
answers,
answer_starts,
has_answer,
id2,
start_logits,
end_logits,
has_answer_logits,
pad_mask,
segment_labels,
) = (json.loads(s) for s in parts)
if isinstance(question, list):
# if we have paraphrases for question
question = choice(question)
for piece_dict in _split_document(
id,
doc,
question,
answers,
answer_starts,
has_answer == "True",
self.ignore_impossible,
self.max_character_length,
self.min_overlap,
):
piece_dict.update(
{
"start_logits": start_logits,
"end_logits": end_logits,
"has_answer_logits": has_answer_logits,
"pad_mask": pad_mask,
"segment_labels": segment_labels,
}
)
yield piece_dict