pytext.trainers package¶
Submodules¶
pytext.trainers.ensemble_trainer module¶
-
class
pytext.trainers.ensemble_trainer.
EnsembleTrainer
(real_trainers)[source]¶ Bases:
pytext.trainers.trainer.TrainerBase
Trainer for ensemble models
-
real_trainer
¶ the actual trainer to run
Type: Trainer
-
pytext.trainers.hogwild_trainer module¶
-
class
pytext.trainers.hogwild_trainer.
HogwildTrainer
(real_trainer_config, num_workers, model: torch.nn.modules.module.Module, *args, **kwargs)[source]¶ Bases:
pytext.trainers.trainer.Trainer
-
classmethod
from_config
(config: pytext.trainers.hogwild_trainer.HogwildTrainer.Config, model: torch.nn.modules.module.Module, *args, **kwargs)[source]¶
-
run_epoch
(state: pytext.trainers.training_state.TrainingState, data_iter: torchtext.legacy.data.iterator.Iterator, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter)¶
-
set_up_training
(state: pytext.trainers.training_state.TrainingState, training_data)¶
-
classmethod
-
class
pytext.trainers.hogwild_trainer.
HogwildTrainer_Deprecated
(real_trainer_config, num_workers, model: torch.nn.modules.module.Module, *args, **kwargs)[source]¶ Bases:
pytext.trainers.trainer.Trainer
-
classmethod
from_config
(config: pytext.trainers.hogwild_trainer.HogwildTrainer_Deprecated.Config, model: torch.nn.modules.module.Module, *args, **kwargs)[source]¶
-
classmethod
pytext.trainers.trainer module¶
-
class
pytext.trainers.trainer.
TaskTrainer
(config: pytext.trainers.trainer.Trainer.Config, model: torch.nn.modules.module.Module)[source]¶ Bases:
pytext.trainers.trainer.Trainer
-
run_step
(samples: List[Any], state: pytext.trainers.training_state.TrainingState, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter, report_metric: bool)[source]¶ Our run_step is a bit different, because we’re wrapping the model forward call with model.train_batch, which arranges tensors and gets loss, etc.
Whenever “samples” contains more than one mini-batch (sample_size > 1), we want to accumulate gradients locally and only call all-reduce in the last backwards pass.
-
-
class
pytext.trainers.trainer.
Trainer
(config: pytext.trainers.trainer.Trainer.Config, model: torch.nn.modules.module.Module)[source]¶ Bases:
pytext.trainers.trainer.TrainerBase
- Base Trainer class that provide ways to
- 1 Train model, compute metrics against eval set and use the metrics for model selection. 2 Test trained model, compute and publish metrics against a blind test set.
-
epochs
¶ Training epochs
Type: int
-
early_stop_after
¶ Stop after how many epochs when the eval metric is not improving
Type: int
-
max_clip_norm
¶ Clip gradient norm if set
Type: Optional[float]
-
report_train_metrics
¶ Whether metrics on training data should be computed and reported.
Type: bool
-
target_time_limit_seconds
¶ Target time limit for training in seconds. If the expected time to train another epoch exceeds this limit, stop training.
Type: float
-
classmethod
from_config
(config: pytext.trainers.trainer.Trainer.Config, model: torch.nn.modules.module.Module, *args, **kwargs)[source]¶
-
run_epoch
(state: pytext.trainers.training_state.TrainingState, data: pytext.data.data_handler.BatchIterator, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter)[source]¶
-
run_step
(samples: List[Any], state: pytext.trainers.training_state.TrainingState, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter, report_metric: bool)[source]¶
-
save_checkpoint
(state: pytext.trainers.training_state.TrainingState, train_config: pytext.config.pytext_config.PyTextConfig) → str[source]¶
-
set_up_training
(state: pytext.trainers.training_state.TrainingState, training_data: pytext.data.data_handler.BatchIterator)[source]¶
-
test
(test_iter, model, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter)[source]¶
-
train
(training_data: pytext.data.data_handler.BatchIterator, eval_data: pytext.data.data_handler.BatchIterator, model: pytext.models.model.Model, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter, train_config: pytext.config.pytext_config.PyTextConfig, rank: int = 0) → Tuple[torch.nn.modules.module.Module, Any][source]¶ Train and eval a model, the model states will be modified. :param train_iter: batch iterator of training data :type train_iter: BatchIterator :param eval_iter: batch iterator of evaluation data :type eval_iter: BatchIterator :param model: model to be trained :type model: Model :param metric_reporter: compute metric based on training :type metric_reporter: MetricReporter :param output and report results to console, file.. etc: :param train_config: training config :type train_config: PyTextConfig :param training_result: only meaningful for Hogwild training. default :type training_result: Optional :param is None: :param rank: only used in distributed training, the rank of the current :type rank: int :param training thread, evaluation will only be done in rank 0:
Returns: the trained model together with the best metric Return type: model, best_metric
-
train_from_state
(state: pytext.trainers.training_state.TrainingState, training_data: pytext.data.data_handler.BatchIterator, eval_data: pytext.data.data_handler.BatchIterator, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter, train_config: pytext.config.pytext_config.PyTextConfig) → Tuple[torch.nn.modules.module.Module, Any][source]¶ Train and eval a model from a given training state will be modified. This function iterates epochs specified in config, and for each epoch do:
- Train model using training data, aggregate and report training results
- Adjust learning rate if scheduler is specified
- Evaluate model using evaluation data
- Calculate metrics based on evaluation results and select best model
Parameters: - training_state (TrainingState) – contrains stateful information to be
- to restore a training job (able) –
- train_iter (BatchIterator) – batch iterator of training data
- eval_iter (BatchIterator) – batch iterator of evaluation data
- model (Model) – model to be trained
- metric_reporter (MetricReporter) – compute metric based on training output and report results to console, file.. etc
- train_config (PyTextConfig) – training config
Returns: the trained model together with the best metric
Return type: model, best_metric
pytext.trainers.training_state module¶
Module contents¶
-
class
pytext.trainers.
Trainer
(config: pytext.trainers.trainer.Trainer.Config, model: torch.nn.modules.module.Module)[source]¶ Bases:
pytext.trainers.trainer.TrainerBase
- Base Trainer class that provide ways to
- 1 Train model, compute metrics against eval set and use the metrics for model selection. 2 Test trained model, compute and publish metrics against a blind test set.
-
epochs
¶ Training epochs
Type: int
-
early_stop_after
¶ Stop after how many epochs when the eval metric is not improving
Type: int
-
max_clip_norm
¶ Clip gradient norm if set
Type: Optional[float]
-
report_train_metrics
¶ Whether metrics on training data should be computed and reported.
Type: bool
-
target_time_limit_seconds
¶ Target time limit for training in seconds. If the expected time to train another epoch exceeds this limit, stop training.
Type: float
-
classmethod
from_config
(config: pytext.trainers.trainer.Trainer.Config, model: torch.nn.modules.module.Module, *args, **kwargs)[source]¶
-
run_epoch
(state: pytext.trainers.training_state.TrainingState, data: pytext.data.data_handler.BatchIterator, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter)[source]¶
-
run_step
(samples: List[Any], state: pytext.trainers.training_state.TrainingState, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter, report_metric: bool)[source]¶
-
save_checkpoint
(state: pytext.trainers.training_state.TrainingState, train_config: pytext.config.pytext_config.PyTextConfig) → str[source]¶
-
set_up_training
(state: pytext.trainers.training_state.TrainingState, training_data: pytext.data.data_handler.BatchIterator)[source]¶
-
test
(test_iter, model, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter)[source]¶
-
train
(training_data: pytext.data.data_handler.BatchIterator, eval_data: pytext.data.data_handler.BatchIterator, model: pytext.models.model.Model, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter, train_config: pytext.config.pytext_config.PyTextConfig, rank: int = 0) → Tuple[torch.nn.modules.module.Module, Any][source]¶ Train and eval a model, the model states will be modified. :param train_iter: batch iterator of training data :type train_iter: BatchIterator :param eval_iter: batch iterator of evaluation data :type eval_iter: BatchIterator :param model: model to be trained :type model: Model :param metric_reporter: compute metric based on training :type metric_reporter: MetricReporter :param output and report results to console, file.. etc: :param train_config: training config :type train_config: PyTextConfig :param training_result: only meaningful for Hogwild training. default :type training_result: Optional :param is None: :param rank: only used in distributed training, the rank of the current :type rank: int :param training thread, evaluation will only be done in rank 0:
Returns: the trained model together with the best metric Return type: model, best_metric
-
train_from_state
(state: pytext.trainers.training_state.TrainingState, training_data: pytext.data.data_handler.BatchIterator, eval_data: pytext.data.data_handler.BatchIterator, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter, train_config: pytext.config.pytext_config.PyTextConfig) → Tuple[torch.nn.modules.module.Module, Any][source]¶ Train and eval a model from a given training state will be modified. This function iterates epochs specified in config, and for each epoch do:
- Train model using training data, aggregate and report training results
- Adjust learning rate if scheduler is specified
- Evaluate model using evaluation data
- Calculate metrics based on evaluation results and select best model
Parameters: - training_state (TrainingState) – contrains stateful information to be
- to restore a training job (able) –
- train_iter (BatchIterator) – batch iterator of training data
- eval_iter (BatchIterator) – batch iterator of evaluation data
- model (Model) – model to be trained
- metric_reporter (MetricReporter) – compute metric based on training output and report results to console, file.. etc
- train_config (PyTextConfig) – training config
Returns: the trained model together with the best metric
Return type: model, best_metric
-
class
pytext.trainers.
TrainingState
(**kwargs)[source]¶ Bases:
object
-
best_model_metric
= None¶
-
best_model_state
= None¶
-
epoch
= 0¶
-
epochs_since_last_improvement
= 0¶
-
rank
= 0¶
-
stage
= 'Training'¶
-
step_counter
= 0¶
-
tensorizers
= None¶
-
-
class
pytext.trainers.
EnsembleTrainer
(real_trainers)[source]¶ Bases:
pytext.trainers.trainer.TrainerBase
Trainer for ensemble models
-
real_trainer
¶ the actual trainer to run
Type: Trainer
-
-
class
pytext.trainers.
HogwildTrainer
(real_trainer_config, num_workers, model: torch.nn.modules.module.Module, *args, **kwargs)[source]¶ Bases:
pytext.trainers.trainer.Trainer
-
classmethod
from_config
(config: pytext.trainers.hogwild_trainer.HogwildTrainer.Config, model: torch.nn.modules.module.Module, *args, **kwargs)[source]¶
-
run_epoch
(state: pytext.trainers.training_state.TrainingState, data_iter: torchtext.legacy.data.iterator.Iterator, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter)¶
-
set_up_training
(state: pytext.trainers.training_state.TrainingState, training_data)¶
-
classmethod
-
class
pytext.trainers.
HogwildTrainer_Deprecated
(real_trainer_config, num_workers, model: torch.nn.modules.module.Module, *args, **kwargs)[source]¶ Bases:
pytext.trainers.trainer.Trainer
-
classmethod
from_config
(config: pytext.trainers.hogwild_trainer.HogwildTrainer_Deprecated.Config, model: torch.nn.modules.module.Module, *args, **kwargs)[source]¶
-
classmethod
-
class
pytext.trainers.
TaskTrainer
(config: pytext.trainers.trainer.Trainer.Config, model: torch.nn.modules.module.Module)[source]¶ Bases:
pytext.trainers.trainer.Trainer
-
run_step
(samples: List[Any], state: pytext.trainers.training_state.TrainingState, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter, report_metric: bool)[source]¶ Our run_step is a bit different, because we’re wrapping the model forward call with model.train_batch, which arranges tensors and gets loss, etc.
Whenever “samples” contains more than one mini-batch (sample_size > 1), we want to accumulate gradients locally and only call all-reduce in the last backwards pass.
-