Source code for pytext.data.test.tsv_data_source_test

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import unittest
from typing import List

from pytext.data.sources.data_source import SafeFileWrapper
from pytext.data.sources.tsv import (
    BlockShardedTSVDataSource,
    SessionTSVDataSource,
    TSVDataSource,
)
from pytext.utils.test import import_tests_module


tests_module = import_tests_module()


[docs]class TSVDataSourceTest(unittest.TestCase):
[docs] def setUp(self): self.data = TSVDataSource( SafeFileWrapper(tests_module.test_file("train_dense_features_tiny.tsv")), SafeFileWrapper(tests_module.test_file("test_dense_features_tiny.tsv")), eval_file=None, field_names=["label", "slots", "text", "dense"], schema={"text": str, "label": str}, )
[docs] def test_read_data_source(self): data = list(self.data.train) self.assertEqual(10, len(data)) example = next(iter(data)) self.assertEqual(2, len(example)) self.assertEqual({"label", "text"}, set(example))
[docs] def test_quoting(self): """ The text column of the first row of this file opens a quote but does not close it. """ data_source = TSVDataSource( SafeFileWrapper(tests_module.test_file("test_tsv_quoting.tsv")), SafeFileWrapper(tests_module.test_file("test_tsv_quoting.tsv")), eval_file=None, field_names=["label", "text"], schema={"text": str, "label": str}, ) data = list(data_source.train) self.assertEqual(4, len(data))
[docs] def test_bad_quoting(self): """ The text column of the first row of this file opens a quote but does not close it. """ data_source = TSVDataSource( SafeFileWrapper(tests_module.test_file("test_tsv_quoting.tsv")), SafeFileWrapper(tests_module.test_file("test_tsv_quoting.tsv")), eval_file=None, field_names=["label", "text"], schema={"text": str, "label": str}, quoted=True, ) data = list(data_source.train) self.assertEqual(1, len(data))
[docs] def test_csv(self): data_source = TSVDataSource( SafeFileWrapper(tests_module.test_file("test_data_tiny_csv.tsv")), test_file=None, eval_file=None, field_names=["label", "slots", "text"], delimiter=",", schema={"text": str, "label": str}, quoted=True, ) for row in data_source.train: self.assertEqual("alarm/set_alarm", row["label"]) self.assertTrue(row["text"].startswith("this is the text"))
[docs] def test_read_test_data_source(self): data = list(self.data.test) self.assertEqual(4, len(data)) example = next(iter(data)) self.assertEqual(2, len(example)) self.assertEqual({"label", "text"}, set(example))
[docs] def test_read_eval_data_source(self): data = list(self.data.eval) self.assertEqual(0, len(data))
[docs] def test_iterate_training_data_multiple_times(self): train = self.data.train data = list(train) data2 = list(train) self.assertEqual(10, len(data)) self.assertEqual(10, len(data2)) example = next(iter(data2)) self.assertEqual(2, len(example)) self.assertEqual({"label", "text"}, set(example))
[docs] def test_read_data_source_with_column_remapping(self): data_source = TSVDataSource( SafeFileWrapper(tests_module.test_file("train_dense_features_tiny.tsv")), SafeFileWrapper(tests_module.test_file("test_dense_features_tiny.tsv")), eval_file=None, field_names=["remapped_label", "slots", "remapped_text", "dense"], column_mapping={"remapped_label": "label", "remapped_text": "text"}, schema={"text": str, "label": str}, ) data = list(data_source.train) self.assertEqual(10, len(data)) example = next(iter(data)) self.assertEqual(2, len(example)) self.assertEqual({"label", "text"}, set(example))
[docs] def test_read_data_source_with_utf8_issues(self): schema = {"text": str, "label": str} data_source = TSVDataSource.from_config( TSVDataSource.Config( train_filename=tests_module.test_file("test_utf8_errors.tsv"), field_names=["label", "text"], ), schema, ) list(data_source.train)
[docs]class SessionTSVDataSourceTest(unittest.TestCase):
[docs] def setUp(self): self.data = SessionTSVDataSource( SafeFileWrapper(tests_module.test_file("seq_tagging_example.tsv")), field_names=["session_id", "intent", "goals", "label"], schema={"intent": List[str], "goals": List[str], "label": List[str]}, )
[docs] def test_read_session_data(self): self.assertEqual(3, len(list(self.data.train))) # validate multiple iteration self.assertEqual(3, len(list(self.data.train))) it = iter(self.data.train) example = next(it) self.assertEqual(4, len(example)) self.assertEqual("id1", example["session_id"]) self.assertEqual(["int11", "int12"], example["intent"]) self.assertEqual(["g11", "g12"], example["goals"]) self.assertEqual(["0", "0"], example["label"]) example = next(it) example = next(it) self.assertEqual("id3", example["session_id"]) self.assertEqual(["int31", "int32", "int33"], example["intent"]) self.assertEqual(["g31", "g32", "g33"], example["goals"]) self.assertEqual(["0", "1", "1"], example["label"])
[docs]class BlockShardedTSVDataSourceTest(unittest.TestCase):
[docs] def test_quoting(self): """ The text column of the first row of this file opens a quote but does not close it. """ data_source = BlockShardedTSVDataSource( train_file=SafeFileWrapper(tests_module.test_file("test_tsv_quoting.tsv")), test_file=None, eval_file=None, field_names=["label", "text"], schema={"text": str, "label": str}, ) data = list(data_source.train_unsharded) self.assertEqual(4, len(data)) data = list(data_source.train) self.assertEqual(4, len(data))
[docs] def test_bad_quoting(self): """ The text column of the first row of this file opens a quote but does not close it. """ data_source = BlockShardedTSVDataSource( train_file=SafeFileWrapper(tests_module.test_file("test_tsv_quoting.tsv")), test_file=None, eval_file=None, field_names=["label", "text"], schema={"text": str, "label": str}, quoted=True, ) data = list(data_source.train_unsharded) self.assertEqual(1, len(data)) data = list(data_source.train) self.assertEqual(1, len(data))