Source code for pytext.trainers.ensemble_trainer

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
from pytext.config import ConfigBase
from pytext.config.component import create_trainer

from .trainer import TaskTrainer, TrainerBase


[docs]class EnsembleTrainer(TrainerBase): """Trainer for ensemble models Attributes: real_trainer (Trainer): the actual trainer to run """
[docs] class Config(ConfigBase): real_trainer: TaskTrainer.Config = TaskTrainer.Config()
[docs] @classmethod def from_config(cls, config: Config, model: torch.nn.Module, *args, **kwargs): trainers = [ create_trainer(config.real_trainer, m, *args, **kwargs) for m in model.models ] return cls(trainers)
def __init__(self, real_trainers): self.real_trainers = real_trainers self.optimizer = real_trainers[0].optimizer self.test = real_trainers[0].test self.train_single_model = real_trainers[0].train """ Train and eval ensemble model, each sub model will be trained separately in sequence, and ``model.merge_sub_models`` will be called to merge states in sub models (e.g transition matrix for crf). To train sub models in parallel, please use train_single_model method instead Args: train_iter (BatchIterator): batch iterator of training data eval_iter (BatchIterator): batch iterator of evaluation data model (Model): model to be trained metric_reporter (MetricReporter): compute metric based on training output and report results to console, file.. etc train_config (PyTextConfig): training config optimizers (List[torch.optim.Optimizer]): a list of torch optimizers, in most of the case only contains one optimizer scheduler (Optional[torch.optim.lr_scheduler]): learning rate scheduler, default is None rank (int): only used in distributed training, the rank of the current training thread, evaluation will only be done in rank 0 Returns: model, none: only the trained ensemble model, no best metric will be returned since there's no clear way of aggregating metric from sub models """
[docs] def train(self, train_iter, eval_iter, model, *args, **kwargs): for (m, t) in zip(model.models, self.real_trainers): t.train(train_iter, eval_iter, m, *args, **kwargs) model.merge_sub_models() return model, None