#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import math
from collections import defaultdict
from typing import Optional
import torch
from torch.optim.optimizer import Optimizer as PT_Optimizer
from .optimizers import Optimizer
[docs]class MADGRAD(Optimizer, PT_Optimizer):
"""
`MADGRAD Optimizer`: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
Optimization.
Paper: https://arxiv.org/abs/2101.11075
Implementation has been copied over from the original author
(https://github.com/facebookresearch/madgrad/blob/master/madgrad/madgrad.py)
"""
[docs] class Config(Optimizer.Config):
lr: float = 1e-3
eps: float = 1e-6
momentum: float = 0.9
weight_decay: float = 0.0
r"""
Arguments:
params (iterable):
Iterable of parameters to optimize or dicts defining parameter groups.
lr (float):
Learning rate (default: 1e-2).
momentum (float):
Momentum value in the range [0,1) (default: 0.9).
weight_decay (float):
Weight decay, i.e. a L2 penalty (default: 0).
eps (float):
Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6).
"""
[docs] @classmethod
def from_config(cls, config: Config, model: torch.nn.Module):
return cls(
params=model.parameters(),
lr=config.lr,
momentum=config.momentum,
weight_decay=config.weight_decay,
eps=config.eps,
)
def __init__(
self,
params,
lr: float = 1e-2,
momentum: float = 0.9,
weight_decay: float = 0,
eps: float = 1e-6,
k: int = 0,
):
if momentum < 0 or momentum >= 1:
raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
if lr <= 0:
raise ValueError(f"Learning rate {lr} must be positive")
if weight_decay < 0:
raise ValueError(f"Weight decay {weight_decay} must be non-negative")
if eps < 0:
raise ValueError("Eps must be non-negative")
defaults = {
"lr": lr,
"eps": eps,
"momentum": momentum,
"weight_decay": weight_decay,
"k": k,
}
self.momentum = momentum
PT_Optimizer.__init__(self, params, defaults)
self.initialize_state()
[docs] def initialize_state(self):
for group in self.param_groups:
for p in group["params"]:
if p not in self.state:
state = self.state[p]
state["grad_sum_sq"] = torch.zeros_like(p.data).detach()
state["s"] = torch.zeros_like(p.data).detach()
if self.momentum != 0:
state["x0"] = torch.clone(p.data).detach()
@property
def supports_memory_efficient_fp16(self) -> bool:
return False
@property
def supports_flat_params(self) -> bool:
return True
[docs] def step(self, closure=None, **kwargs) -> Optional[float]:
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
eps = group["eps"]
k = group["k"]
lr = group["lr"] + eps
decay = group["weight_decay"]
momentum = group["momentum"]
ck = 1 - momentum
lamb = lr * math.pow(k + 1, 0.5)
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
if momentum != 0.0 and grad.is_sparse:
raise RuntimeError(
"momentum != 0 is not compatible with sparse gradients"
)
grad_sum_sq = state["grad_sum_sq"]
s = state["s"]
# Apply weight decay
if decay != 0:
if grad.is_sparse:
raise RuntimeError(
"weight_decay option is not compatible with sparse gradients"
)
grad.add_(p.data, alpha=decay)
if grad.is_sparse:
grad = grad.coalesce()
grad_val = grad._values()
p_masked = p.sparse_mask(grad)
grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
s_masked = s.sparse_mask(grad)
# Compute x_0 from other known quantities
rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
x0_masked_vals = p_masked._values().addcdiv(
s_masked._values(), rms_masked_vals, value=1
)
# Dense + sparse op
grad_sq = grad * grad
grad_sum_sq.add_(grad_sq, alpha=lamb)
grad_sum_sq_masked.add_(grad_sq, alpha=lamb)
rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
s.add_(grad, alpha=lamb)
s_masked._values().add_(grad_val, alpha=lamb)
# update masked copy of p
p_kp1_masked_vals = x0_masked_vals.addcdiv(
s_masked._values(), rms_masked_vals, value=-1
)
# Copy updated masked p to dense p using an add operation
p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
p.data.add_(p_masked, alpha=-1)
else:
if momentum == 0:
# Compute x_0 from other known quantities
rms = grad_sum_sq.pow(1 / 3).add_(eps)
x0 = p.data.addcdiv(s, rms, value=1)
else:
x0 = state["x0"]
# Accumulate second moments
grad_sum_sq.addcmul_(grad, grad, value=lamb)
rms = grad_sum_sq.pow(1 / 3).add_(eps)
# Update s
s.data.add_(grad, alpha=lamb)
# Step
if momentum == 0:
p.data.copy_(x0.addcdiv(s, rms, value=-1))
else:
z = x0.addcdiv(s, rms, value=-1)
# p is a moving average of z
p.data.mul_(1 - ck).add_(z, alpha=ck)
group["k"] = group["k"] + 1
return loss
[docs] def add_param_group(self, param_group):
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen
layers can be made trainable and added to the :class:`Optimizer` as
training progresses.
Args:
param_group (dict): Specifies what Tensors should be optimized along
with group specific optimization options.
"""
super().add_param_group(param_group)
self.initialize_state()
[docs] def reset_param_groups(self):
self.param_groups = []
self.state = defaultdict(dict)
[docs] def clip_grad_norm(self, max_norm, model=None):
return Optimizer.clip_grad_norm(self, max_norm, model)