#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
from fairseq.utils import clip_grad_norm_
from pytext.config import ConfigBase
from pytext.config.component import Component, ComponentType
[docs]class Optimizer(Component):
__COMPONENT_TYPE__ = ComponentType.OPTIMIZER
__EXPANSIBLE__ = True
[docs] class Config(ConfigBase):
pass
[docs] def backward(self, loss):
loss.backward()
[docs] def clip_grad_norm(self, max_norm, model=None):
if max_norm is None:
"""incase max_norm is none we don't compute clip_grad_norm."""
return None
elif model is None:
"""Some callers are passing max_norm only instead of both the args.
For those we treat model as max_norm.
eg. optimizer.clip_grad_norm(max_norm)
"""
return clip_grad_norm_(self.params, max_norm)
else:
return clip_grad_norm_(model.parameters(), max_norm)
[docs] def pre_export(self, model):
pass
[docs] def finalize(self) -> bool:
return False
[docs] def multiply_grads(self, c):
"""Multiplies grads by a constant *c*."""
for p in self.params:
if p.grad is not None:
p.grad.data.mul_(c)
@property
def params(self):
"""Return an iterable of the parameters held by the optimizer."""
for param_group in self.param_groups:
for p in param_group["params"]:
yield p
[docs] def reset_param_groups(self):
self.param_groups = []
[docs]class Adagrad(torch.optim.Adagrad, Optimizer):
[docs] class Config(Optimizer.Config):
lr: float = 1e-2
weight_decay: float = 0.00001
def __init__(self, parameters, lr, weight_decay):
super().__init__(parameters, lr=lr, weight_decay=weight_decay)
[docs] @classmethod
def from_config(cls, config: Config, model: torch.nn.Module):
return cls(model.parameters(), config.lr, config.weight_decay)
[docs]class Adam(torch.optim.Adam, Optimizer):
[docs] class Config(Optimizer.Config):
lr: float = 0.001
weight_decay: float = 0.00001
eps: float = 1e-8
def __init__(self, parameters, lr, weight_decay, eps):
super().__init__(parameters, lr=lr, weight_decay=weight_decay, eps=eps)
[docs] @classmethod
def from_config(cls, config: Config, model: torch.nn.Module):
return cls(model.parameters(), config.lr, config.weight_decay, config.eps)
[docs]class SGD(torch.optim.SGD, Optimizer):
[docs] class Config(Optimizer.Config):
lr: float = 0.001
momentum: float = 0.0
def __init__(self, parameters, lr, momentum):
super().__init__(parameters, lr=lr, momentum=momentum)
[docs] @classmethod
def from_config(cls, config: Config, model: torch.nn.Module):
return cls(model.parameters(), config.lr, config.momentum)
[docs]class AdamW(torch.optim.AdamW, Optimizer):
"""Adds PyText support for
Decoupled Weight Decay Regularization for Adam as done in the paper:
https://arxiv.org/abs/1711.05101
for more information read the fast.ai blog on this optimization
method here: https://www.fast.ai/2018/07/02/adam-weight-decay/
"""
[docs] class Config(Optimizer.Config):
lr: float = 0.001
weight_decay: float = 1e-2
eps: float = 1e-8
def __init__(self, parameters, lr, weight_decay, eps):
super().__init__(parameters, lr=lr, weight_decay=weight_decay, eps=eps)
[docs] @classmethod
def from_config(cls, config: Config, model: torch.nn.Module):
return cls(model.parameters(), config.lr, config.weight_decay, config.eps)
[docs]def learning_rates(optimizer):
for param_group in optimizer.param_groups:
yield param_group["lr"]