#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import unittest
from pytext.common.constants import RawExampleFieldName, Stage
from pytext.data import Batcher, Data, PoolingBatcher
from pytext.data.data import RowData
from pytext.data.sources import RawExample
from pytext.data.sources.data_source import SafeFileWrapper
from pytext.data.sources.tsv import TSVDataSource
from pytext.data.tensorizers import LabelTensorizer, TokenTensorizer
from pytext.data.utils import pad
from pytext.utils import precision
from pytext.utils.test import import_tests_module
tests_module = import_tests_module()
[docs]class DataTest(unittest.TestCase):
[docs] def setUp(self):
self.data_source = TSVDataSource(
SafeFileWrapper(tests_module.test_file("train_dense_features_tiny.tsv")),
SafeFileWrapper(tests_module.test_file("test_dense_features_tiny.tsv")),
eval_file=None,
field_names=["label", "slots", "text", "dense"],
schema={"text": str, "label": str},
)
self.tensorizers = {
"tokens": TokenTensorizer(text_column="text"),
"labels": LabelTensorizer(label_column="label", allow_unknown=True),
}
[docs] def test_create_data_no_batcher_provided(self):
data = Data(self.data_source, self.tensorizers)
batches = list(data.batches(Stage.TRAIN))
# We should have made at least one non-empty batch
self.assertTrue(batches)
raw_batch, batch = next(iter(batches))
self.assertTrue(batch)
[docs] def test_create_batches(self):
data = Data(self.data_source, self.tensorizers, Batcher(train_batch_size=16))
batches = list(data.batches(Stage.TRAIN))
self.assertEqual(1, len(batches))
raw_batch, batch = next(iter(batches))
self.assertEqual(set(self.tensorizers), set(batch))
tokens, seq_lens, _ = batch["tokens"]
self.assertEqual(10, len(raw_batch))
self.assertEqual(
{"text", "label", RawExampleFieldName.ROW_INDEX}, set(raw_batch[0])
)
self.assertEqual((10,), seq_lens.size())
self.assertEqual((10,), batch["labels"].size())
self.assertEqual({"tokens", "labels"}, set(batch))
self.assertEqual(10, len(tokens))
[docs] def test_create_batches_different_tensorizers(self):
tensorizers = {"tokens": TokenTensorizer(text_column="text")}
data = Data(self.data_source, tensorizers, Batcher(train_batch_size=16))
batches = list(data.batches(Stage.TRAIN))
self.assertEqual(1, len(batches))
raw_batch, batch = next(iter(batches))
self.assertEqual({"tokens"}, set(batch))
tokens, seq_lens, _ = batch["tokens"]
self.assertEqual((10,), seq_lens.size())
self.assertEqual(10, len(tokens))
[docs] def test_data_initializes_tensorsizers(self):
tensorizers = {
"tokens": TokenTensorizer(text_column="text"),
"labels": LabelTensorizer(label_column="label"),
}
# verify TokenTensorizer isn't in an initialized state yet
assert tensorizers["tokens"].vocab is None
Data(self.data_source, tensorizers)
# Tensorizers should have been initialized
self.assertEqual(49, len(tensorizers["tokens"].vocab))
self.assertEqual(7, len(tensorizers["labels"].vocab))
[docs] def test_data_iterate_multiple_times(self):
data = Data(self.data_source, self.tensorizers)
batches = data.batches(Stage.TRAIN)
data1 = list(batches)
data2 = list(batches)
# We should have made at least one non-empty batch
self.assertTrue(data1)
self.assertTrue(data2)
_, (batch1, _) = data1[0]
_, (batch2, _) = data2[0]
# pytorch tensors don't have equals comparisons, so comparing the tensor
# dicts is non-trivial, but they should also be equal
self.assertEqual(batch1, batch2)
[docs] def test_sort(self):
data = Data(
self.data_source,
self.tensorizers,
Batcher(train_batch_size=5),
sort_key="tokens",
)
def assert_sorted(batch):
_, seq_lens, _ = batch["tokens"]
seq_lens = seq_lens.tolist()
for i in range(len(seq_lens) - 1):
self.assertTrue(seq_lens[i] >= seq_lens[i + 1])
batches = iter(list(data.batches(Stage.TRAIN)))
first_raw_batch, first_batch = next(batches)
assert_sorted(first_batch)
# make sure labels are also in the same order of sorted tokens
self.assertEqual(
self.tensorizers["labels"].vocab[first_batch["labels"][1]],
"alarm/set_alarm",
)
self.assertEqual(first_raw_batch[1][RawExampleFieldName.ROW_INDEX], 1)
second_raw_batch, second_batch = next(batches)
assert_sorted(second_batch)
self.assertEqual(
self.tensorizers["labels"].vocab[second_batch["labels"][1]],
"alarm/time_left_on_alarm",
)
self.assertEqual(second_raw_batch[0][RawExampleFieldName.ROW_INDEX], 6)
self.assertEqual(second_raw_batch[1][RawExampleFieldName.ROW_INDEX], 5)
[docs] def test_create_batches_with_cache(self):
data = Data(
self.data_source,
self.tensorizers,
Batcher(train_batch_size=1),
in_memory=True,
)
list(data.batches(Stage.TRAIN))
self.assertEqual(10, len(data.numberized_cache[Stage.TRAIN]))
data1 = Data(
self.data_source,
self.tensorizers,
Batcher(train_batch_size=1),
in_memory=True,
)
with self.assertRaises(Exception):
# Concurrent iteration not supported
batches1 = data1.batches(Stage.TRAIN)
batches2 = data1.batches(Stage.TRAIN)
for _ in batches1:
for _ in batches2:
continue
[docs] def test_fp16_padding(self):
nested_lists = [[1, 2, 3], [4, 5]]
padded_lists = pad(nested_lists, pad_token=0)
expected_lists = [[1, 2, 3], [4, 5, 0]]
self.assertEqual(padded_lists, expected_lists)
precision.FP16_ENABLED = True
padded_lists = pad(nested_lists, pad_token=0)
expected_lists = [[1, 2, 3, 0, 0, 0, 0, 0], [4, 5, 0, 0, 0, 0, 0, 0]]
self.assertEqual(padded_lists, expected_lists)
precision.FP16_ENABLED = False
[docs]class BatcherTest(unittest.TestCase):
[docs] def test_batcher(self):
data = [
RowData({"text": "something"}, {"a": i, "b": 10 + i, "c": 20 + i})
for i in range(10)
]
batcher = Batcher(train_batch_size=3)
batches = list(batcher.batchify(data))
self.assertEqual(len(batches), 4)
self.assertEqual(len(batches[0].raw_data), 3)
self.assertEqual("something", batches[1].raw_data[0]["text"])
self.assertEqual(batches[1].numberized["a"], [3, 4, 5])
self.assertEqual(batches[3].numberized["b"], [19])
[docs] def test_pooling_batcher(self):
data = [
RowData({"text": "something"}, {"a": i, "b": 10 + i, "c": 20 + i})
for i in range(10)
]
batcher = PoolingBatcher(train_batch_size=3, pool_num_batches=2)
batches = list(batcher.batchify(data, sort_key=lambda x: x.numberized["a"]))
self.assertEqual(len(batches), 4)
a_vals = {a for raw_batch, batch in batches for a in batch["a"]}
self.assertSetEqual(a_vals, set(range(10)))
for raw_batch, batch in batches[:2]:
self.assertEqual([{"text": "something"}] * len(raw_batch), list(raw_batch))
self.assertGreater(batch["a"][0], batch["a"][-1])
for a in batch["a"]:
self.assertLess(a, 6)
for _, batch in batches[2:]:
for a in batch["a"]:
self.assertGreaterEqual(a, 6)
[docs]class RawExampleTest(unittest.TestCase):
[docs] def test_raw_example_hashable(self):
example = RawExample()
example["one"] = 111
example["two"] = "222"
example["three"] = [3, 33, [333, 3333], {"33333": 333333}]
example["four"] = {"4": {"44": [444, 4444]}, "44444": 444444}
self.assertTrue(hash(example))