Source code for pytext.task.serialize

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import io
import logging
import os
from typing import Dict, List, Optional

import torch
from pytext.config import PyTextConfig, config_to_json, pytext_config_from_json
from pytext.data import CommonMetadata
from pytext.data.tensorizers import Tensorizer
from pytext.models import Model
from pytext.trainers.training_state import TrainingState
from pytext.utils.file_io import PathManager


DATA_STATE = "data_state"
CONFIG_JSON = "config_json"
MODEL_STATE = "model_state"
SERIALIZE_VERSION_KEY = "pytext_serialization_version"
TENSORIZERS = "tensorizers"
TRAINING_STATE = "training_state"


LATEST_SERIALIZE_VERSION = 3
LOADER_VERSION_MAP = {}


logger = logging.getLogger(__name__)


[docs]def register_snapshot_loader(version): def decorator(fn): LOADER_VERSION_MAP[version] = fn global LATEST_SERIALIZE_VERSION LATEST_SERIALIZE_VERSION = max(LATEST_SERIALIZE_VERSION, version) return fn return decorator
[docs]@register_snapshot_loader(1) def load_v1(state): config = pytext_config_from_json(state[CONFIG_JSON]) # importing in file level generates circular import/dependency failures, # that need refator later to fix from .task import create_task task = create_task( config.task, metadata=state[DATA_STATE], model_state=state[MODEL_STATE] ) return task, config, None
[docs]@register_snapshot_loader(2) def load_v2(state): config = pytext_config_from_json(state[CONFIG_JSON]) model_state = state[MODEL_STATE] tensorizers = state[TENSORIZERS] # importing in file level generates circular import/dependency failures, # that need refator later to fix from .task import create_task task = create_task( config.task, metadata=state[DATA_STATE], model_state=model_state, tensorizers=tensorizers, ) return task, config, None
[docs]@register_snapshot_loader(3) def load_v3(state, overwrite_config=None): saved_config = pytext_config_from_json(state[CONFIG_JSON]) if overwrite_config: config = overwrite_config print(f"Use config from current task") else: config = saved_config print(f"Use config saved in snapshot") model_state = state[MODEL_STATE] training_state = state[TRAINING_STATE] if training_state and training_state.tensorizers: tensorizers = training_state.tensorizers else: tensorizers = state[TENSORIZERS] # importing in file level generates circular import/dependency failures, # that need refator later to fix from .task import create_task task = create_task( config.task, metadata=state[DATA_STATE], model_state=model_state, tensorizers=tensorizers, ) # TODO: T53664090 @stevenliu save & load state_dict() of optimizer and scheduler if training_state: if training_state.model is None and task.model: training_state.model = task.model if training_state.optimizer and task.trainer.optimizer: """ https://pytorch.org/tutorials/beginner/saving_loading_models.html Unpickling optimizer object from checkpoint could result in a different parameter copy from model parameters. Especially in mixied precision training, which optimizer param_groups maintains master weights copy instead of the model parameters. The suggested loading mechanism is model = TheModelClass(*args, **kwargs) optimizer = TheOptimizerClass(model.parameters(), *args, **kwargs) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) """ optimizer = task.trainer.optimizer optimizer.load_state_dict(training_state.optimizer.state_dict()) training_state.optimizer = optimizer return task, config, training_state
[docs]def load_checkpoint(f: io.IOBase, overwrite_config=None): state = torch.load(f, map_location=lambda storage, loc: storage) print(f"Loaded checkpoint...") if SERIALIZE_VERSION_KEY not in state: return load_v1(state) else: return LOADER_VERSION_MAP[state[SERIALIZE_VERSION_KEY]](state, overwrite_config)
[docs]def save_checkpoint( f: io.IOBase, config: PyTextConfig, model: Model, meta: Optional[CommonMetadata], tensorizers: Dict[str, Tensorizer], training_state: Optional[TrainingState] = None, ) -> str: # Currently torch.save() has error pickling certain models when not saving # by model.state_dict(), thus currently overriding the model in # training_state with None, and put back saving # https://github.com/pytorch/pytorch/issues/15116 model_in_training_state = None if training_state: model_in_training_state, training_state.model = training_state.model, None try: state = { DATA_STATE: meta, CONFIG_JSON: config_to_json(PyTextConfig, config), MODEL_STATE: model.state_dict(), SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION, TENSORIZERS: tensorizers, TRAINING_STATE: training_state, } torch.save(state, f) finally: if training_state: training_state.model = model_in_training_state
[docs]def get_latest_checkpoint_path(dir_path: Optional[str] = None) -> str: """ Get the latest checkpoint path args: dir_path: the dir to scan for existing checkpoint files. Default: if None, the latest checkpoint path saved in momery will be returned Returns: checkpoint_path """ if not dir_path: return _CHECKPOINT_MANAGER.get_latest_checkpoint_path() if PathManager.exists(dir_path): checkpoint_indices = [ int(file_path.split("-")[1]) for file_path in PathManager.ls(dir_path) if file_path.startswith("checkpoint") ] if checkpoint_indices: latest_checkpoint_path = f"{dir_path}/checkpoint-{max(checkpoint_indices)}" logger.info(f"find the latest checkpoint: {latest_checkpoint_path}") return latest_checkpoint_path return None
[docs]def get_post_training_snapshot_path() -> str: return _CHECKPOINT_MANAGER.get_post_training_snapshot_path()
[docs]class CheckpointManager: """ CheckpointManager is class abstraction to manage training job's checkpoints with different IO and storage, using two functions: save() and load(). """ DELIMITER = "-" def __init__(self): # keep a list of saved checkpoint path self._saved_paths: List[str] = [] self._post_training_snapshot_path = None # generate per epoch checkpoint save path
[docs] def generate_checkpoint_path(self, config: PyTextConfig, identifier: str): dir_name = os.path.dirname(config.save_snapshot_path) return f"{dir_name}/checkpoint{self.DELIMITER}{identifier}"
[docs] def save( self, config: PyTextConfig, model: Model, meta: Optional[CommonMetadata], tensorizers: Dict[str, Tensorizer], training_state: Optional[TrainingState] = None, identifier: str = None, ) -> str: """ save a checkpoint to given path, config, model and training_state together represent the checkpoint. When identifier is None, this function is used to save post-training snapshot """ saved_path = "" if identifier: # saving during-training checkpoints saved_path = self.generate_checkpoint_path(config, identifier) print("Saving checkpoint to ", saved_path) else: # saving post-training snapshot if no identifer given saved_path = config.save_snapshot_path print(f"Saving pytorch model to: {saved_path}") saved_folder = os.path.dirname(saved_path) if not PathManager.exists(saved_folder): PathManager.mkdirs(saved_folder) print(f"created {saved_folder}") with PathManager.open(saved_path, "wb") as checkpoint_f: save_checkpoint( checkpoint_f, config, model, meta, tensorizers, training_state ) if identifier: self._saved_paths.append(saved_path) else: self._post_training_snapshot_path = saved_path return saved_path
[docs] def load(self, load_path: str, overwrite_config=None): """ Loads a checkpoint from disk. Args: load_path (str): the file path to load for checkpoint Returns: task (Task), config (PyTextConfig) and training_state (TrainingState) """ if not (load_path and PathManager.isfile(load_path)): raise ValueError(f"Invalid snapshot path{load_path}") print(f"Loading model from {load_path}") with PathManager.open(load_path, "rb") as checkpoint_f: return load_checkpoint(checkpoint_f, overwrite_config)
[docs] def list(self) -> List[str]: """ Return all existing checkpoint path in str Returns: checkpoint_path_list (List[str]), list elements are in the same order of checkpoint saving """ return self._saved_paths
[docs] def get_latest_checkpoint_path(self) -> str: """ Return most recent saved checkpoint path in str Returns: checkpoint_path (str) """ return self._saved_paths[-1] if len(self._saved_paths) > 0 else None
[docs] def get_post_training_snapshot_path(self) -> str: return self._post_training_snapshot_path
_CHECKPOINT_MANAGER = CheckpointManager()
[docs]def set_checkpoint_manager(manager: CheckpointManager) -> None: global _CHECKPOINT_MANAGER _CHECKPOINT_MANAGER = manager
[docs]def save( config: PyTextConfig, model: Model, meta: Optional[CommonMetadata], tensorizers: Dict[str, Tensorizer], training_state: Optional[TrainingState] = None, identifier: Optional[str] = None, ) -> str: """ Save all stateful information of a training task to a specified file-like object, will save the original config, model state, metadata, training state if training is not completed Args: identifier (str): used to identify a checkpoint within a training job, used as a suffix for save path config (PytextConfig): contains all raw parameter/hyper-parameters for training task model (Model): actual model in training training_state (TrainingState): stateful infomation during training Returns: identifier (str): if identifier is not specified, will save to config.save_snapshot_path to be consistent to post-training snapshot; if specified, will be used to save checkpoint during training, identifier is used to identify checkpoints in the same training """ return _CHECKPOINT_MANAGER.save( config, model, meta, tensorizers, training_state, identifier )
[docs]def load(load_path: str, overwrite_config=None): """ Load task, config and training state from a saved snapshot by default, it will construct the task using the saved config then load metadata and model state. if overwrite_task is specified, it will construct the task using overwrite_task then load metadata and model state. """ return _CHECKPOINT_MANAGER.load(load_path, overwrite_config)