#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Optional
import torch
from pytext.optimizer.optimizers import Optimizer
from torch.optim import Optimizer as PT_Optimizer
[docs]class Lamb(Optimizer, PT_Optimizer):
r"""Implements Lamb algorithm.
THIS WAS DIRECTLY COPIED OVER FROM pytorch/contrib:
https://github.com/cybertronai/pytorch-lamb
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`.
https://arxiv.org/abs/1904.00962
Has the option for minimum trust LAMB as described in "Single Headed
Attention RNN: Stop Thinking With Your Head" section 6.3
https://arxiv.org/abs/1911.11423
"""
[docs] class Config(Optimizer.Config):
lr: float = 0.001
weight_decay: float = 0.00001
eps: float = 1e-8
min_trust: Optional[float] = None
[docs] @classmethod
def from_config(cls, config: Config, model: torch.nn.Module):
return cls(
model.parameters(),
lr=config.lr,
weight_decay=config.weight_decay,
eps=config.eps,
min_trust=config.min_trust,
)
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=0,
min_trust=None,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
PT_Optimizer.__init__(
self,
params,
{"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay},
)
self.min_trust = min_trust
[docs] def step(self, closure=None, **kwargs):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
"Lamb does not support sparse gradients, consider SparseAdam instad."
)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# m_t
exp_avg.mul_(beta1).add_(1 - beta1, grad)
# v_t
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
# Paper v3 does not use debiasing.
# bias_correction1 = 1 - beta1 ** state['step']
# bias_correction2 = 1 - beta2 ** state['step']
# Apply bias to lr to avoid broadcast.
step_size = group["lr"]
# * math.sqrt(bias_correction2) / bias_correction1
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
adam_step = exp_avg / exp_avg_sq.sqrt().add(group["eps"])
if group["weight_decay"] != 0:
adam_step.add_(group["weight_decay"], p.data)
adam_norm = adam_step.pow(2).sum().sqrt()
if weight_norm == 0 or adam_norm == 0:
trust_ratio = 1
else:
trust_ratio = weight_norm / adam_norm
if self.min_trust:
trust_ratio = max(self.min_trust, trust_ratio)
state["weight_norm"] = weight_norm
state["adam_norm"] = adam_norm
state["trust_ratio"] = trust_ratio
p.data.add_(-step_size * trust_ratio, adam_step)
return loss