#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, Optional, Type
from pandas import DataFrame
from pytext.data.sources.data_source import RootDataSource
from .session import SessionDataSource
[docs]class PandasDataSource(RootDataSource):
"""
DataSource which loads data from a pandas DataFrame.
Inputs:
train_df: DataFrame for training
eval_df: DataFrame for evalu
test_df: DataFrame for test
schema: same as base DataSource, define the list of output values with their types
column_mapping: maps the column names in DataFrame to the name defined in schema
"""
[docs] class Config(RootDataSource.Config):
train_df: Optional[DataFrame] = None
test_df: Optional[DataFrame] = None
eval_df: Optional[DataFrame] = None
[docs] @classmethod
def from_config(cls, config: Config, schema: Dict[str, Type]):
return cls(
train_df=config.train_df,
eval_df=config.eval_df,
test_df=config.test_df,
schema=schema,
column_mapping=config.column_mapping,
)
def __init__(
self,
train_df: Optional[DataFrame] = None,
eval_df: Optional[DataFrame] = None,
test_df: Optional[DataFrame] = None,
**kwargs
):
super().__init__(**kwargs)
self.train_df = train_df
self.eval_df = eval_df
self.test_df = test_df
[docs] @staticmethod
def raw_generator(df: Optional[DataFrame]):
if df is None:
yield from ()
else:
for _, row in df.iterrows():
yield row
[docs] def raw_train_data_generator(self):
return self.raw_generator(self.train_df)
[docs] def raw_eval_data_generator(self):
return self.raw_generator(self.eval_df)
[docs] def raw_test_data_generator(self):
return self.raw_generator(self.test_df)
[docs]class SessionPandasDataSource(PandasDataSource, SessionDataSource):
def __init__(
self,
schema: Dict[str, Type],
id_col: str,
train_df: Optional[DataFrame] = None,
eval_df: Optional[DataFrame] = None,
test_df: Optional[DataFrame] = None,
column_mapping: Dict[str, str] = (),
):
schema[id_col] = str
super().__init__(
schema=schema,
train_df=train_df,
test_df=test_df,
eval_df=eval_df,
column_mapping=column_mapping,
id_col=id_col,
)
self._validate_schema()