Source code for pytext.data.sources.tsv

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import csv
import logging
import sys
import threading
from typing import Dict, List, Optional, Type

from pytext.utils.path import get_absolute_path

from .data_source import (
    RootDataSource,
    SafeFileWrapper,
    ShardedDataSource,
    generator_property,
)
from .session import SessionDataSource


[docs]class TSV: def __init__( self, file, field_names=None, delimiter="\t", quoted=False, drop_incomplete_rows=False, ): self.file = file self.field_names = field_names self.delimiter = delimiter self.quoted = quoted self.drop_incomplete_rows = drop_incomplete_rows self.total_rows_count = 0 self.incomplete_rows_count = 0 self._access_lock = threading.Lock() csv.field_size_limit(sys.maxsize) def __iter__(self): can_acquire = self._access_lock.acquire(blocking=False) if not can_acquire: raise Exception("Concurrent iteration not supported") self.file.seek(0) try: reader = csv.DictReader( (line.replace("\0", "") for line in self.file), fieldnames=self.field_names, delimiter=self.delimiter, escapechar="\\", quoting=csv.QUOTE_MINIMAL if self.quoted else csv.QUOTE_NONE, ) if self.drop_incomplete_rows: for row in reader: self.total_rows_count += 1 if any(map(lambda v: v is None, row.values())): # drop! self.incomplete_rows_count += 1 continue yield row else: yield from reader finally: self._access_lock.release() def __del__(self): logging.debug("Destroying TSV object") logging.debug(f"Total number of rows read: {self.total_rows_count}") logging.debug(f"Total number of rows dropped: {self.incomplete_rows_count}")
[docs]class TSVDataSource(RootDataSource): """DataSource which loads data from TSV sources. Uses python's csv library."""
[docs] class Config(RootDataSource.Config): #: Filename of training set. If not set, iteration will be empty. train_filename: Optional[str] = None #: Filename of testing set. If not set, iteration will be empty. test_filename: Optional[str] = None #: Filename of eval set. If not set, iteration will be empty. eval_filename: Optional[str] = None #: Field names for the TSV. If this is not set, the first line of each file #: will be assumed to be a header containing the field names. field_names: Optional[List[str]] = None #: The column delimiter passed to Python's csv library. Change to "," for csv. delimiter: str = "\t" #: Whether the columns can use quotes to include delimiters or not. #: Rows with unclosed quotes will be merged with \n inside. #: Change to True for quoted csv. quoted: bool = False # Flag to turn on dropping rows with columns less than the expected # number of columns. This prevents passing None/null as column values # down to the tensorizer. drop_incomplete_rows: bool = False
[docs] @classmethod def from_config(cls, config: Config, schema: Dict[str, Type], **kwargs): args = config._asdict() train_filename = args.pop("train_filename") test_filename = args.pop("test_filename") eval_filename = args.pop("eval_filename") train_file = ( SafeFileWrapper( get_absolute_path(train_filename), encoding="utf-8", errors="replace" ) if train_filename else None ) test_file = ( SafeFileWrapper( get_absolute_path(test_filename), encoding="utf-8", errors="replace" ) if test_filename else None ) eval_file = ( SafeFileWrapper( get_absolute_path(eval_filename), encoding="utf-8", errors="replace" ) if eval_filename else None ) return cls( train_file=train_file, test_file=test_file, eval_file=eval_file, schema=schema, **args, **kwargs, )
def __init__( self, train_file=None, test_file=None, eval_file=None, field_names=None, delimiter=Config.delimiter, quoted=Config.quoted, drop_incomplete_rows=Config.drop_incomplete_rows, **kwargs, ): super().__init__(**kwargs) self._init_tsv( field_names, delimiter, train_file, test_file, eval_file, quoted, drop_incomplete_rows, ) def _init_tsv( self, field_names, delimiter, train_file, test_file, eval_file, quoted, drop_incomplete_rows, ): def make_tsv(file): return TSV( file, field_names=field_names, delimiter=delimiter, quoted=quoted, drop_incomplete_rows=drop_incomplete_rows, ) self._train_tsv = make_tsv(train_file) if train_file else [] self._test_tsv = make_tsv(test_file) if test_file else [] self._eval_tsv = make_tsv(eval_file) if eval_file else []
[docs] def raw_train_data_generator(self): return iter(self._train_tsv)
[docs] def raw_test_data_generator(self): return iter(self._test_tsv)
[docs] def raw_eval_data_generator(self): return iter(self._eval_tsv)
[docs]class MultilingualTSVDataSource(TSVDataSource): """ Data Source for multi-lingual data. The input data can have multiple text fields and each field can either have the same language or different languages. The `data_source_languages` dict contains the language information for each text field and this should match the number of language identifiers specified in `language_columns`. """
[docs] class Config(TSVDataSource.Config): data_source_languages: Dict[str, List[str]] = { "train": ["en"], "eval": ["en"], "test": ["en"], } language_columns: List[str] = ["language"]
def __init__( self, train_file=None, test_file=None, eval_file=None, field_names=None, delimiter=Config.delimiter, data_source_languages=Config.data_source_languages, language_columns=Config.language_columns, **kwargs, ): super().__init__( train_file, test_file, eval_file, field_names, delimiter, **kwargs ) self.data_source_languages = data_source_languages self.language_columns = language_columns assert len(data_source_languages["train"]) == len( self.language_columns ), "Number of languages and language columns should be the same." def _convert_raw_source(self, source, languages): for row in source: example = self._read_example(row) if example is None: continue for col, lang in zip(self.language_columns, languages): example[col] = lang yield example @generator_property def train(self): return self._convert_raw_source( self.raw_train_data_generator(), self.data_source_languages["train"] ) @generator_property def test(self): return self._convert_raw_source( self.raw_test_data_generator(), self.data_source_languages["test"] ) @generator_property def eval(self): return self._convert_raw_source( self.raw_eval_data_generator(), self.data_source_languages["eval"] )
[docs]class BlockShardedTSV: """Take a TSV file, split into N pieces (by byte location) and return an iterator on one of the pieces. The pieces are equal by byte size, not by number of rows. Thus, care needs to be taken when using this for distributed training, otherwise number of batches for different workers might be different. """ def __init__( self, file, field_names=None, delimiter="\t", quoted=False, block_id=0, num_blocks=1, drop_incomplete_rows=False, ): self.file = file self.field_names = field_names self.delimiter = delimiter self.quoted = quoted self.block_id = block_id self.num_blocks = num_blocks self.drop_incomplete_rows = drop_incomplete_rows csv.field_size_limit(sys.maxsize) def __iter__(self): # (self.begin, self.end) are the pointers to the begin and end # of file segment self.file.seek(0, 2) end = self.file.tell() self.begin = self.block_id * end / self.num_blocks self.end = (self.block_id + 1) * end / self.num_blocks self.file.seek(self.begin, 0) # make sure we're at the beginning of a full row if self.begin: self.file.readline() reader = csv.DictReader( (line.replace("\0", "") for line in iter(self.file.readline, "")), fieldnames=self.field_names, delimiter=self.delimiter, quoting=csv.QUOTE_MINIMAL if self.quoted else csv.QUOTE_NONE, ) # iterate until we're at the end of segment for row in reader: if self.file.tell() > self.end: break if self.drop_incomplete_rows: if any(map(lambda v: v is None, row.values())): # drop! continue yield row
[docs]class BlockShardedTSVDataSource(TSVDataSource, ShardedDataSource): def __init__(self, rank=0, world_size=1, **kwargs): self.rank = rank self.world_size = world_size # calls init of TSVDataSource super().__init__(**kwargs) # weird python syntax to call init of ShardedDataSource super(TSVDataSource, self).__init__(schema=self.schema) def _init_tsv( self, field_names, delimiter, train_file, test_file, eval_file, quoted, drop_incomplete_rows, ): def make_tsv(file, rank=0, world_size=1): return BlockShardedTSV( file, field_names=field_names, delimiter=delimiter, block_id=rank, num_blocks=world_size, quoted=quoted, drop_incomplete_rows=drop_incomplete_rows, ) self._train_tsv = ( make_tsv(train_file, self.rank, self.world_size) if train_file else [] ) self._test_tsv = make_tsv(test_file) if test_file else [] self._eval_tsv = make_tsv(eval_file) if eval_file else [] self._train_unsharded = ( TSV(train_file, field_names=field_names, delimiter=delimiter, quoted=quoted) if train_file else [] ) @generator_property def train_unsharded(self): return iter(self._train_unsharded)
[docs]class SessionTSVDataSource(TSVDataSource, SessionDataSource): def __init__( self, train_file=None, test_file=None, eval_file=None, field_names=None, **kwargs, ): # requires first column to be the session id assert len(field_names) >= 2, "should specify at least 2 columns" super().__init__( train_file=train_file, test_file=test_file, eval_file=eval_file, field_names=field_names, id_col=field_names[0], **kwargs, ) self.schema[self.id_col] = str self.field_names = field_names self._validate_schema()