Source code for pytext.trainers.hogwild_trainer

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import copy

import torch
import torch.multiprocessing as mp
from pytext.common.constants import Stage
from pytext.config.pytext_config import ConfigBase
from pytext.metric_reporters import MetricReporter
from pytext.trainers.trainer import TaskTrainer, Trainer, TrainingState
from pytext.utils import cuda

try:
    from torchtext.legacy.data import Iterator
except ImportError:
    from torchtext.data import Iterator


[docs]class HogwildTrainer_Deprecated(Trainer):
[docs] class Config(ConfigBase): real_trainer: Trainer.Config = Trainer.Config() num_workers: int = 1
[docs] @classmethod def from_config(cls, config: Config, model: torch.nn.Module, *args, **kwargs): # can't run hogwild on cuda if cuda.CUDA_ENABLED or config.num_workers == 1: return Trainer(config.real_trainer, model) return cls(config.real_trainer, config.num_workers, model, *args, **kwargs)
def __init__( self, real_trainer_config, num_workers, model: torch.nn.Module, *args, **kwargs ): super().__init__(real_trainer_config, model, *args, **kwargs) self.num_workers = num_workers
[docs] def run_epoch( self, state: TrainingState, data_iter: Iterator, metric_reporter: MetricReporter ): if state.stage == Stage.TRAIN: processes = [] for worker_rank in range(self.num_workers): # Initialize the batches with different random states. worker_state = copy.copy(state) worker_state.rank = worker_rank data_iter.batches.init_epoch() p = mp.Process( target=super().run_epoch, args=(state, data_iter, metric_reporter) ) processes.append(p) p.start() for p in processes: p.join() else: return super().run_epoch(state, data_iter, metric_reporter)
[docs] def set_up_training(self, state: TrainingState, training_data): training_data = super().set_up_training(state, training_data) # Share memory of tensors for concurrent updates from multiple processes. if self.num_workers > 1: for param in state.model.parameters(): param.share_memory_() return training_data
[docs]class HogwildTrainer(Trainer):
[docs] class Config(ConfigBase): real_trainer: TaskTrainer.Config = TaskTrainer.Config() num_workers: int = 1
[docs] @classmethod def from_config(cls, config: Config, model: torch.nn.Module, *args, **kwargs): # can't run hogwild on cuda if cuda.CUDA_ENABLED or config.num_workers == 1: return TaskTrainer(config.real_trainer, model) return cls(config.real_trainer, config.num_workers, model, *args, **kwargs)
__init__ = HogwildTrainer_Deprecated.__init__ run_epoch = HogwildTrainer_Deprecated.run_epoch set_up_training = HogwildTrainer_Deprecated.set_up_training