#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import contextlib
from collections import namedtuple
from sys import stderr
from typing import Optional
import torch
from fairseq.optim.fp16_optimizer import DynamicLossScaler as Fairseq_DynamicLossScaler
from pytext.config.component import create_optimizer
from pytext.optimizer.optimizers import Optimizer
from pytext.utils import cuda, precision
_APEX_DISABLED = False
try:
from apex import amp
except ImportError:
print("Install apex from https://github.com/NVIDIA/apex/.", file=stderr)
_APEX_DISABLED = True
except AttributeError as e:
print(f"Fail to import apex: {e}", file=stderr)
_APEX_DISABLED = True
try:
from fairseq.optim.fp16_optimizer import (
_FP16OptimizerMixin as Fairseq_FP16OptimizerMixin,
)
except ImportError:
# TODO: temporary fix fairseq dependency, remove after fairseq new release.
from .fairseq_fp16_utils import Fairseq_FP16OptimizerMixin
# TODO: remove this try block after the new release by fairseq that
# contains the dependency
try:
from fairseq.optim.fp16_optimizer import (
_MemoryEfficientFP16OptimizerMixin as Fairseq_MemoryEfficientFP16OptimizerMixin,
)
except ImportError:
from .fairseq_fp16_utils import Fairseq_MemoryEfficientFP16OptimizerMixin
"""
Tips:
1. Recommand run fp16 on latest generation (Volta V100) GPU, CUDA 9.1 or newer
to leverage tensor cores, which provide 8x more throughput than single
precision math pipelines.
2. Additionally:
- Batch size should be a multiple of 8
- Tokens size should be a multiple of 8
- Embedding layers should be padded to be a multiple of 8
- Ideally, everything should be a multiple of 8 (e.g padding, etc)
3. Larger batch_size might increase GPU utilization and better performance.
"""
[docs]class FP16Optimizer(Optimizer):
__EXPANSIBLE__ = True
def __init__(self, fp32_optimizer):
self.fp32_optimizer: torch.optim.Optimizer = fp32_optimizer
@property
def param_groups(self):
return self.fp32_optimizer.param_groups
[docs] def finalize(self) -> bool:
return self.fp32_optimizer.finalize()
# methods to implement
[docs] def state_dict(self):
raise NotImplementedError
[docs] def load_state_dict(self, state_dict):
raise NotImplementedError
[docs] def zero_grad(self):
raise NotImplementedError
[docs] def step(self, closure=None):
raise NotImplementedError
[docs] def backward(self, loss):
raise NotImplementedError
[docs] def clip_grad_norm(self, max_norm, model):
raise NotImplementedError
[docs] def pre_export(self, model):
raise NotImplementedError
"""
Apex amp: https://github.com/NVIDIA/apex/tree/master/apex/amp
FP32 Master Weights <--(step)-- FP32 Gradients <--(unscale)-- Scaled FP16 Gradients
| |
(copy) | | (backprop)
| |
FP16 Weights --(forward)--> FP32 Loss --(loss scaling)--> Scaled FP32 Loss
Using amp require adding three lines of code.
https://nvidia.github.io/apex/amp.html
1. Allow Amp to perform casts as required by the opt_level:
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
2. loss.backward() replace with:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
3. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) replace with:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm)
Opt level explaination (from Nvidia Apex):
* O1: Insert automatic casts around Pytorch functions and Tensor methods
- The type of your model's weights is not altered. However, internally,
Pytorch functions are patched to cast any Tensor Core-friendly ops to FP16
for speed, while operations that might benefit from the additional stability
of FP32 are patched to cast their inputs to fp32.
- O1 is the safest way to try mixed precision training, and is recommended when
trying mixed precision training for the first time.
* O2: FP16 training with FP32 batchnorm and FP32 master weights.
- Calls .half() on your model, converting the entire model (except for batchnorms)
to FP16. Batchnorms are retained in FP32 for additional stability.
- The forward pass is patched to cast incoming Tensors to FP16, so you don't
need to change your data pipeline.
- O2 creates FP32 master weights outside the model and patches any optimizers
to update these master weights, then copy the master weights into the FP16
model weights.
"""
[docs]class FP16OptimizerApex(FP16Optimizer):
[docs] class Config(FP16Optimizer.Config):
# O1: Insert automatic casts around Pytorch functions and Tensor methods
# O2: FP16 training with FP32 batchnorm and FP32 master weights. (recommand)
opt_level: str = "O2"
# initial loss scale, None will use the default loss_scale
# defined in opt_level (for example: "dynamic" for O2)
init_loss_scale: Optional[int] = None
# determine the minimum loss scale
min_loss_scale: Optional[float] = None
def __init__(
self,
fp32_optimizer: Optimizer,
model: torch.nn.Module,
opt_level: str,
init_loss_scale: Optional[int],
min_loss_scale: Optional[float],
):
assert precision.FP16_ENABLED and not _APEX_DISABLED
model, fp32_optimizer = amp.initialize(
model,
fp32_optimizer,
opt_level=opt_level,
loss_scale=init_loss_scale,
min_loss_scale=min_loss_scale,
)
super().__init__(fp32_optimizer)
self.opt_level = opt_level
[docs] @classmethod
def from_config(
cls,
fp16_config: Config,
model: torch.nn.Module,
fp32_config: Optimizer.Config,
*unused,
):
fp32_optimizer = create_optimizer(fp32_config, model)
return cls(
fp32_optimizer,
model,
fp16_config.opt_level,
fp16_config.init_loss_scale,
fp16_config.min_loss_scale,
)
[docs] def state_dict(self):
return self.fp32_optimizer.state_dict()
[docs] def load_state_dict(self, state_dict):
return self.fp32_optimizer.load_state_dict(state_dict)
[docs] def zero_grad(self):
self.fp32_optimizer.zero_grad()
[docs] def step(self, closure=None):
self.fp32_optimizer.step(closure)
[docs] def backward(self, loss):
with amp.scale_loss(
loss, self.fp32_optimizer, delay_unscale=precision.DELAY_UNSCALE
) as scaled_loss:
scaled_loss.backward()
[docs] def clip_grad_norm(self, max_norm, model):
if max_norm is not None:
return torch.nn.utils.clip_grad_norm_(
amp.master_params(self.fp32_optimizer), max_norm
)
else:
return None
[docs] def pre_export(self, model):
if self.opt_level == "O2":
# convert model parameters back to fp32
model.float()
if hasattr(model, "old_forward"):
model.forward = model.old_forward
else:
# restoring uncasted versions of functions
amp._amp_state.handle._deactivate()
precision.FP16_ENABLED = False
[docs]class MemoryEfficientFP16OptimizerFairseq(
Fairseq_MemoryEfficientFP16OptimizerMixin, FP16Optimizer
):
"""
Wrap the mem efficient *optimizer* to support FP16 (mixed precision) training.
"""
[docs] class Config(FP16Optimizer.Config):
# initial loss scale
init_loss_scale: int = 2 ** 7
# determine when to increase loss scale,
# represents: consecutive number of non-overflow steps
scale_window: Optional[int] = None
# determine when to decrease loss scale, value range should be from 0 to 1,
# represents: percentage of overflow since last rescale
scale_tolerance: float = 0.0
# determine the loss scale minimum value threshold
threshold_loss_scale: Optional[float] = None
# used to detect loss exploding, exception will be raised if loss_scale
# reach this value
min_loss_scale: float = 0.0001
def __init__(
self,
fp16_params,
optimizer,
init_loss_scale,
scale_window,
scale_tolerance,
threshold_loss_scale,
min_loss_scale,
num_accumulated_batches,
):
assert precision.FP16_ENABLED
super().__init__(optimizer)
self.wrapped_optimizer = optimizer
if scale_window is None:
scale_window = (
2 ** 14 / cuda.DISTRIBUTED_WORLD_SIZE / num_accumulated_batches
)
else:
scale_window = scale_window
self.scaler = Fairseq_DynamicLossScaler(
init_scale=init_loss_scale,
scale_window=scale_window,
tolerance=scale_tolerance,
threshold=threshold_loss_scale,
)
self.min_loss_scale = min_loss_scale
[docs] @classmethod
def from_config(
cls,
fp16_config: Config,
model: torch.nn.Module,
fp32_config: Optimizer.Config,
num_accumulated_batches: int,
):
model = model.half()
fp16_params = list(filter(lambda p: p.requires_grad, model.parameters()))
fp32_optimizer = create_optimizer(fp32_config, model)
print(
"| Fairseq MemoryEfficientFP16Optimizer with init_loss_scale={}".format(
fp16_config.init_loss_scale
)
)
return cls(
fp16_params=fp16_params,
optimizer=fp32_optimizer,
init_loss_scale=fp16_config.init_loss_scale,
scale_window=fp16_config.scale_window,
scale_tolerance=fp16_config.scale_tolerance,
threshold_loss_scale=fp16_config.threshold_loss_scale,
min_loss_scale=fp16_config.min_loss_scale,
num_accumulated_batches=num_accumulated_batches,
)
[docs] def clip_grad_norm(self, max_norm, unused_model):
# fairseq clip_grad_norm will skip clipping when max_norm is 0.
if max_norm is None:
max_norm = 0.0
return super().clip_grad_norm(max_norm)
[docs] def pre_export(self, model):
model.float()
precision.FP16_ENABLED = False
[docs]class FP16OptimizerFairseq(Fairseq_FP16OptimizerMixin, FP16Optimizer):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
"""
[docs] class Config(FP16Optimizer.Config):
# initial loss scale
init_loss_scale: int = 2 ** 7
# determine when to increase loss scale,
# represents: consecutive number of non-overflow steps
scale_window: Optional[int] = None
# determine when to decrease loss scale, value range should be from 0 to 1,
# represents: percentage of overflow since last rescale
scale_tolerance: float = 0.0
# determine the loss scale minimum value threshold
threshold_loss_scale: Optional[float] = None
# used to detect loss exploding, exception will be raised if loss_scale
# reach this value
min_loss_scale: float = 0.0001
def __init__(
self,
fp16_params,
fp32_optimizer,
init_loss_scale,
scale_window,
scale_tolerance,
threshold_loss_scale,
min_loss_scale,
num_accumulated_batches,
):
assert precision.FP16_ENABLED
super().__init__(fp32_optimizer)
self.fp16_params = fp16_params
args = {"pipeline_model_parallel": False, "distributed_no_spawn": False}
fairseq_args = namedtuple("args", args.keys())(*args.values())
self.fp32_params = self.build_fp32_params(
args=fairseq_args, params=fp16_params, flatten=True
)
if scale_window is None:
scale_window = (
2 ** 14 / cuda.DISTRIBUTED_WORLD_SIZE / num_accumulated_batches
)
else:
scale_window = scale_window
self.scaler = Fairseq_DynamicLossScaler(
init_scale=init_loss_scale,
scale_window=scale_window,
tolerance=scale_tolerance,
threshold=threshold_loss_scale,
)
self.min_loss_scale = min_loss_scale
# reset fp32_optimizer param groups to using master weights
fp32_param_group = self.fp32_optimizer.param_groups[0]
fp32_param_group["params"] = [self.fp32_params[torch.cuda.current_device()]]
self.fp32_optimizer.reset_param_groups()
self.fp32_optimizer.add_param_group(fp32_param_group)
[docs] @classmethod
def from_config(
cls,
fp16_config: Config,
model: torch.nn.Module,
fp32_config: Optimizer.Config,
num_accumulated_batches: int,
):
model = model.half()
fp16_params = list(filter(lambda p: p.requires_grad, model.parameters()))
fp32_optimizer = create_optimizer(fp32_config, model)
print(
"| Fairseq FP16Optimizer with init_loss_scale={}".format(
fp16_config.init_loss_scale
)
)
return cls(
fp16_params=fp16_params,
fp32_optimizer=fp32_optimizer,
init_loss_scale=fp16_config.init_loss_scale,
scale_window=fp16_config.scale_window,
scale_tolerance=fp16_config.scale_tolerance,
threshold_loss_scale=fp16_config.threshold_loss_scale,
min_loss_scale=fp16_config.min_loss_scale,
num_accumulated_batches=num_accumulated_batches,
)
[docs] def clip_grad_norm(self, max_norm, unused_model):
# fairseq clip_grad_norm will skip clipping when max_norm is 0.
if max_norm is None:
max_norm = 0.0
return super().clip_grad_norm(max_norm)
[docs] def pre_export(self, model):
model.float()
precision.FP16_ENABLED = False
"""fp16 optimizer wraps torch.optim to support mixed precision training
structure of fp16Optimizer:
property
fp16_optimizer.param_groups ----------> inner_optimizer.param_groups
| |
___ __ |__ __ __ __ __ __ | __ __ __
| fp16 | after backward | fp32 |
zero_grad ----|-> grads --|-----------------|--> grads <--|-- check overflow
loss --->| weights <-|-----------------|-- weights |
model --->|_ __ __ __ __ __| after step |__ __ __ __ __ __ |
Usage Example:
1 optim.zero_grad()
2 for i in range(N):
3 model.forward() ---- fp16 weights
4 pre_process ---- fp16 grads upscale
5 optim.backward() ---- upscaled fp16 grads
6 post_process ---- downscale and float to fp32 grads
7 optim.step() ---- fp32 weights and grads
class FP16_Optimizer:
Properties:
inner_optimizer(torch.optim): optimizer in pytext (eg. Adam)
which is initialized with fp16 params already
param_groups (list): list of dictionaries: key(string), value (list)
loss_scaler(DynamicLossScaler): handle upscale, unscale, check_overflow
weights_update_needed(bool): whether coping weights from master to model is needed
grads_update_needed(bool): whether copying grads from model to master is needed
class DynamicLossScaler:
properties:
init_scale(int): beginning value of loss scale
scale_factor(int): the step length that we use to increase the scale
scale_window(int): the upper bound of iterations among which no overflow is triggered
is_overflow(bool): indicate whether overflow happens in this step
is_scaled(bool): whether grads are scaled
"""
[docs]class DynamicLossScaler(object):
def __init__(self, init_scale, scale_factor, scale_window):
self.scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self._iter = 0
self._last_overflow_iter = 0
self.is_overflow = False
[docs] def upscale(self, loss):
return loss.float() * self.scale
[docs] def unscale(self, grad):
grad.div_(self.scale)
[docs] def unscale_grads(self, param_groups):
for p in generate_params(param_groups):
self.unscale(p.grad)
[docs] def check_overflow_(self, grad):
if grad is not None:
cpu_sum = float(grad.float().sum())
if (
cpu_sum == float("inf")
or cpu_sum == -float("inf")
or cpu_sum != cpu_sum
):
self.is_overflow = True
else:
self.is_overflow = False
[docs] def check_overflow(self, params):
self.is_overflow = False
for p in generate_params(params):
self.check_overflow_(p.grad)
if self.is_overflow:
break
[docs] def update_scale(self):
r"""According to overflow situation, adjust loss scale.
Once overflow happened, we decrease the scale by scale_factor.
Setting tolerance is another approach depending on cases.
If we haven't had overflows for #scale_window times, we should increase
the scale by scale_factor.
"""
self._iter += 1
if self.is_overflow:
self._last_overflow_iter = self._iter
self.scale = max(self.scale / self.scale_factor, 1)
print(
"overflow happens, skip step, new loss scale is {}".format(self.scale)
)
elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.scale *= self.scale_factor
[docs]class FP16OptimizerDeprecated(object):
def __init__(self, init_optimizer, init_scale, scale_factor, scale_window):
r"""Initialize master weights maintaining optimizer.
Args:
init_optimizer(torch.optim.Optimizer): an initialized optimizer
init_scale(int): beginning value of loss scale
scale_factor(int): step that we adjust loss scale
scale_window(int): tolerence for non-overflows
Effects:
Initialize the optimizer, create master weights copy and loss scaler.
Modifies:
Record the reference of model params (fp16).
Change the inner optimizer's params to fp32.
Initialized the scaler, state and default
"""
self.inner_optimizer = init_optimizer
self.param_groups = []
for group in self.inner_optimizer.param_groups:
fp16_group = {}
for key, value in group.items():
if key == "params":
fp16_param = []
for j, p in enumerate(value):
fp16_param.append(p)
master_p = p.detach().clone().float()
master_p.requires_grad_(True)
group["params"][j] = master_p
# change the state map:
if p in self.inner_optimizer.state:
self.inner_optimizer.state[
master_p
] = self.inner_optimizer.state.pop(p)
fp16_group["params"] = fp16_param
else:
fp16_group[key] = value
self.param_groups.append(fp16_group)
self.loss_scaler = DynamicLossScaler(init_scale, scale_factor, scale_window)
self.state = self.inner_optimizer.state
self.weights_update_needed = False
self.grads_update_needed = False
[docs] def zero_grad(self):
for p in generate_params(self.param_groups):
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
[docs] def scale_loss(self, loss):
# print("-----running backward----")
self.grads_update_needed = True
return self.loss_scaler.upscale(loss)
[docs] def step(self):
r"""Realize weights update.
Update the grads from model to master. During iteration for parameters,
we check overflow after floating grads and copy. Then do unscaling.
If overflow doesn't happen, call inner optimizer's step() and copy
back the updated weights from inner optimizer to model.
Update loss scale according to overflow checking result.
"""
self._grads_from_model_to_master()
if not self.loss_scaler.is_overflow:
self.inner_optimizer.step()
self.weights_update_needed = True
self._weights_from_master_to_model()
self.loss_scaler.update_scale()
def _grads_from_model_to_master(self):
r"""Sync grads from model to inner optimizer
During each iteration, check overflow of grads.
If not overflow, float the grads and copy to inner optimizer, unscale.
"""
if self.grads_update_needed:
for model_param, master_param in zip(
generate_params(self.param_groups),
generate_params(self.inner_optimizer.param_groups),
):
# check master grad overflow
self.loss_scaler.check_overflow_(model_param.grad)
# print("checking overflow---{}".format(self.loss_scaler.is_overflow))
if self.loss_scaler.is_overflow:
break
if master_param.grad is None:
master_param.grad = torch.empty_like(master_param)
master_param.grad.copy_(model_param.grad)
self.loss_scaler.unscale(master_param.grad)
self.grads_update_needed = False
def _weights_from_master_to_model(self):
if self.weights_update_needed:
for model_param, master_param in zip(
generate_params(self.param_groups),
generate_params(self.inner_optimizer.param_groups),
):
model_param.data.copy_(master_param.data)
self.weights_update_needed = False
[docs] def state_dict(self):
state_dict = {}
state_dict["loss_scale"] = self.loss_scaler.scale
state_dict["overflow"] = self.loss_scaler.is_overflow
state_dict["param_groups"] = self.param_groups
state_dict["optimizer_state_dict"] = self.inner_optimizer.state_dict()
return state_dict
[docs] def load_state_dict(self, state_dict):
self.loss_scaler.scale = state_dict["loss_scale"]
self.loss_scaler.is_overflow = state_dict["overflow"]
self.inner_optimizer.load_state_dict(state_dict["optimizer_state_dict"])
self.param_groups = state_dict["param_groups"]
[docs] def finalize(self):
return self.inner_optimizer.finalize()
def __getstate__(self):
return self.state_dict()
def __setstate__(self, state):
self.load_state_dict(state)
[docs]def initialize(
model,
optimizer,
opt_level,
init_scale=2 ** 16,
scale_factor=2.0,
scale_window=2000,
memory_efficient=False,
):
optimizer = (
FP16OptimizerDeprecated(optimizer, init_scale, scale_factor, scale_window)
if not memory_efficient
else PureFP16Optimizer(optimizer, init_scale, scale_factor, scale_window)
)
return (model.half(), optimizer)
[docs]@contextlib.contextmanager
def scale_loss(loss, optimizer, delay_unscale=False):
yield optimizer.scale_loss(loss)
[docs]def master_params(optimizer):
return generate_params(optimizer.inner_optimizer.param_groups)
[docs]def generate_params(param_groups):
for group in param_groups:
for p in group["params"]:
yield p
"""PureFP16Optimizer
No maintenance of fp32 weights.
Internally maintain the chain:
loss.backward() float() step() half()
----------------->fp16 grads------>fp32 grads------> fp32 weights -----> fp16 weights
"""
[docs]class PureFP16Optimizer(FP16OptimizerDeprecated):
def __init__(
self, init_optimizer, init_scale=2.0 ** 16, scale_factor=2, scale_window=2000
):
r"""Initialize the memory-efficient optimizer
Args:
init_optimizer(torch.optim.Optimizer): an initialized optimizer
init_scale(int): beginning value of loss scale
scale_factor(int): step that we adjust loss scale
scale_window(int): tolerence for non-overflows
Effects:
initialize this optimizer wrapper and loss scaling tools,
initialized the scaler and state
"""
self.inner_optimizer = init_optimizer
self.param_groups = self.inner_optimizer.param_groups
self.loss_scaler = DynamicLossScaler(init_scale, scale_factor, scale_window)
self.state = self.inner_optimizer.state
self.is_scaled = False
print("===============Pure Memory Efficient Optimizer===============")
[docs] def scale_loss(self, loss):
r"""Scale the loss.
Args:
loss(pytext.Loss): loss function object
"""
self.is_scaled = True
return self.loss_scaler.upscale(loss)
[docs] def step(self):
r"""Updates the weights in inner optimizer.
If inner optimizer supports memory efficient, check overflow,
unscale and call advanced step.
Otherwise, float weights and grads, check whether grads are overflow
during the iteration, if not overflow, unscale grads and call inner
optimizer's step; If overflow happens, do nothing, wait to the end
to call half weights and grads (grads will be eliminated in zero_grad)
"""
support = getattr(self.inner_optimizer, "supports_memory_efficient_fp16", False)
if support:
self.loss_scaler.check_overflow(self.param_groups)
if not self.loss_scaler.is_overflow:
self._unscale()
self.inner_optimizer.step()
else:
self._fp16_to_fp32()
if not self.loss_scaler.is_overflow:
self.inner_optimizer.step()
self._fp32_to_fp16()
self.loss_scaler.update_scale()
def _unscale(self):
if self.is_scaled:
self.loss_scaler.unscale_grads(self.param_groups)
self.is_scaled = False
def _fp16_to_fp32(self):
for p in generate_params(self.param_groups):
p.data = p.data.float()
if p.grad is not None:
p.grad.data = p.grad.data.float()
self.loss_scaler.check_overflow_(p.grad)
if self.loss_scaler.is_overflow:
break
self.loss_scaler.unscale(p.grad)
def _fp32_to_fp16(self):
for p in generate_params(self.param_groups):
p.data = p.data.half()
if p.grad is not None:
p.grad.data = p.grad.data.half()
[docs] def load_state_dict(self, state_dict):
r"""Load an optimizer state dict.
We prefer the configuration of the existing optimizer instance.
Realize the same logic as in init() -- point the param_groups of outer
optimizer to that of the inner_optimizer.
"""
self.loss_scaler.scale = state_dict["loss_scale"]
self.loss_scaler.is_overflow = state_dict["overflow"]
self.inner_optimizer.load_state_dict(state_dict["optimizer_state_dict"])
self.param_groups = self.inner_optimizer.param_groups
self.state = self.inner_optimizer.state
[docs]class GeneratorFP16Optimizer(PureFP16Optimizer):
def __init__(
self, init_optimizer, init_scale=2.0 ** 16, scale_factor=2, scale_window=2000
):
r"""Initialize the generator implementation method of memory efficient optimizer.
Args:
init_optimizer(torch.optim.Optimizer): an initialized optimizer
init_scale(int): beginning value of loss scale
scale_factor(int): step that we adjust loss scale
scale_window(int): tolerence for non-overflows
Effects:
We create another copy of references of parameters in self.param_groups
to keep trace of changed weights and grads.
"""
self.inner_optimizer = init_optimizer
self.param_groups = []
for group in self.inner_optimizer.param_groups:
fp16_group = {}
for key, value in group.items():
fp16_group[key] = value
self.param_groups.append(fp16_group)
self.loss_scaler = DynamicLossScaler(init_scale, scale_factor, scale_window)
self.state = self.inner_optimizer.state
self.is_scaled = False
print("=============Generator Memory Efficient Optimizer==============")
[docs] def step(self):
r"""Updates weights.
Effects:
Check overflow, if not, when inner_optimizer supports memory-effcient
step, do overall unscale and call memory-efficient step.
If it doesn't support, modify each parameter list in param_groups
of inner_optimizer to a generator of the tensors. Call normal step
then, data type changing will be added automatically in that function.
No matter whether it is overflow, we need to update scale at the
last step.
"""
support = getattr(self.inner_optimizer, "supports_memory_efficient_fp16", False)
self.loss_scaler.check_overflow(self.param_groups)
if not self.loss_scaler.is_overflow:
if support:
self._unscale()
self.inner_optimizer.step()
else:
self._preprocess_step()
self.inner_optimizer.step()
self.loss_scaler.update_scale()
def _preprocess_step(self):
r"""Change the parameter list to a generator."""
for i, group in enumerate(self.param_groups):
self.inner_optimizer.param_groups[i]["params"] = convert_generator(
group["params"], self.loss_scaler.scale
)
[docs] def load_state_dict(self, state_dict):
r"""Load an optimizer state dict.
We prefer the configuration of the existing optimizer instance.
After we load state dict to inner_optimizer, we create the copy of
references of parameters again as in init().
"""
self.loss_scaler.scale = state_dict["loss_scale"]
self.loss_scaler.is_overflow = state_dict["overflow"]
self.inner_optimizer.load_state_dict(state_dict["optimizer_state_dict"])
self.param_groups = []
for group in self.inner_optimizer.param_groups:
fp16_group = {}
for key, value in group.items():
fp16_group[key] = value
self.param_groups.append(fp16_group)
self.state = self.inner_optimizer.state
[docs]def convert_generator(params, scale):
r"""Create the generator for parameter tensors.
For each parameter, we float and unscale it. After the caller calls next(),
we realize the half process and start next parameter's processing.
"""
for p in params:
p.data = p.data.float()
if p.grad is not None:
p.grad.data = p.grad.data.float()
p.grad.div_(scale)
yield p
p.data = p.data.half()
if p.grad is not None:
p.grad.data = p.grad.data.half()