Source code for pytext.data.sources.data_source

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import json
import logging
import re
from typing import Any, Dict, List, Type, TypeVar

from pytext.config.component import Component, ComponentType
from pytext.data.utils import shard
from pytext.utils.data import Slot, parse_slot_string
from pytext.utils.file_io import PathManager


Schema = Dict[str, Type]


[docs]class RawExample(dict): """A wrapper class for a single example row with a dict interface. This is here for any logic we want row objects to have that dicts don't do.""" def __hash__(self): """ Makes examples hashable. This allows them to be cached. Currently this assumes keys (column names) are already hashable, and values are either dicts, lists or a type that is already hashable. """ def list_iter(l): for item in l: if isinstance(item, dict): yield frozenset(dict_iter(item)) elif isinstance(item, list): yield tuple(list_iter(item)) else: yield item def dict_iter(d): for kvp in d.items(): if isinstance(kvp[1], dict): yield (kvp[0], frozenset(dict_iter(kvp[1]))) elif isinstance(kvp[1], list): yield (kvp[0], tuple(list_iter(kvp[1]))) else: yield kvp return hash(frozenset(dict_iter(self)))
[docs]class SafeFileWrapper: """ A simple wrapper class for files which allows filedescriptors to be managed with normal Python ref counts. Without using this, if you create a file in a from_config you will see a warning along the lines of "ResourceWarning: self._file is acquired but not always released" this is because we're opening a file not in a context manager (with statement). We want to do it this way because it lets us pass a file object to the DataSource, rather than a filename. This exposes a ton more flexibility and testability, passing filenames is one of the paths towards pain. However, we don't have a clear resource management system set up for configuration. from_config functions are the tool that we have to allow objects to specify how they should be created from a configuration, which generally should only happen from the command line, whereas in eg. a notebook you should build the objects with constructors directly. If building from constructors, you can just open a file and pass it, but from_config here needs to create a file object from a configured filename. Python files don't close automatically, so you also need a system that will close them when the python interpreter shuts down. If you don't, it will print a resource warning at runtime, as the interpreter manually closes the filehandles (although modern OSs are pretty okay with having open file handles, it's hard for me to justify exactly why Python is so strict about this; I think one of the main reasons you might actually care is if you have a writeable file handle it might not have flushed properly when the C runtime exits, but Python doesn't actually distinguish between writeable and non-writeable file handles). This class is a wrapper that creates a system for (sort-of) safely closing the file handles before the runtime exits. It does this by closing the file when the object's deleter is called. Although the python standard doesn't actually make any guarantees about when deleters are called, CPython is reference counted and so as an mplementation detail will call a deleter whenever the last reference to it is removed, which generally will happen to all objects created during program execution as long as there aren't reference cycles (I don't actually know off-hand whether the cycle collection is run before shutdown, and anyway the cycles would have to include objects that the runtime itself maintains pointers to, which seems like you'd have to work hard to do and wouldn't do accidentally). This isn't true for other python systems like PyPy or Jython which use generational garbage collection and so don't actually always call destructors before the system shuts down, but again this is only really relevant for mutable files. An alternative implementation would be to build a resource management system into PyText, something like a function that we use for opening system resources that registers the resources and then we make sure are all closed before system shutdown. That would probably technically be the right solution, but I didn't really think of that first and also it's a bit longer to implement. If you are seeing resource warnings on your system, please file a github issue. """ def __init__(self, *args, **kwargs): self._file = PathManager.open(*args, **kwargs) def __del__(self): self._file.close() def __iter__(self): """Some file utilities check hasattr(o, "__iter__") explicitly.""" return iter(self._file) def __getattr__(self, attr): return getattr(self._file, attr)
[docs]class GeneratorIterator: """Create an object which can be iterated over multiple times from a generator call. Each iteration will call the generator and allow iterating over it. This is unsafe to use on generators which have side effects, such as file readers; it's up to the callers to safely manage these scenarios. """ def __init__(self, generator, *args, **kwargs): self.generator = generator self.args = args self.kwargs = kwargs def __iter__(self): return self.generator(*self.args, **self.kwargs)
[docs]class GeneratorMethodProperty: """Identify a generator method as a property. This will allow instances to iterate over the property multiple times, and not consume the generator. It accomplishes this by wrapping the generator and creating multiple generator instances if iterated over multiple times. """ def __init__(self, generator): self.generator = generator def __get__(self, obj, objtype=None): return GeneratorIterator(self.generator, obj)
# Use the more typical property decorator style generator_property = GeneratorMethodProperty
[docs]class DataSource(Component): """ Data sources are simple components that stream data from somewhere using Python's iteration interface. It should expose 3 iterators, "train", "test", and "eval". Each of these should be able to be iterated over any number of times, and iterating over it should yield dictionaries whose values are deserialized python types. Simply, these data sources exist as an interface to read through datasets in a pythonic way, with pythonic types, and abstract away the form that they are stored in. """ __COMPONENT_TYPE__ = ComponentType.DATA_SOURCE __EXPANSIBLE__ = True def __init__(self, schema: Schema): self.schema = schema
[docs] @generator_property def train(self): raise NotImplementedError
[docs] @generator_property def test(self): raise NotImplementedError
[docs] @generator_property def eval(self): raise NotImplementedError
[docs]class ShardedDataSource(DataSource): """Base class for sharded data sources."""
[docs]class RowShardedDataSource(ShardedDataSource): "Shards a given datasource by row." def __init__(self, data_source: DataSource, rank=0, world_size=1): super().__init__(data_source.schema) self.data_source = data_source self.rank = rank self.world_size = world_size self.eval = data_source.eval self.test = data_source.test
[docs] @generator_property def train(self): return shard(iter(self.data_source.train), self.rank, self.world_size)
[docs] @generator_property def train_unsharded(self): """Used to initialize tensorizer on the intire dataset.""" return iter(self.data_source.train)
[docs]class RootDataSource(DataSource): """A data source which actually loads data from a location. This data source needs to be responsible for converting types based on a schema, because it should be the only part of the system that actually needs to understand details about the underlying storage system. RootDataSource presents a simpler abstraction than DataSource where the rows are automatically converted to the right DataTypes. A RootDataSource should implement `raw_train_data_generator`, `raw_test_data_generator`, and `raw_eval_data_generator`. These functions should yield dictionaries of raw objects which the loading system can convert using the schema loading functions. """ DATA_SOURCE_TYPES = {}
[docs] class Config(Component.Config): #: An optional column mapping, allowing the columns in the raw data source #: to not map directly to the column names in the schema. This mapping will #: remap names from the raw data source to names in the schema. column_mapping: Dict[str, str] = {}
def __init__(self, schema: Schema, column_mapping: Dict[str, str] = ()): super().__init__(schema) self.column_mapping = dict(column_mapping) def _read_example(self, row): example = RawExample() for column_name, value in row.items(): name = self.column_mapping.get(column_name, column_name) if name in self.schema: example[name] = self.load(value, self.schema[name]) else: continue if len(example) != len(self.schema): # We might need to re-evaluate this for multi-task training logging.warning( "Skipping row missing values: row {} -> schema {}".format( list(row.keys()), list(self.schema.keys()) ) ) return None return example def _convert_raw_source(self, source): """Convert a raw iterable source, ie. from `DataSource.raw_train_data_generator`, to an iterable that will yield `pytext.data.type.DataType` objects according to the schema and the converters for this DataSource. """ for row in source: example = self._read_example(row) if example is None: continue yield example
[docs] @classmethod def register_type(cls, type): # Make sure we don't accidentally use RootDataSource's registry if "DATA_SOURCE_TYPES" not in vars(cls): cls.DATA_SOURCE_TYPES = {} def decorator(fn): cls.DATA_SOURCE_TYPES[type] = fn return fn return decorator
[docs] def load(self, value, schema_type): for cls in type(self).__mro__: if schema_type in getattr(cls, "DATA_SOURCE_TYPES", {}): converter = cls.DATA_SOURCE_TYPES[schema_type] return converter(value) else: raise Exception(f'Type not registered in data source: "{schema_type}"')
[docs] def raw_train_data_generator(self): """ Returns a generator that yields the TRAIN data one item at a time in a dictionary where each key is a field and the value is of the raw type from the source. DataSources need to implement this. """ raise NotImplementedError
[docs] def raw_test_data_generator(self): """ Returns a generator that yields the TEST data one item at a time in a dictionary where each key is a field and the value is of the raw type from the source. DataSources need to implement this. """ raise NotImplementedError
[docs] def raw_eval_data_generator(self): """ Returns a generator that yields the EVAL data one item at a time in a dictionary where each key is a field and the value is of the raw type from the source. DataSources need to implement this. """ raise NotImplementedError
[docs] @generator_property def train(self): return self._convert_raw_source(self.raw_train_data_generator())
[docs] @generator_property def test(self): return self._convert_raw_source(self.raw_test_data_generator())
[docs] @generator_property def eval(self): return self._convert_raw_source(self.raw_eval_data_generator())
[docs]@RootDataSource.register_type(Any) @RootDataSource.register_type(str) def load_text(s): return s
[docs]@RootDataSource.register_type(List[Slot]) def load_slots(s): return parse_slot_string(s)
Gazetteer = List[Dict[str, Dict[str, float]]] JSONString = TypeVar("JSONString", str, bytes)
[docs]@RootDataSource.register_type(Gazetteer) @RootDataSource.register_type(List[str]) @RootDataSource.register_type(List[int]) def load_json(s): if isinstance(s, List) and all(isinstance(x, (str, int)) for x in s): return s return json.loads(s)
[docs]@RootDataSource.register_type(List[float]) def load_float_list(s): if isinstance(s, List) and all(isinstance(x, float) for x in s): return s # replace spaces between float numbers with commas (regex101.com/r/C2705x/1) processed = re.sub(r"(?<=[\d.])\s*,?\s+(?=[+-]?[\d.])", ",", s) # remove dot not followed with a digit (regex101.com/r/goSmuG/1/) processed = re.sub(r"(?<=\d)\.(?![\d])", "", processed) try: parsed = json.loads(processed) except json.decoder.JSONDecodeError as e: raise ValueError( f"Unable to parse float list `{s}` (normalized to `{processed}`)" ) from e if not isinstance(parsed, list): raise ValueError(f"Expected float list for float feature, got {parsed}") return [float(f) for f in parsed]
[docs]@RootDataSource.register_type(JSONString) def load_json_string(s): parsed = json.loads(s) if not isinstance(parsed, str): raise TypeError( "Expected input to be parsed into a string object. " + f"Got {type(parsed)} type.\n" + f"Original: <<{s}>>, Parsed: <<{parsed}>>" ) return parsed
[docs]@RootDataSource.register_type(float) def load_float(f): return float(f)
[docs]@RootDataSource.register_type(int) def load_int(x): return int(x)