Source code for pytext.optimizer.lamb

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Optional

import torch
from pytext.optimizer.optimizers import Optimizer
from torch.optim import Optimizer as PT_Optimizer


[docs]class Lamb(Optimizer, PT_Optimizer): r"""Implements Lamb algorithm. THIS WAS DIRECTLY COPIED OVER FROM pytorch/contrib: https://github.com/cybertronai/pytorch-lamb It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`. https://arxiv.org/abs/1904.00962 Has the option for minimum trust LAMB as described in "Single Headed Attention RNN: Stop Thinking With Your Head" section 6.3 https://arxiv.org/abs/1911.11423 """
[docs] class Config(Optimizer.Config): lr: float = 0.001 weight_decay: float = 0.00001 eps: float = 1e-8 min_trust: Optional[float] = None
[docs] @classmethod def from_config(cls, config: Config, model: torch.nn.Module): return cls( model.parameters(), lr=config.lr, weight_decay=config.weight_decay, eps=config.eps, min_trust=config.min_trust, )
def __init__( self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0, min_trust=None, ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) PT_Optimizer.__init__( self, params, {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay}, ) self.min_trust = min_trust
[docs] def step(self, closure=None, **kwargs): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError( "Lamb does not support sparse gradients, consider SparseAdam instad." ) state = self.state[p] # State initialization if len(state) == 0: state["step"] = 0 # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p.data) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] state["step"] += 1 # Decay the first and second moment running average coefficient # m_t exp_avg.mul_(beta1).add_(1 - beta1, grad) # v_t exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) # Paper v3 does not use debiasing. # bias_correction1 = 1 - beta1 ** state['step'] # bias_correction2 = 1 - beta2 ** state['step'] # Apply bias to lr to avoid broadcast. step_size = group["lr"] # * math.sqrt(bias_correction2) / bias_correction1 weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) adam_step = exp_avg / exp_avg_sq.sqrt().add(group["eps"]) if group["weight_decay"] != 0: adam_step.add_(group["weight_decay"], p.data) adam_norm = adam_step.pow(2).sum().sqrt() if weight_norm == 0 or adam_norm == 0: trust_ratio = 1 else: trust_ratio = weight_norm / adam_norm if self.min_trust: trust_ratio = max(self.min_trust, trust_ratio) state["weight_norm"] = weight_norm state["adam_norm"] = adam_norm state["trust_ratio"] = trust_ratio p.data.add_(-step_size * trust_ratio, adam_step) return loss