#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, Optional
from pytext.common.constants import BatchContext, Stage
from pytext.config.component import Component, ComponentType, create_component
from pytext.data import (
BaseBatchSampler,
Data,
EvalBatchSampler,
RoundRobinBatchSampler,
generator_iterator,
)
from pytext.data.data import BatchData
[docs]class DisjointMultitaskData(Data):
"""
Wrapper for doing multitask training using multiple data objects.
Takes a dictionary of data objects, does round robin over their
iterators using BatchSampler.
Args:
config (Config): Configuration object of type DisjointMultitaskData.Config.
data_dict (Dict[str, Data]): Data objects to do roundrobin over.
*args (type): Extra arguments to be passed down to sub data handlers.
**kwargs (type): Extra arguments to be passed down to sub data handlers.
Attributes:
data_dict (type): Data handlers to do roundrobin over.
"""
[docs] class Config(Component.Config):
sampler: BaseBatchSampler.Config = RoundRobinBatchSampler.Config()
test_key: Optional[str] = None
[docs] @classmethod
def from_config(
cls,
config: Config,
data_dict: Dict[str, Data],
task_key: str = BatchContext.TASK_NAME,
rank=0,
world_size=1,
init_tensorizers=True,
):
samplers = {
Stage.TRAIN: create_component(ComponentType.BATCH_SAMPLER, config.sampler),
Stage.EVAL: EvalBatchSampler(),
Stage.TEST: EvalBatchSampler(),
}
return cls(data_dict, samplers, config.test_key, task_key)
def __init__(
self,
data_dict: Dict[str, Data],
samplers: Dict[Stage, BaseBatchSampler],
test_key: str = None,
task_key: str = BatchContext.TASK_NAME,
) -> None:
self.test_key = test_key or list(data_dict)[0]
# currently the way training is set up is that, the data object needs
# to specify a data_source which is used at test time. For multitask
# this is set to the data_source associated with the test_key
data_source = data_dict[self.test_key].data_source
tensorizers = data_dict[self.test_key].tensorizers
super().__init__(data_source, tensorizers)
self.data_dict = data_dict
self.samplers = samplers
self.task_key = task_key
[docs] @generator_iterator
def batches(self, stage: Stage, data_source=None, load_early=False):
"""Yield batches from each task, sampled according to a given sampler.
This batcher additionally exposes a task name in the batch to allow the model
to filter examples to the appropriate tasks."""
if data_source is not None:
# means being called in test workflow
for batch in self.data_dict[self.test_key].batches(
stage, data_source, load_early
):
yield batch
else:
all_batches = {
name: task.batches(stage, load_early=load_early)
for name, task in self.data_dict.items()
}
sampled_batches = self.samplers[stage].batchify(all_batches)
for name, (raw_batch, batch) in sampled_batches:
batch[self.task_key] = name
yield BatchData(raw_batch, batch)