Source code for pytext.data.test.dynamic_pooling_batcher_test

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import unittest
from typing import List

from pytext.data.dynamic_pooling_batcher import (
    BatcherSchedulerConfig,
    DynamicPoolingBatcher,
    ExponentialBatcherSchedulerConfig,
    ExponentialDynamicPoolingBatcher,
    LinearDynamicPoolingBatcher,
)
from pytext.data.sources import RawExample
from pytext.utils.test import import_tests_module


tests_module = import_tests_module()


[docs]class DynamicPoolingBatcherTest(unittest.TestCase): @classmethod def _get_dataset(cls, dataset_size: int) -> List[RawExample]: return [([], {})] * dataset_size
[docs] def test_linear_scheduler(self): data = DynamicPoolingBatcherTest._get_dataset(dataset_size=100) batch_scheduler_config = BatcherSchedulerConfig( start_batch_size=32, end_batch_size=256, epoch_period=5, step_size=1 ) batcher_config = DynamicPoolingBatcher.Config( train_batch_size=1, eval_batch_size=1, test_batch_size=1, pool_num_batches=1, num_shuffled_pools=1, scheduler_config=batch_scheduler_config, ) batcher = LinearDynamicPoolingBatcher.from_config(batcher_config) # epoch 1 batches = [item for item in batcher.batchify(data)] self.assertEqual(len(batches[0].raw_data), 32) # epoch 2 # new size ()(256-32) / 5) + 32 = 76.8 ~ 77 batches = [item for item in batcher.batchify(data)] self.assertEqual(len(batches[0].raw_data), 77)
[docs] def test_exponential_scheduler(self): data = DynamicPoolingBatcherTest._get_dataset(dataset_size=100) batch_scheduler_config = ExponentialBatcherSchedulerConfig( start_batch_size=32, end_batch_size=256, epoch_period=5, step_size=1, gamma=2, ) batcher_config = ExponentialDynamicPoolingBatcher.Config( train_batch_size=1, eval_batch_size=1, test_batch_size=1, pool_num_batches=1, num_shuffled_pools=1, scheduler_config=batch_scheduler_config, ) batcher = ExponentialDynamicPoolingBatcher.from_config(batcher_config) # epoch 1 batches = [item for item in batcher.batchify(data)] self.assertEqual(len(batches[0].raw_data), 32) # epoch 2 # new size 32 * 2^1 = 64 batches = [item for item in batcher.batchify(data)] self.assertEqual(len(batches[0].raw_data), 64)
[docs] def test_batch_size_greater_than_data(self): data = DynamicPoolingBatcherTest._get_dataset(dataset_size=50) batch_scheduler_config = ExponentialBatcherSchedulerConfig( start_batch_size=32, end_batch_size=256, epoch_period=5, step_size=1, gamma=2, ) batcher_config = ExponentialDynamicPoolingBatcher.Config( train_batch_size=1, eval_batch_size=1, test_batch_size=1, pool_num_batches=1, num_shuffled_pools=1, scheduler_config=batch_scheduler_config, ) batcher = ExponentialDynamicPoolingBatcher.from_config(batcher_config) # epoch 1 batches = [item for item in batcher.batchify(data)] self.assertEqual(len(batches[0].raw_data), 32) # epoch 2 # new size 32 * 2^1 = 64 / 8 = 8 batches = [item for item in batcher.batchify(data)] self.assertEqual(len(batches[0].raw_data), 50)
[docs] def end_of_scheduler(self): data = DynamicPoolingBatcherTest._get_dataset(dataset_size=300) batch_scheduler_config = ExponentialBatcherSchedulerConfig( start_batch_size=32, end_batch_size=256, epoch_period=2, step_size=4, gamma=2, ) batcher_config = ExponentialDynamicPoolingBatcher.Config( train_batch_size=1, eval_batch_size=1, test_batch_size=1, pool_num_batches=1, num_shuffled_pools=1, scheduler_config=batch_scheduler_config, ) batcher = ExponentialDynamicPoolingBatcher.from_config(batcher_config) # epoch 1 batches = [item for item in batcher.batchify(data)] self.assertEqual(len(batches[0].raw_data), 32) # pass N epochs no_op_epochs = 4 _ = [[item for item in batcher.batchify(data)] for _ in range(no_op_epochs)] # after period is passed, batch size should be max batch size batches = [item for item in batcher.batchify(data)] self.assertEqual(len(batches[0].raw_data), 256)
[docs] def test_step_size(self): data = DynamicPoolingBatcherTest._get_dataset(dataset_size=64) batch_scheduler_config = ExponentialBatcherSchedulerConfig( start_batch_size=32, end_batch_size=256, epoch_period=2, step_size=2, gamma=2, ) batcher_config = ExponentialDynamicPoolingBatcher.Config( train_batch_size=1, eval_batch_size=1, test_batch_size=1, pool_num_batches=1, num_shuffled_pools=1, scheduler_config=batch_scheduler_config, ) batcher = ExponentialDynamicPoolingBatcher.from_config(batcher_config) # epoch 1 batches = [item for item in batcher.batchify(data)] self.assertEqual(len(batches[0].raw_data), 32) # epoch 2 # no op on batch size batches = [item for item in batcher.batchify(data)] self.assertEqual(len(batches[0].raw_data), 32) # epoch 3 batches = [item for item in batcher.batchify(data)] self.assertEqual(len(batches[0].raw_data), 64) # epoch 4 # no op on batch size batches = [item for item in batcher.batchify(data)] self.assertEqual(len(batches[0].raw_data), 64)