Source code for pytext.data.test.data_test

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import unittest

from pytext.common.constants import RawExampleFieldName, Stage
from pytext.data import Batcher, Data, PoolingBatcher
from pytext.data.data import RowData
from pytext.data.sources import RawExample
from pytext.data.sources.data_source import SafeFileWrapper
from pytext.data.sources.tsv import TSVDataSource
from pytext.data.tensorizers import LabelTensorizer, TokenTensorizer
from pytext.data.utils import pad
from pytext.utils import precision
from pytext.utils.test import import_tests_module


tests_module = import_tests_module()


[docs]class DataTest(unittest.TestCase):
[docs] def setUp(self): self.data_source = TSVDataSource( SafeFileWrapper(tests_module.test_file("train_dense_features_tiny.tsv")), SafeFileWrapper(tests_module.test_file("test_dense_features_tiny.tsv")), eval_file=None, field_names=["label", "slots", "text", "dense"], schema={"text": str, "label": str}, ) self.tensorizers = { "tokens": TokenTensorizer(text_column="text"), "labels": LabelTensorizer(label_column="label", allow_unknown=True), }
[docs] def test_create_data_no_batcher_provided(self): data = Data(self.data_source, self.tensorizers) batches = list(data.batches(Stage.TRAIN)) # We should have made at least one non-empty batch self.assertTrue(batches) raw_batch, batch = next(iter(batches)) self.assertTrue(batch)
[docs] def test_create_batches(self): data = Data(self.data_source, self.tensorizers, Batcher(train_batch_size=16)) batches = list(data.batches(Stage.TRAIN)) self.assertEqual(1, len(batches)) raw_batch, batch = next(iter(batches)) self.assertEqual(set(self.tensorizers), set(batch)) tokens, seq_lens, _ = batch["tokens"] self.assertEqual(10, len(raw_batch)) self.assertEqual( {"text", "label", RawExampleFieldName.ROW_INDEX}, set(raw_batch[0]) ) self.assertEqual((10,), seq_lens.size()) self.assertEqual((10,), batch["labels"].size()) self.assertEqual({"tokens", "labels"}, set(batch)) self.assertEqual(10, len(tokens))
[docs] def test_create_batches_different_tensorizers(self): tensorizers = {"tokens": TokenTensorizer(text_column="text")} data = Data(self.data_source, tensorizers, Batcher(train_batch_size=16)) batches = list(data.batches(Stage.TRAIN)) self.assertEqual(1, len(batches)) raw_batch, batch = next(iter(batches)) self.assertEqual({"tokens"}, set(batch)) tokens, seq_lens, _ = batch["tokens"] self.assertEqual((10,), seq_lens.size()) self.assertEqual(10, len(tokens))
[docs] def test_data_initializes_tensorsizers(self): tensorizers = { "tokens": TokenTensorizer(text_column="text"), "labels": LabelTensorizer(label_column="label"), } # verify TokenTensorizer isn't in an initialized state yet assert tensorizers["tokens"].vocab is None Data(self.data_source, tensorizers) # Tensorizers should have been initialized self.assertEqual(49, len(tensorizers["tokens"].vocab)) self.assertEqual(7, len(tensorizers["labels"].vocab))
[docs] def test_data_iterate_multiple_times(self): data = Data(self.data_source, self.tensorizers) batches = data.batches(Stage.TRAIN) data1 = list(batches) data2 = list(batches) # We should have made at least one non-empty batch self.assertTrue(data1) self.assertTrue(data2) _, (batch1, _) = data1[0] _, (batch2, _) = data2[0] # pytorch tensors don't have equals comparisons, so comparing the tensor # dicts is non-trivial, but they should also be equal self.assertEqual(batch1, batch2)
[docs] def test_sort(self): data = Data( self.data_source, self.tensorizers, Batcher(train_batch_size=5), sort_key="tokens", ) def assert_sorted(batch): _, seq_lens, _ = batch["tokens"] seq_lens = seq_lens.tolist() for i in range(len(seq_lens) - 1): self.assertTrue(seq_lens[i] >= seq_lens[i + 1]) batches = iter(list(data.batches(Stage.TRAIN))) first_raw_batch, first_batch = next(batches) assert_sorted(first_batch) # make sure labels are also in the same order of sorted tokens self.assertEqual( self.tensorizers["labels"].vocab[first_batch["labels"][1]], "alarm/set_alarm", ) self.assertEqual(first_raw_batch[1][RawExampleFieldName.ROW_INDEX], 1) second_raw_batch, second_batch = next(batches) assert_sorted(second_batch) self.assertEqual( self.tensorizers["labels"].vocab[second_batch["labels"][1]], "alarm/time_left_on_alarm", ) self.assertEqual(second_raw_batch[0][RawExampleFieldName.ROW_INDEX], 6) self.assertEqual(second_raw_batch[1][RawExampleFieldName.ROW_INDEX], 5)
[docs] def test_create_batches_with_cache(self): data = Data( self.data_source, self.tensorizers, Batcher(train_batch_size=1), in_memory=True, ) list(data.batches(Stage.TRAIN)) self.assertEqual(10, len(data.numberized_cache[Stage.TRAIN])) data1 = Data( self.data_source, self.tensorizers, Batcher(train_batch_size=1), in_memory=True, ) with self.assertRaises(Exception): # Concurrent iteration not supported batches1 = data1.batches(Stage.TRAIN) batches2 = data1.batches(Stage.TRAIN) for _ in batches1: for _ in batches2: continue
[docs] def test_fp16_padding(self): nested_lists = [[1, 2, 3], [4, 5]] padded_lists = pad(nested_lists, pad_token=0) expected_lists = [[1, 2, 3], [4, 5, 0]] self.assertEqual(padded_lists, expected_lists) precision.FP16_ENABLED = True padded_lists = pad(nested_lists, pad_token=0) expected_lists = [[1, 2, 3, 0, 0, 0, 0, 0], [4, 5, 0, 0, 0, 0, 0, 0]] self.assertEqual(padded_lists, expected_lists) precision.FP16_ENABLED = False
[docs]class BatcherTest(unittest.TestCase):
[docs] def test_batcher(self): data = [ RowData({"text": "something"}, {"a": i, "b": 10 + i, "c": 20 + i}) for i in range(10) ] batcher = Batcher(train_batch_size=3) batches = list(batcher.batchify(data)) self.assertEqual(len(batches), 4) self.assertEqual(len(batches[0].raw_data), 3) self.assertEqual("something", batches[1].raw_data[0]["text"]) self.assertEqual(batches[1].numberized["a"], [3, 4, 5]) self.assertEqual(batches[3].numberized["b"], [19])
[docs] def test_pooling_batcher(self): data = [ RowData({"text": "something"}, {"a": i, "b": 10 + i, "c": 20 + i}) for i in range(10) ] batcher = PoolingBatcher(train_batch_size=3, pool_num_batches=2) batches = list(batcher.batchify(data, sort_key=lambda x: x.numberized["a"])) self.assertEqual(len(batches), 4) a_vals = {a for raw_batch, batch in batches for a in batch["a"]} self.assertSetEqual(a_vals, set(range(10))) for raw_batch, batch in batches[:2]: self.assertEqual([{"text": "something"}] * len(raw_batch), list(raw_batch)) self.assertGreater(batch["a"][0], batch["a"][-1]) for a in batch["a"]: self.assertLess(a, 6) for _, batch in batches[2:]: for a in batch["a"]: self.assertGreaterEqual(a, 6)
[docs]class RawExampleTest(unittest.TestCase):
[docs] def test_raw_example_hashable(self): example = RawExample() example["one"] = 111 example["two"] = "222" example["three"] = [3, 33, [333, 3333], {"33333": 333333}] example["four"] = {"4": {"44": [444, 4444]}, "44444": 444444} self.assertTrue(hash(example))