Source code for pytext.data.sources.session

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from itertools import chain
from typing import List

from pytext.config.serialize import _get_class_type
from pytext.data.sources.data_source import RootDataSource


[docs]class SessionDataSource(RootDataSource): """ Data source for session based data, the input data is organized in sessions, each session may have multiple rows. The first column is always the session id. Raw input rows are consolidated by session id and returned as one session per example """ def __init__(self, id_col, **kwargs): self.id_col = id_col self.current_id = None self.current_session = [] super().__init__(**kwargs) def _validate_schema(self): """Make sure the input schema are all list type, which is the return value type, and convert it to the actual type (e.g List[T] -> T) when reading the raw data from file. """ for k, v in self.schema.items(): if k != self.id_col: assert _get_class_type(v) is list, f"{k} is not a list type!" self.schema[k] = v.__args__[0]
[docs] def merge_session(self, session): res = {self.id_col: session[0][self.id_col]} for k, v in chain.from_iterable([s.items() for s in session]): if k != self.id_col: res[k] = res.get(k, []) res[k].append(v) return res
def _convert_raw_source(self, source): for row in source: example = self._read_example(row) if example is None: continue if example[self.id_col] == self.current_id: self.current_session.append(example) else: self.current_id = example[self.id_col] session = self.current_session self.current_session = [example] if session: yield self.merge_session(session) self.current_id = None session = self.current_session self.current_session = [] if session: yield self.merge_session(session)