Source code for pytext.trainers.trainer

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import itertools
import time
from contextlib import ExitStack as contextlib_ExitStack
from typing import Any, Iterable, List, Optional, Tuple

import torch
from pytext.common.constants import BatchContext, Stage
from pytext.config import PyTextConfig
from pytext.config.component import (
    Component,
    ComponentType,
    create_optimizer,
    create_privacy_engine,
    create_scheduler,
    create_sparsifier,
)
from pytext.config.pytext_config import ConfigBase
from pytext.data.data_handler import BatchIterator
from pytext.metric_reporters import MetricReporter
from pytext.models.distributed_model import DistributedModel
from pytext.models.model import Model
from pytext.optimizer import Adam, Optimizer, PrivacyEngine, learning_rates
from pytext.optimizer.fp16_optimizer import FP16Optimizer, FP16OptimizerFairseq
from pytext.optimizer.scheduler import Scheduler
from pytext.optimizer.sparsifiers.sparsifier import Sparsifier
from pytext.task.serialize import save
from pytext.trainers.training_state import TrainingState
from pytext.utils import cuda, distributed, precision, timing


[docs]class TrainerBase(Component): __COMPONENT_TYPE__ = ComponentType.TRAINER
[docs]def cycle(iterator: Iterable[Any]) -> Iterable[Any]: """Like itertools.cycle, but will call iter on the original iterable instead. This limits it to not be able to run on say raw generators, but also doesn't store a copy of the iterable in memory for repetition.""" while True: yield from iterator
[docs]def maybe_accumulate_gradients(exit_stack, model, index, sample_size): # index == sample_size - 1 represents the last backward pass if ( cuda.DISTRIBUTED_WORLD_SIZE > 1 and hasattr(model, "no_sync") and index < sample_size - 1 ): """ Whenever *samples* contains more than one mini-batch (e.g sample_size > 1), we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ exit_stack.enter_context(model.no_sync()) if precision.FP16_ENABLED and index < sample_size - 1: """ Whenever *samples* contains more than one mini-batch (e.g sample_size > 1), we want to accumulate gradients in FP16 parameters (e.g delay unscale) and only unscale to FP32 parameters after the last backward pass. """ exit_stack.enter_context(precision.delay_unscale())
[docs]class Trainer(TrainerBase): """ Base Trainer class that provide ways to 1 Train model, compute metrics against eval set and use the metrics for model selection. 2 Test trained model, compute and publish metrics against a blind test set. Attributes: epochs (int): Training epochs early_stop_after (int): Stop after how many epochs when the eval metric is not improving max_clip_norm (Optional[float]): Clip gradient norm if set report_train_metrics (bool): Whether metrics on training data should be computed and reported. target_time_limit_seconds (float): Target time limit for training in seconds. If the expected time to train another epoch exceeds this limit, stop training. """
[docs] class Config(ConfigBase): #: Training epochs epochs: int = 10 #: Stop after how many epochs when the eval metric is not improving early_stop_after: int = 0 #: Clip gradient norm if set max_clip_norm: Optional[float] = None #: Whether metrics on training data should be computed and reported. report_train_metrics: bool = True #: Target time limit for training, default (None) to no time limit. target_time_limit_seconds: Optional[int] = None #: Whether to do evaluation and model selection based on it. do_eval: bool = True #: if do_eval, do we load the best model state dict after training or just # use the latest model state load_best_model_after_train: bool = True #: Number of samples for logging training progress. num_samples_to_log_progress: int = 1000 #: Number of forward & backward per batch before update gradients, the #: actual_batch_size = batch_size x num_accumulated_batches num_accumulated_batches: int = 1 #: Define epoch as a fixed number of batches. Subsequent epochs will continue #: to iterate through the data, cycling through it when they reach the end. #: If not set, use exactly one pass through the dataset as one epoch. #: This configuration only affects the train epochs, test and eval #: will always test their entire datasets. num_batches_per_epoch: Optional[int] = None #: config for optimizer, used in parameter update optimizer: Optimizer.Config = Adam.Config() scheduler: Optional[Scheduler.Config] = None sparsifier: Optional[Sparsifier.Config] = None #: Define arguments for fp16 training. A fp16_optimizer will be created #: and wraps the original optimizer, which will scale loss during #: backward and master weight will be maintained on original optimizer. #: https://arxiv.org/abs/1710.03740 fp16_args: FP16Optimizer.Config = FP16OptimizerFairseq.Config() # PrivacyEngine related args privacy_engine: Optional[PrivacyEngine.Config] = None use_tensorboard: bool = False
def __init__(self, config: Config, model: torch.nn.Module): if config.early_stop_after > 0: assert config.do_eval, "can't do early stopping when not running evalution" if precision.FP16_ENABLED: self.optimizer: torch.optim.Optimizer = create_optimizer( config.fp16_args, model, config.optimizer, config.num_accumulated_batches, ) else: self.optimizer: torch.optim.Optimizer = create_optimizer( config.optimizer, model ) self.privacy_engine: PrivacyEngine = ( create_privacy_engine(config.privacy_engine, model, self.optimizer) if config.privacy_engine else None ) self.scheduler: torch.optim.lr_scheduler = ( create_scheduler(config.scheduler, self.optimizer) if config.scheduler else Scheduler() ) self.sparsifier: Sparsifier = ( create_sparsifier(config.sparsifier) if config.sparsifier else Sparsifier() ) self.config = config
[docs] @classmethod def from_config(cls, config: Config, model: torch.nn.Module, *args, **kwargs): return cls(config, model)
[docs] @timing.time("Trainer.test") def test(self, test_iter, model, metric_reporter: MetricReporter): state = TrainingState(stage=Stage.TEST, model=model, epoch=1) if cuda.CUDA_ENABLED: state.model.cuda() state.model.eval() with torch.no_grad(): return self.run_epoch(state, test_iter, metric_reporter)
[docs] @timing.time("pre-training") def set_up_training(self, state: TrainingState, training_data: BatchIterator): if cuda.CUDA_ENABLED: state.model.cuda() state.scheduler.prepare(training_data, self.config.epochs) if cuda.DISTRIBUTED_WORLD_SIZE > 1: device_id = torch.cuda.current_device() state.model = DistributedModel( module=state.model, device_ids=[device_id], output_device=device_id, broadcast_buffers=False, find_unused_parameters=state.model.find_unused_parameters, process_group=distributed._round_robin_process_group, ) state.start_time = time.time() if self.config.num_batches_per_epoch: # Set the training_data iterator to cycle, so it will never run out, # but rather after reaching the end will loop back to the beginning. training_data = cycle(training_data) return training_data
[docs] @timing.time("zero gradients") def zero_grads(self, state): if state.stage != Stage.TRAIN: return state.optimizer.zero_grad()
[docs] @timing.time("backprop") def backprop(self, state, loss): if state.stage != Stage.TRAIN: return with timing.time("loss.backward"): state.optimizer.backward(loss)
[docs] @timing.time("optimizer") def optimizer_step(self, state): if state.stage != Stage.TRAIN: return try: grad_norm = state.optimizer.clip_grad_norm( self.config.max_clip_norm, state.model ) except OverflowError as e: print(f"Gradient overflow. Skipping step, {e}") return None state.scheduler.step_batch() with timing.time("optimizer.step"): state.optimizer.step() state.step_counter += 1 # grad_norm could be used to check grads sync in distributed training return grad_norm
[docs] @timing.time("sparsifier") def sparsification_step(self, state): # sparsification only if sparifier is used if not self.config.sparsifier: return self.sparsifier.sparsify(state)
[docs] def continue_training(self, state: TrainingState) -> bool: # Are we done? if state.epoch >= self.config.epochs: return False # Check whether the model has improved recently enough # Only do this if we're bothering to evaluate the model if self.config.do_eval and state.epochs_since_last_improvement >= ( self.config.early_stop_after or float("inf") ): print( f"Worker {state.rank}: Eval metric hasn't changed for " + f"{state.epochs_since_last_improvement} epochs. Stopping now." ) return False # Check whether we think the next epoch will put us over the configured # time limit. epochs_run = state.epoch + 1 time_elapsed = time.time() - state.start_time mean_epoch_time = time_elapsed / epochs_run expected_next_epoch_time = time_elapsed + mean_epoch_time target_time_limit = ( float("inf") if self.config.target_time_limit_seconds is None else self.config.target_time_limit_seconds ) if expected_next_epoch_time > target_time_limit: print( f"Worker {state.rank}: Stopping training after {epochs_run} epochs " f"and {int(time_elapsed)} seconds, due to the target max training " f"time of {self.config.target_time_limit_seconds} seconds." ) return False return True
[docs] def move_state_dict_to_cpu(self, state_dict): for key, maybe_parameter in state_dict.items(): if isinstance(maybe_parameter, torch.Tensor): state_dict[key] = maybe_parameter.cpu() else: self.move_state_dict_to_cpu(maybe_parameter) return state_dict
[docs] def move_state_dict_to_gpu(self, state_dict): for key, maybe_parameter in state_dict.items(): if isinstance(maybe_parameter, torch.Tensor): state_dict[key] = maybe_parameter.cuda() else: self.move_state_dict_to_gpu(maybe_parameter) return state_dict
[docs] def update_best_model( self, state: TrainingState, train_config: PyTextConfig, eval_metric ): # This should be updated by all workers so they agree on when to stop training # when `early_stop_after` is specified. state.epochs_since_last_improvement = 0 state.best_model_metric = eval_metric print(f"Found a better model!") # Only one worker should save checkpoints # unless doing iterative pruning if state.rank != 0 and not self.sparsifier.save_model_state_for_all_rank(): return model_state = state.model.state_dict() # save to cpu to avoid multiple model copies in gpu memory if cuda.CUDA_ENABLED: self.move_state_dict_to_cpu(model_state) state.best_model_state = model_state
[docs] @timing.time("save checkpoint") def save_checkpoint(self, state: TrainingState, train_config: PyTextConfig) -> str: # Only one worker should save checkpoints if state.rank != 0: return if train_config.save_module_checkpoints or train_config.save_all_checkpoints: # saves per-epoch sub-modules when save_all_checkpoints or # save_module_checkpoints is enabled state.model.save_modules( base_path=train_config.modules_save_dir, suffix=f"-ep{state.epoch}" ) if state.epochs_since_last_improvement == 0: # state.epochs_since_last_improvement == 0 means found a better # model in current epoch, thus update best model's sub-modules state.model.save_modules(base_path=train_config.modules_save_dir) # next to add new config and implementation of frequency on checkpointing if train_config.save_all_checkpoints: return save( config=train_config, model=state.model, meta=None, tensorizers=None, training_state=state, identifier=str(state.epoch), )
[docs] def load_best_model(self, state: TrainingState): if cuda.CUDA_ENABLED: # Move current model to CPU to avoid multiple models in GPU memory state.model.cpu() state.model.load_state_dict(state.best_model_state) # Move model back to GPU state.model.cuda() else: state.model.load_state_dict(state.best_model_state)
[docs] def train( self, training_data: BatchIterator, eval_data: BatchIterator, model: Model, metric_reporter: MetricReporter, train_config: PyTextConfig, rank: int = 0, ) -> Tuple[torch.nn.Module, Any]: """ Train and eval a model, the model states will be modified. Args: train_iter (BatchIterator): batch iterator of training data eval_iter (BatchIterator): batch iterator of evaluation data model (Model): model to be trained metric_reporter (MetricReporter): compute metric based on training output and report results to console, file.. etc train_config (PyTextConfig): training config training_result (Optional): only meaningful for Hogwild training. default is None rank (int): only used in distributed training, the rank of the current training thread, evaluation will only be done in rank 0 Returns: model, best_metric: the trained model together with the best metric """ state = TrainingState( model=model, optimizer=self.optimizer, scheduler=self.scheduler, sparsifier=self.sparsifier, privacy_engine=self.privacy_engine, rank=rank, ) return self.train_from_state( state, training_data, eval_data, metric_reporter, train_config )
[docs] @timing.time("Trainer.train_from_state") def train_from_state( self, state: TrainingState, training_data: BatchIterator, eval_data: BatchIterator, metric_reporter: MetricReporter, train_config: PyTextConfig, ) -> Tuple[torch.nn.Module, Any]: """ Train and eval a model from a given training state will be modified. This function iterates epochs specified in config, and for each epoch do: 1. Train model using training data, aggregate and report training results 2. Adjust learning rate if scheduler is specified 3. Evaluate model using evaluation data 4. Calculate metrics based on evaluation results and select best model Args: training_state (TrainingState): contrains stateful information to be able to restore a training job train_iter (BatchIterator): batch iterator of training data eval_iter (BatchIterator): batch iterator of evaluation data model (Model): model to be trained metric_reporter (MetricReporter): compute metric based on training output and report results to console, file.. etc train_config (PyTextConfig): training config Returns: model, best_metric: the trained model together with the best metric """ training_data = self.set_up_training(state, training_data) model = state.model rank = state.rank trainable_params = sum( p.numel() for p in state.model.parameters() if p.requires_grad ) print(f"Model :{model}") print(f"Num trainable parameters: {trainable_params}") self.sparsifier.initialize( self, state, eval_data, metric_reporter, train_config ) while self.continue_training(state): self.sparsifier.op_pre_epoch(self, state) state.epoch += 1 state.epochs_since_last_improvement += 1 lrs = learning_rates(state.optimizer) print(f"\nWorker {state.rank} starting epoch {state.epoch}") print(f"Learning rate(s): {', '.join(map(str, lrs))}") with timing.time("train epoch"): state.stage = Stage.TRAIN state.model.train() print(f"start training epoch {state.epoch}") epoch_data = training_data if self.config.num_batches_per_epoch: # We want to limit the number of batches in the epoch; # equivalent to epoch_data[:num_batches_per_epoch] for iterators. # In this case we set the training data iterator to cycle earlier # in the training process, so when it reaches the end it will # loop back to the beginning. epoch_data = itertools.islice( epoch_data, self.config.num_batches_per_epoch ) self.run_epoch(state, epoch_data, metric_reporter) if not self.config.do_eval: continue with timing.time("eval epoch"): state.stage = Stage.EVAL model.eval(Stage.EVAL) print(f"start evaluating epoch {state.epoch}") with torch.no_grad(): eval_metric = self.run_epoch(state, eval_data, metric_reporter) # Step the learning rate scheduler(s) assert eval_metric is not None state.scheduler.step_epoch( metrics=metric_reporter.get_model_select_metric(eval_metric), epoch=state.epoch, ) # Did we train a better model? better_model = metric_reporter.compare_metric( eval_metric, state.best_model_metric ) if better_model: self.update_best_model(state, train_config, eval_metric) if better_model or train_config.save_all_checkpoints: self.save_checkpoint(state, train_config) if self.optimizer.finalize(): should_update_model = True eval_metric = None if self.config.do_eval: state.stage = Stage.EVAL model.eval(Stage.EVAL) print(f"start evaluating finalized state") with torch.no_grad(): eval_metric = self.run_epoch(state, eval_data, metric_reporter) should_update_model = metric_reporter.compare_metric( eval_metric, state.best_model_metric ) if should_update_model: self.update_best_model(state, train_config, eval_metric) if should_update_model or train_config.save_all_checkpoints: self.save_checkpoint(state, train_config) # Only bother loading the best model for master worker if ( rank == 0 and state.best_model_state is not None and self.config.load_best_model_after_train ): self.load_best_model(state) return state.model, state.best_model_metric
[docs] @timing.report_snapshot def run_epoch( self, state: TrainingState, data: BatchIterator, metric_reporter: MetricReporter ): # This method is due for some refactoring, pushing it off because it interacts # with the metric reporter too much. Much of the logic here either changes in # the NewTaskTrainer or should change with a better metric reporter design. report_metric = state.stage != Stage.TRAIN or self.config.report_train_metrics model = state.model samples = [] is_data_empty = True """ Sometimes, a batch of inputs is too large to fit into GPU, which has to be split into several micro-batches. However, to improve efficiency, it would be helpful to only apply params/gradients sync at original batch boundaries instead of micro-batch boundaries. num_accumulated_batches specified the number of accumulating gradients locally before sync gradients, total training_batch_size = train_batch_size x num_accumulated_batches and it will improve the system performance by reduce the total network transfer bytes. """ for sample in enumerate(data): is_data_empty = False samples.append(sample) if ( state.stage != Stage.TRAIN or len(samples) == self.config.num_accumulated_batches ): self.run_step(samples, state, metric_reporter, report_metric) samples = [] if samples: self.run_step(samples, state, metric_reporter, report_metric) samples = [] metrics = None if report_metric: if is_data_empty: error_msg = ( f"Trying to report metric for stage {state.stage}, but no data was " "found. Either disable metric reporting for this stage, pass in " "non-empty data, or see if data fields are misnamed (warnings " "would appear in preceding stdout logs)." ) raise ValueError(error_msg) with timing.time("report metrics"): metrics = metric_reporter.report_metric( model, state.stage, state.epoch, print_to_channels=(state.rank == 0), optimizer=getattr( state, "optimizer", None ), # optimizer is not present during test ) else: metric_reporter._reset() if state.rank == 0 and self.config.sparsifier: current_sparsity = self.sparsifier.get_current_sparsity(state.model) print(f"sparsity in the model: {current_sparsity}") return metrics
[docs] @timing.time("run_step") def run_step( self, samples: List[Any], state: TrainingState, metric_reporter: MetricReporter, report_metric: bool, ): sample_size = len(samples) assert sample_size <= self.config.num_accumulated_batches model = state.model self.zero_grads(state) for idx, (batch_id, (inputs, targets, context)) in enumerate(samples): with contextlib_ExitStack() as exit_stack: maybe_accumulate_gradients(exit_stack, model, idx, sample_size) # pass context to model to use in forward call if needed model.contextualize(context) with timing.time("model.forward"): logits = model(*inputs) with timing.time("compute loss"): loss = precision.maybe_float( model.get_loss(logits, targets, context) ) if BatchContext.IGNORE_LOSS in context: loss *= 0 elif sample_size > 1: # gradients averaged per batch and accumulated across samples. # divide sample_size to let gradients averaged per example loss = loss / sample_size self.backprop(state, loss) if report_metric: with timing.time("get pred"): preds, scores = model.get_pred( logits, targets, context, state.stage, *inputs ) with timing.time("add metrics"): metric_reporter.add_batch_stats( batch_id, preds, targets, scores, loss.item(), inputs, **context ) if batch_id % self.config.num_samples_to_log_progress == 0: print( f"Running batch {batch_id} for epoch {state.epoch} \ in {state.stage} stage", flush=True, ) # update gradients after len(samples) forward & backward self.optimizer_step(state) with timing.time("add gradients"): if report_metric and state.stage == Stage.TRAIN: metric_reporter.add_gradients(state.model) self.sparsification_step(state)
[docs]class TaskTrainer(Trainer): __EXPANSIBLE__ = True
[docs] class Config(Trainer.Config): """Make mypy happy"""
[docs] @timing.time("run_step") def run_step( self, samples: List[Any], state: TrainingState, metric_reporter: MetricReporter, report_metric: bool, ): """Our run_step is a bit different, because we're wrapping the model forward call with model.train_batch, which arranges tensors and gets loss, etc. Whenever "samples" contains more than one mini-batch (sample_size > 1), we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ sample_size = len(samples) assert sample_size <= self.config.num_accumulated_batches model = state.model self.zero_grads(state) for idx, (batch_id, (raw_batch, batch)) in enumerate(samples): with contextlib_ExitStack() as exit_stack: # enter ddp no_sync context and fp16 delay_scale context if needed maybe_accumulate_gradients(exit_stack, model, idx, sample_size) with timing.time("model.train_batch"): loss, metric_data = model.train_batch(model, batch, state) if sample_size > 1: # gradients averaged per batch and accumulated across samples. # divide sample_size to let gradients averaged per example loss = loss / sample_size self.backprop(state, loss) if report_metric: with timing.time("add metrics"): metric_reporter.add_batch_stats( batch_id, *metric_data, # TODO merge this step into add_batch_stats once all data # migration is done **metric_reporter.batch_context(raw_batch, batch), ) if batch_id % self.config.num_samples_to_log_progress == 0: metric_reporter.report_realtime_metric(state.stage) # update gradients after #len(samples) forward & backward self.optimizer_step(state) with timing.time("add gradients"): if report_metric and state.stage == Stage.TRAIN: metric_reporter.add_gradients(state.model) self.sparsification_step(state)
def _prepare_scheduler(self, training_batches, scheduler=None): """Batch based schedulers require knowing the number of batches in the data. We're not supporting that yet with the Data api, need to figure out how to expose this info or restructure batch-based schedulers to not need it.""" if scheduler.batch_based_schedulers: raise Exception("New tasks don't yet support batch-based scheduling") return scheduler