Source code for pytext.optimizer.optimizers

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
from fairseq.utils import clip_grad_norm_
from pytext.config import ConfigBase
from pytext.config.component import Component, ComponentType


[docs]class Optimizer(Component): __COMPONENT_TYPE__ = ComponentType.OPTIMIZER __EXPANSIBLE__ = True
[docs] class Config(ConfigBase): pass
[docs] def backward(self, loss): loss.backward()
[docs] def clip_grad_norm(self, max_norm, model=None): if max_norm is None: """incase max_norm is none we don't compute clip_grad_norm.""" return None elif model is None: """Some callers are passing max_norm only instead of both the args. For those we treat model as max_norm. eg. optimizer.clip_grad_norm(max_norm) """ return clip_grad_norm_(self.params, max_norm) else: return clip_grad_norm_(model.parameters(), max_norm)
[docs] def pre_export(self, model): pass
[docs] def finalize(self) -> bool: return False
[docs] def multiply_grads(self, c): """Multiplies grads by a constant *c*.""" for p in self.params: if p.grad is not None: p.grad.data.mul_(c)
@property def params(self): """Return an iterable of the parameters held by the optimizer.""" for param_group in self.param_groups: for p in param_group["params"]: yield p
[docs] def reset_param_groups(self): self.param_groups = []
[docs]class Adagrad(torch.optim.Adagrad, Optimizer):
[docs] class Config(Optimizer.Config): lr: float = 1e-2 weight_decay: float = 0.00001
def __init__(self, parameters, lr, weight_decay): super().__init__(parameters, lr=lr, weight_decay=weight_decay)
[docs] @classmethod def from_config(cls, config: Config, model: torch.nn.Module): return cls(model.parameters(), config.lr, config.weight_decay)
[docs]class Adam(torch.optim.Adam, Optimizer):
[docs] class Config(Optimizer.Config): lr: float = 0.001 weight_decay: float = 0.00001 eps: float = 1e-8
def __init__(self, parameters, lr, weight_decay, eps): super().__init__(parameters, lr=lr, weight_decay=weight_decay, eps=eps)
[docs] @classmethod def from_config(cls, config: Config, model: torch.nn.Module): return cls(model.parameters(), config.lr, config.weight_decay, config.eps)
[docs]class SGD(torch.optim.SGD, Optimizer):
[docs] class Config(Optimizer.Config): lr: float = 0.001 momentum: float = 0.0
def __init__(self, parameters, lr, momentum): super().__init__(parameters, lr=lr, momentum=momentum)
[docs] @classmethod def from_config(cls, config: Config, model: torch.nn.Module): return cls(model.parameters(), config.lr, config.momentum)
[docs]class AdamW(torch.optim.AdamW, Optimizer): """Adds PyText support for Decoupled Weight Decay Regularization for Adam as done in the paper: https://arxiv.org/abs/1711.05101 for more information read the fast.ai blog on this optimization method here: https://www.fast.ai/2018/07/02/adam-weight-decay/ """
[docs] class Config(Optimizer.Config): lr: float = 0.001 weight_decay: float = 1e-2 eps: float = 1e-8
def __init__(self, parameters, lr, weight_decay, eps): super().__init__(parameters, lr=lr, weight_decay=weight_decay, eps=eps)
[docs] @classmethod def from_config(cls, config: Config, model: torch.nn.Module): return cls(model.parameters(), config.lr, config.weight_decay, config.eps)
[docs]def learning_rates(optimizer): for param_group in optimizer.param_groups: yield param_group["lr"]