Source code for pytext.optimizer.madgrad

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import math
from collections import defaultdict
from typing import Optional

import torch
from torch.optim.optimizer import Optimizer as PT_Optimizer

from .optimizers import Optimizer


[docs]class MADGRAD(Optimizer, PT_Optimizer): """ `MADGRAD Optimizer`: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization. Paper: https://arxiv.org/abs/2101.11075 Implementation has been copied over from the original author (https://github.com/facebookresearch/madgrad/blob/master/madgrad/madgrad.py) """
[docs] class Config(Optimizer.Config): lr: float = 1e-3 eps: float = 1e-6 momentum: float = 0.9 weight_decay: float = 0.0
r""" Arguments: params (iterable): Iterable of parameters to optimize or dicts defining parameter groups. lr (float): Learning rate (default: 1e-2). momentum (float): Momentum value in the range [0,1) (default: 0.9). weight_decay (float): Weight decay, i.e. a L2 penalty (default: 0). eps (float): Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6). """
[docs] @classmethod def from_config(cls, config: Config, model: torch.nn.Module): return cls( params=model.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay, eps=config.eps, )
def __init__( self, params, lr: float = 1e-2, momentum: float = 0.9, weight_decay: float = 0, eps: float = 1e-6, k: int = 0, ): if momentum < 0 or momentum >= 1: raise ValueError(f"Momentum {momentum} must be in the range [0,1]") if lr <= 0: raise ValueError(f"Learning rate {lr} must be positive") if weight_decay < 0: raise ValueError(f"Weight decay {weight_decay} must be non-negative") if eps < 0: raise ValueError("Eps must be non-negative") defaults = { "lr": lr, "eps": eps, "momentum": momentum, "weight_decay": weight_decay, "k": k, } self.momentum = momentum PT_Optimizer.__init__(self, params, defaults) self.initialize_state()
[docs] def initialize_state(self): for group in self.param_groups: for p in group["params"]: if p not in self.state: state = self.state[p] state["grad_sum_sq"] = torch.zeros_like(p.data).detach() state["s"] = torch.zeros_like(p.data).detach() if self.momentum != 0: state["x0"] = torch.clone(p.data).detach()
@property def supports_memory_efficient_fp16(self) -> bool: return False @property def supports_flat_params(self) -> bool: return True
[docs] def step(self, closure=None, **kwargs) -> Optional[float]: """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: eps = group["eps"] k = group["k"] lr = group["lr"] + eps decay = group["weight_decay"] momentum = group["momentum"] ck = 1 - momentum lamb = lr * math.pow(k + 1, 0.5) for p in group["params"]: if p.grad is None: continue grad = p.grad.data state = self.state[p] if momentum != 0.0 and grad.is_sparse: raise RuntimeError( "momentum != 0 is not compatible with sparse gradients" ) grad_sum_sq = state["grad_sum_sq"] s = state["s"] # Apply weight decay if decay != 0: if grad.is_sparse: raise RuntimeError( "weight_decay option is not compatible with sparse gradients" ) grad.add_(p.data, alpha=decay) if grad.is_sparse: grad = grad.coalesce() grad_val = grad._values() p_masked = p.sparse_mask(grad) grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad) s_masked = s.sparse_mask(grad) # Compute x_0 from other known quantities rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps) x0_masked_vals = p_masked._values().addcdiv( s_masked._values(), rms_masked_vals, value=1 ) # Dense + sparse op grad_sq = grad * grad grad_sum_sq.add_(grad_sq, alpha=lamb) grad_sum_sq_masked.add_(grad_sq, alpha=lamb) rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps) s.add_(grad, alpha=lamb) s_masked._values().add_(grad_val, alpha=lamb) # update masked copy of p p_kp1_masked_vals = x0_masked_vals.addcdiv( s_masked._values(), rms_masked_vals, value=-1 ) # Copy updated masked p to dense p using an add operation p_masked._values().add_(p_kp1_masked_vals, alpha=-1) p.data.add_(p_masked, alpha=-1) else: if momentum == 0: # Compute x_0 from other known quantities rms = grad_sum_sq.pow(1 / 3).add_(eps) x0 = p.data.addcdiv(s, rms, value=1) else: x0 = state["x0"] # Accumulate second moments grad_sum_sq.addcmul_(grad, grad, value=lamb) rms = grad_sum_sq.pow(1 / 3).add_(eps) # Update s s.data.add_(grad, alpha=lamb) # Step if momentum == 0: p.data.copy_(x0.addcdiv(s, rms, value=-1)) else: z = x0.addcdiv(s, rms, value=-1) # p is a moving average of z p.data.mul_(1 - ck).add_(z, alpha=ck) group["k"] = group["k"] + 1 return loss
[docs] def add_param_group(self, param_group): r"""Add a param group to the :class:`Optimizer` s `param_groups`. This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the :class:`Optimizer` as training progresses. Args: param_group (dict): Specifies what Tensors should be optimized along with group specific optimization options. """ super().add_param_group(param_group) self.initialize_state()
[docs] def reset_param_groups(self): self.param_groups = [] self.state = defaultdict(dict)
[docs] def clip_grad_norm(self, max_norm, model=None): return Optimizer.clip_grad_norm(self, max_norm, model)