#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from collections.abc import Iterator
from typing import Dict, Optional
import numpy as np
from pytext.config.component import Component, ComponentType
[docs]class BaseBatchSampler(Component):
__COMPONENT_TYPE__ = ComponentType.BATCH_SAMPLER
__EXPANSIBLE__ = True
[docs] @classmethod
def from_config(cls, config: Component.Config):
return cls()
def __init__(self):
pass
[docs] def batchify(self, iterators: Dict[str, Iterator]):
pass
[docs]class EvalBatchSampler(BaseBatchSampler):
"""
This sampler takes in a dictionary of Iterators and returns batches
associated with each key in the dictionary. It guarentees that we will see
each batch associated with each key exactly once in the epoch.
Example:
Iterator 1: [A, B, C, D], Iterator 2: [a, b]
Output: [A, B, C, D, a, b]
"""
[docs] def batchify(self, iterators: Dict[str, Iterator]):
"""
Loop through each key in the input dict and generate batches from
the iterator associated with that key.
Args:
iterators: Dictionary of iterators
"""
iter_dict = {name: iter(iterator) for name, iterator in iterators.items()}
for name, it in iter_dict.items():
for item in it:
yield name, item
[docs]class RoundRobinBatchSampler(BaseBatchSampler):
"""
This sampler takes a dictionary of Iterators and returns batches in a round
robin fashion till a the end of one of the iterators is reached. The end
is specified by `iter_to_set_epoch`.
If `iter_to_set_epoch` is set, cycle batches from each iterator until one
epoch of the target iterator is fulfilled. Iterators with fewer batches
than the target iterator are repeated, so they never run out.
If `iter_to_set_epoch` is None, cycle over batches from each iterator until the
shortest iterator completes one epoch.
Example:
Iterator 1: [A, B, C, D], Iterator 2: [a, b]
iter_to_set_epoch = "Iterator 1"
Output: [A, a, B, b, C, a, D, b]
iter_to_set_epoch = None
Output: [A, a, B, b]
Args:
iter_to_set_epoch (Optional[str]): Name of iterator to define epoch size.
If this is not set, epoch size defaults to the length of
the shortest iterator.
"""
__COMPONENT_TYPE__ = ComponentType.BATCH_SAMPLER
[docs] class Config(Component.Config):
iter_to_set_epoch: str = ""
[docs] @classmethod
def from_config(cls, config: Config):
return cls(config.iter_to_set_epoch)
def __init__(self, iter_to_set_epoch: Optional[str] = None) -> None:
self.iter_to_set_epoch = iter_to_set_epoch
[docs] def batchify(self, iterators: Dict[str, Iterator]):
"""
Loop through each key in the input dict and generate batches from
the iterator associated with that key until the target iterator reaches
its end.
Args:
iterators: Dictionary of iterators
"""
iter_dict = {name: iter(iterator) for name, iterator in iterators.items()}
while True:
for name, it in iter_dict.items():
try:
yield name, next(it)
except StopIteration:
new_iter = iter(iterators[name])
iter_dict[name] = new_iter
if (not self.iter_to_set_epoch) or name == self.iter_to_set_epoch:
self.iter_to_set_epoch = name
# end of epoch
return
else:
yield name, next(new_iter)
[docs]def select_key_and_batch(
iterator_names: Dict[str, str],
iterator_probs: Dict[str, float],
iter_dict: Dict[str, Iterator],
iterators: Dict[str, Iterator],
):
"""Helper function for RandomizedBatchSampler and AlternatingRandomizedBatchSampler
to select a key from iterator_names using iterator_probs and return a batch
for the selected key using iter_dict and iterators.
"""
# Select a candidate iterator using the uniform distribtion
selected_key = np.random.choice(iterator_names, p=iterator_probs)
try:
batch = next(iter_dict[selected_key])
except StopIteration:
iter_dict[selected_key] = iter(iterators[selected_key])
batch = next(iter_dict[selected_key])
return selected_key, batch
[docs]class RandomizedBatchSampler(BaseBatchSampler):
"""
This sampler takes in a dictionary of iterators and returns batches according
to the specified probabilities by `unnormalized_iterator_probs`. We cycle through
the iterators (restarting any that "run out") indefinitely. Set batches_per_epoch
in Trainer.Config.
Example:
Iterator A: [A, B, C, D], Iterator B: [a, b]
batches_per_epoch = 3, unnormalized_iterator_probs = {"A": 0, "B": 1}
Epoch 1 = [a, b, a]
Epoch 2 = [b, a, b]
Args:
unnormalized_iterator_probs (Dict[str, float]): Iterator sampling probabilities.
The keys should be the same as the keys of the underlying iterators, and the
values will be normalized to sum to 1.
"""
__COMPONENT_TYPE__ = ComponentType.BATCH_SAMPLER
[docs] class Config(Component.Config):
unnormalized_iterator_probs: Dict[str, float]
[docs] @classmethod
def from_config(cls, config: Config):
return cls(config.unnormalized_iterator_probs)
def __init__(self, unnormalized_iterator_probs: Dict[str, float]) -> None:
self.iterator_names, self.iterator_probs = extract_iterator_properties(
unnormalized_iterator_probs
)
# Note: we need to make `iter_dict` an instance attribute so that it persists
# across calls to `batchify()`. This way subsequent epochs will continue from
# previous states of the iterators (instead of recreating them).
self.iter_dict = None
[docs] def batchify(self, iterators: Dict[str, Iterator]):
assert set(iterators) == set(self.iterator_names)
if self.iter_dict is None:
self.iter_dict = {
name: iter(iterator) for name, iterator in iterators.items()
}
num_batches = 0
while True:
selected_key, batch = select_key_and_batch(
self.iterator_names, self.iterator_probs, self.iter_dict, iterators
)
num_batches += 1
yield selected_key, batch
[docs]class AlternatingRandomizedBatchSampler(RandomizedBatchSampler):
"""
This sampler takes in a dictionary of iterators and returns batches alternating
between keys and probabilities specified by `unnormalized_iterator_probs` and
'second_unnormalized_iterator_probs', This is used for example in XLM
pre-training where we alternate between MLM and TLM batches.
"""
__COMPONENT_TYPE__ = ComponentType.BATCH_SAMPLER
[docs] class Config(Component.Config):
unnormalized_iterator_probs: Dict[str, float]
second_unnormalized_iterator_probs: Dict[str, float]
[docs] @classmethod
def from_config(cls, config: Config):
assert (
len(config.unnormalized_iterator_probs) > 0
and len(config.second_unnormalized_iterator_probs) > 0
)
return cls(
unnormalized_iterator_probs=config.unnormalized_iterator_probs,
second_unnormalized_iterator_probs=(
config.second_unnormalized_iterator_probs
),
)
def __init__(
self,
unnormalized_iterator_probs: Dict[str, float],
second_unnormalized_iterator_probs: Dict[str, float],
) -> None:
super().__init__(unnormalized_iterator_probs)
(
self.second_iterator_names,
self.second_iterator_probs,
) = extract_iterator_properties(second_unnormalized_iterator_probs)
self.is_secondary_turn = False
[docs] def batchify(self, iterators: Dict[str, Iterator]):
assert set(iterators) == set(self.iterator_names).union(
set(self.second_iterator_names)
)
if self.iter_dict is None:
self.iter_dict = {
name: iter(iterator) for name, iterator in iterators.items()
}
while True:
curr_iter = (
self.second_iterator_names
if self.is_secondary_turn
else self.iterator_names
)
curr_probs = (
self.second_iterator_probs
if self.is_secondary_turn
else self.iterator_probs
)
selected_key, batch = select_key_and_batch(
curr_iter, curr_probs, self.iter_dict, iterators
)
self.is_secondary_turn = not self.is_secondary_turn
yield selected_key, batch
[docs]class NaturalBatchSampler(RandomizedBatchSampler):
"""
This sampler iterates over all the datasets, sampling according to
the weighted number of samples in each dataset.
"""
[docs] class Config(Component.Config):
dataset_counts: Dict[str, int] = {}
[docs] @classmethod
def from_config(cls, config: Config):
return cls(config.dataset_counts)
def __init__(self, dataset_counts: Dict[str, int]) -> None:
self.dataset_counts = dataset_counts
# Note: we need to make `iter_dict` an instance attribute so that it persists
# across calls to `batchify()`. This way subsequent epochs will continue from
# previous states of the iterators (instead of recreating them).
self.iter_dict = None
self.iterator_names, self.iterator_probs = None, None
def _dataset_count(self, iterator: Iterator) -> int:
"""
Count number of items in a dataset
"""
return sum(len(raw_batch) for raw_batch, batch in iter(iterator))
def _initialize_iterate_probs_from_counts(self, iterators: Dict[str, Iterator]):
for iterator_name, iterator in iterators.items():
if iterator_name not in self.dataset_counts:
self.dataset_counts[iterator_name] = self._dataset_count(iterator)
print(self.dataset_counts)
print("Total data points:", sum(self.dataset_counts.values()))
self.iterator_names, self.iterator_probs = extract_iterator_properties(
self.dataset_counts
)
[docs] def batchify(self, iterators: Dict[str, Iterator]):
if self.iterator_names is None or self.iterator_probs is None:
self._initialize_iterate_probs_from_counts(iterators=iterators)
for batch in super().batchify(iterators=iterators):
yield batch