Source code for pytext.data.dynamic_pooling_batcher

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import math
from typing import Iterable

from pytext.common.constants import Stage
from pytext.config.module_config import ModuleConfig

from .data import PoolingBatcher
from .sources import RawExample


[docs]class BatcherSchedulerConfig(ModuleConfig): # the initial batch size used for training, this is per node start_batch_size: int = 32 # the final or max batch size to use, any scheduler should # not go over this batch size, this is per node end_batch_size: int = 256 # the number of epochs to increase the batch size over epoch_period: int = 10 # the batch size is kept constant for `step_size` number of epochs step_size: int = 1
[docs]class ExponentialBatcherSchedulerConfig(BatcherSchedulerConfig): # the gamma to increase batch size by gamma: float = 5
[docs]class DynamicPoolingBatcher(PoolingBatcher): """ Allows dynamic batch training, extends pooling batcher with a scheduler config, which specifies how batch size should increase """
[docs] class Config(PoolingBatcher.Config): scheduler_config: BatcherSchedulerConfig = BatcherSchedulerConfig()
[docs] @classmethod def from_config(cls, config: Config): return cls( config.train_batch_size, config.eval_batch_size, config.test_batch_size, config.pool_num_batches, config.num_shuffled_pools, config.scheduler_config, )
def __init__( self, train_batch_size=Config.train_batch_size, eval_batch_size=Config.eval_batch_size, test_batch_size=Config.test_batch_size, pool_num_batches=Config.pool_num_batches, num_shuffled_pools=Config.num_shuffled_pools, scheduler_config=Config.scheduler_config, ): super().__init__( train_batch_size, eval_batch_size, test_batch_size, pool_num_batches, num_shuffled_pools, ) self.scheduler_config = scheduler_config self.curr_batch_size: int = 0 self.curr_epoch: int = -1 self.step_epoch()
[docs] def compute_dynamic_batch_size( self, curr_epoch: int, scheduler_config: BatcherSchedulerConfig, curr_steps: int ) -> int: raise NotImplementedError()
[docs] def step_epoch(self): self.curr_epoch += 1 if self.curr_epoch > self.scheduler_config.epoch_period: # if past the dynamic period just return # the batch size self.curr_batch_size = self.scheduler_config.end_batch_size else: if self.curr_epoch % self.scheduler_config.step_size == 0: old_size = self.curr_batch_size new_size = self.compute_dynamic_batch_size( curr_epoch=self.curr_epoch, scheduler_config=self.scheduler_config, curr_steps=self.curr_epoch // self.scheduler_config.step_size, ) print(f"increasing batch size from {old_size} to {new_size}") self.curr_batch_size = new_size
[docs] def finished_dynamic(self) -> bool: return ( self.scheduler_config.epoch_period != -1 and self.curr_epoch >= self.scheduler_config.epoch_period )
[docs] def get_batch_size(self, stage: Stage) -> int: if stage == Stage.TRAIN: print(f"using dynamic batch size {self.curr_batch_size}") return self.curr_batch_size else: return self._batch_sizes[stage]
[docs] def batchify( self, iterable: Iterable[RawExample], sort_key=None, stage=Stage.TRAIN ): """ From an iterable of dicts, yield dicts of lists: 1. Load `num_shuffled_pools` pools of data, and shuffle them. 2. Load a pool (`batch_size * pool_num_batches` examples). 3. Sort rows, if necessary. 4. Shuffle the order in which the batches are returned, if necessary. """ for item in super().batchify(iterable=iterable, sort_key=sort_key, stage=stage): yield item if stage == Stage.TRAIN: # only step scheduler when in train self.step_epoch()
[docs]class LinearDynamicPoolingBatcher(DynamicPoolingBatcher): """ Linear Dynamic Batch Scheduler: scales up batch size linearly """
[docs] def compute_dynamic_batch_size( self, curr_epoch: int, scheduler_config: BatcherSchedulerConfig, curr_steps: int ) -> int: batch_delta = ( scheduler_config.end_batch_size - scheduler_config.start_batch_size ) curr_est: float = ( batch_delta / (scheduler_config.epoch_period / scheduler_config.step_size) ) * curr_steps + scheduler_config.start_batch_size return math.ceil(curr_est)
[docs]class ExponentialDynamicPoolingBatcher(DynamicPoolingBatcher): """ Exponential Dynamic Batch Scheduler: scales up batch size by a factor of gamma """
[docs] class Config(DynamicPoolingBatcher.Config): scheduler_config: ExponentialBatcherSchedulerConfig
def __init__(self, *args, **kwargs): self.max_steps = None super().__init__(*args, **kwargs)
[docs] def get_max_steps(self): if self.max_steps: return self.max_steps self.max_steps: float = math.floor( math.log( self.scheduler_config.end_batch_size / self.scheduler_config.start_batch_size ) / math.log(self.scheduler_config.gamma) ) return self.max_steps
[docs] def finished_dynamic(self) -> bool: return ( self.scheduler_config.epoch_period != -1 and self.curr_epoch >= self.get_max_steps() )
[docs] def compute_dynamic_batch_size( self, curr_epoch: int, scheduler_config: ExponentialBatcherSchedulerConfig, curr_steps: int, ) -> int: if curr_steps > self.get_max_steps(): return scheduler_config.end_batch_size curr_est: float = scheduler_config.start_batch_size * math.pow( scheduler_config.gamma, curr_steps ) return min(math.ceil(curr_est), scheduler_config.end_batch_size)