#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import abc
import io
import logging
import os
from typing import Dict, List, Optional
import torch
from pytext.config import PyTextConfig, config_to_json, pytext_config_from_json
from pytext.data import CommonMetadata
from pytext.data.tensorizers import Tensorizer
from pytext.models import Model
from pytext.trainers.training_state import TrainingState
from pytext.utils.file_io import PathManager
from pytext.utils.usage import log_class_usage
DATA_STATE = "data_state"
CONFIG_JSON = "config_json"
MODEL_STATE = "model_state"
SERIALIZE_VERSION_KEY = "pytext_serialization_version"
TENSORIZERS = "tensorizers"
TRAINING_STATE = "training_state"
LATEST_SERIALIZE_VERSION = 3
LOADER_VERSION_MAP = {}
logger = logging.getLogger(__name__)
[docs]def register_snapshot_loader(version):
def decorator(fn):
LOADER_VERSION_MAP[version] = fn
global LATEST_SERIALIZE_VERSION
LATEST_SERIALIZE_VERSION = max(LATEST_SERIALIZE_VERSION, version)
return fn
return decorator
[docs]@register_snapshot_loader(1)
def load_v1(state):
config = pytext_config_from_json(state[CONFIG_JSON])
# importing in file level generates circular import/dependency failures,
# that need refator later to fix
from .task import create_task
task = create_task(
config.task, metadata=state[DATA_STATE], model_state=state[MODEL_STATE]
)
return task, config, None
[docs]@register_snapshot_loader(2)
def load_v2(state):
config = pytext_config_from_json(state[CONFIG_JSON])
model_state = state[MODEL_STATE]
tensorizers = state[TENSORIZERS]
# importing in file level generates circular import/dependency failures,
# that need refator later to fix
from .task import create_task
task = create_task(
config.task,
metadata=state[DATA_STATE],
model_state=model_state,
tensorizers=tensorizers,
)
return task, config, None
[docs]@register_snapshot_loader(3)
def load_v3(state, overwrite_config=None, rank=0, world_size=1):
saved_config = pytext_config_from_json(state[CONFIG_JSON])
if overwrite_config:
config = overwrite_config
print(f"Use config from current task")
else:
config = saved_config
print(f"Use config saved in snapshot")
model_state = state[MODEL_STATE]
training_state = state[TRAINING_STATE]
if training_state and training_state.tensorizers:
tensorizers = training_state.tensorizers
else:
tensorizers = state[TENSORIZERS]
# importing in file level generates circular import/dependency failures,
# that need refator later to fix
from .task import create_task
task = create_task(
config.task,
metadata=state[DATA_STATE],
model_state=model_state,
tensorizers=tensorizers,
rank=rank,
world_size=world_size,
)
# TODO: T53664090 @stevenliu save & load state_dict() of optimizer and scheduler
if training_state:
if training_state.model is None and task.model:
training_state.model = task.model
if training_state.optimizer and task.trainer.optimizer:
"""
https://pytorch.org/tutorials/beginner/saving_loading_models.html
Unpickling optimizer object from checkpoint could result in a
different parameter copy from model parameters. Especially in
mixied precision training, which optimizer param_groups maintains
master weights copy instead of the model parameters.
The suggested loading mechanism is
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(model.parameters(), *args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
"""
optimizer = task.trainer.optimizer
optimizer.load_state_dict(training_state.optimizer.state_dict())
training_state.optimizer = optimizer
return task, config, training_state
[docs]def load_checkpoint(state, overwrite_config=None, rank=0, world_size=1):
print(f"Loaded checkpoint...")
if SERIALIZE_VERSION_KEY not in state:
return load_v1(state)
else:
return LOADER_VERSION_MAP[state[SERIALIZE_VERSION_KEY]](
state, overwrite_config, rank, world_size
)
[docs]def save_checkpoint(
f: io.IOBase,
config: PyTextConfig,
model: Model,
meta: Optional[CommonMetadata],
tensorizers: Dict[str, Tensorizer],
training_state: Optional[TrainingState] = None,
) -> str:
# Currently torch.save() has error pickling certain models when not saving
# by model.state_dict(), thus currently overriding the model in
# training_state with None, and put back saving
# https://github.com/pytorch/pytorch/issues/15116
model_in_training_state = None
if training_state:
model_in_training_state, training_state.model = training_state.model, None
try:
state = {
DATA_STATE: meta,
CONFIG_JSON: config_to_json(PyTextConfig, config),
MODEL_STATE: model.state_dict(),
SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION,
TENSORIZERS: tensorizers,
TRAINING_STATE: training_state,
}
torch.save(state, f)
finally:
if training_state:
training_state.model = model_in_training_state
[docs]def get_latest_checkpoint_path(dir_path: Optional[str] = None) -> str:
"""
Get the latest checkpoint path
args:
dir_path: the dir to scan for existing checkpoint files. Default: if None,
the latest checkpoint path saved in momery will be returned
Returns: checkpoint_path
"""
if not dir_path:
return _CHECKPOINT_MANAGER.get_latest_checkpoint_path()
if PathManager.exists(dir_path):
checkpoint_indices = [
int(file_path.split("-")[1])
for file_path in PathManager.ls(dir_path)
if file_path.startswith("checkpoint")
]
if checkpoint_indices:
latest_checkpoint_path = f"{dir_path}/checkpoint-{max(checkpoint_indices)}"
logger.info(f"find the latest checkpoint: {latest_checkpoint_path}")
return latest_checkpoint_path
return None
[docs]def get_post_training_snapshot_path() -> str:
return _CHECKPOINT_MANAGER.get_post_training_snapshot_path()
DELIMITER = "-"
# generate per epoch checkpoint save path
[docs]def generate_checkpoint_path(config: PyTextConfig, identifier: str):
dir_name = os.path.dirname(config.save_snapshot_path)
return f"{dir_name}/checkpoint{DELIMITER}{identifier}"
[docs]class PyTextCheckpointManagerInterface(abc.ABC):
"""
CheckpointManager is a class abstraction to manage a training job's
checkpoints/snapshots with different IO and storage.
"""
[docs] @abc.abstractmethod
def save_checkpoint(self, state, checkpoint_path):
"""
Serialize and save checkpoint to given path. State is a dictionary that
represents the all data to be saved.
Args:
state: Dictionary containing data to be saved
checkpoint_path: path of file to save checkpoint
"""
raise NotImplementedError("Not implemented in interface class")
[docs] @abc.abstractmethod
def save_snapshot(self, state, snapshot_path):
"""
Serialize and save post-training model snapshot to given path. State
is a dictionary that represents the all data to be saved.
Having a separate method for snapshots enables future optimizations like
quantization to be applied to snapshots.
Args:
state: Dictionary containing data to be saved
snapshot_path: path of file to save snapshot
"""
raise NotImplementedError("Not implemented in interface class")
[docs] @abc.abstractmethod
def load(self, load_path: str):
"""
Loads a checkpoint/snapshot from disk.
Args:
load_path (str): the file path from which to load
Returns: De-serialized state (dictionary) that was saved
"""
raise NotImplementedError("Not implemented in interface class")
[docs] @abc.abstractmethod
def list(self) -> List[str]:
"""
Return all existing checkpoint paths
Returns: checkpoint_path_list (List[str]), list elements are in the same
order of checkpoint saving
"""
raise NotImplementedError("Not implemented in interface class")
[docs] @abc.abstractmethod
def get_latest_checkpoint_path(self) -> str:
"""
Return most recent saved checkpoint path in str
Returns: checkpoint_path (str)
"""
raise NotImplementedError("Not implemented in interface class")
[docs] @abc.abstractmethod
def get_post_training_snapshot_path(self) -> str:
raise NotImplementedError("Not implemented in interface class")
[docs]class CheckpointManager(PyTextCheckpointManagerInterface):
def __init__(self):
# keep a list of saved checkpoint path
self._saved_paths: List[str] = []
self._post_training_snapshot_path = None
log_class_usage(__class__)
[docs] def save(self, state, save_path):
with PathManager.open(save_path, "wb") as f:
torch.save(state, f)
[docs] def save_checkpoint(self, state, checkpoint_path):
self.save(state, checkpoint_path)
self._saved_paths.append(checkpoint_path)
[docs] def save_snapshot(self, state, snapshot_path):
self.save(state, snapshot_path)
self._post_training_snapshot_path = snapshot_path
[docs] def load(self, load_path: str):
if not (load_path and PathManager.isfile(load_path)):
raise ValueError(f"Invalid snapshot path: {load_path}")
with PathManager.open(load_path, "rb") as checkpoint_f:
state = torch.load(checkpoint_f, map_location=lambda storage, loc: storage)
return state
[docs] def list(self) -> List[str]:
return self._saved_paths
[docs] def get_latest_checkpoint_path(self) -> str:
return self._saved_paths[-1] if len(self._saved_paths) > 0 else None
[docs] def get_post_training_snapshot_path(self) -> str:
return self._post_training_snapshot_path
_CHECKPOINT_MANAGER = CheckpointManager()
[docs]def set_checkpoint_manager(manager: PyTextCheckpointManagerInterface) -> None:
global _CHECKPOINT_MANAGER
_CHECKPOINT_MANAGER = manager
[docs]def save(
config: PyTextConfig,
model: Model,
meta: Optional[CommonMetadata],
tensorizers: Dict[str, Tensorizer],
training_state: Optional[TrainingState] = None,
identifier: Optional[str] = None,
) -> str:
"""
Save all stateful information of a training task to a specified file-like
object, will save the original config, model state, metadata,
training state if training is not completed
Args:
identifier (str): used to identify a checkpoint within a training job,
used as a suffix for save path
config (PytextConfig): contains all raw parameter/hyper-parameters
for training task
model (Model): actual model in training
training_state (TrainingState): stateful infomation during training
Returns:
identifier (str): if identifier is not specified, will save to
config.save_snapshot_path to be consistent to post-training snapshot;
if specified, will be used to save checkpoint during training,
identifier is used to identify checkpoints in the same training
"""
saved_path = ""
if identifier:
# saving during-training checkpoints
saved_path = generate_checkpoint_path(config, identifier)
else:
# saving post-training snapshot if no identifer given
saved_path = config.save_snapshot_path
print(f"Saving pytorch model to: {saved_path}")
saved_folder = os.path.dirname(saved_path)
if not PathManager.exists(saved_folder):
PathManager.mkdirs(saved_folder)
print(f"created {saved_folder}")
# Currently torch.save() has error pickling certain models when not saving
# by model.state_dict(), thus currently overriding the model in
# training_state with None, and put back saving
# https://github.com/pytorch/pytorch/issues/15116
model_in_training_state = None
if training_state:
model_in_training_state, training_state.model = training_state.model, None
try:
state = {
DATA_STATE: meta,
CONFIG_JSON: config_to_json(PyTextConfig, config),
MODEL_STATE: model.state_dict(),
SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION,
TENSORIZERS: tensorizers,
TRAINING_STATE: training_state,
}
if identifier is not None:
_CHECKPOINT_MANAGER.save_checkpoint(state, saved_path)
else:
_CHECKPOINT_MANAGER.save_snapshot(state, saved_path)
finally:
if training_state:
training_state.model = model_in_training_state
return saved_path
[docs]def load(load_path: str, overwrite_config=None, rank=0, world_size=1):
"""
Load task, config and training state from a saved snapshot
by default, it will construct the task using the saved config then load
metadata and model state.
if overwrite_task is specified, it will construct the task using
overwrite_task then load metadata and model state.
"""
state = _CHECKPOINT_MANAGER.load(load_path)
return load_checkpoint(state, overwrite_config, rank, world_size)