Source code for pytext.trainers.training_state

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


from typing import Any, Dict

from pytext.common.constants import Stage
from pytext.data.tensorizers import Tensorizer
from pytext.models.model import Model
from pytext.optimizer import Optimizer, PrivacyEngine
from pytext.optimizer.scheduler import Scheduler
from pytext.optimizer.sparsifiers.sparsifier import Sparsifier


[docs]class TrainingState: model: Model optimizer: Optimizer scheduler: Scheduler sparsifier: Sparsifier start_time: float privacy_engine: PrivacyEngine # epoch counter epoch: int = 0 # step counter: each optimizer.step() increments step_counter step_counter: int = 0 rank: int = 0 stage: Stage = Stage.TRAIN epochs_since_last_improvement: int = 0 best_model_state: Any = None best_model_metric: Any = None tensorizers: Dict[str, Tensorizer] = None def __init__(self, **kwargs): unknown_keys = kwargs.keys() - TrainingState.__annotations__.keys() if unknown_keys: raise TypeError(f"TrainingState unexpected attributes {unknown_keys}") vars(self).update(kwargs)