Source code for fairseq.optim.nag

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch.optim.optimizer import Optimizer, required

from . import FairseqOptimizer, register_optimizer


[docs]@register_optimizer('nag') class FairseqNAG(FairseqOptimizer): def __init__(self, args, params): super().__init__(args, params) self._optimizer = NAG(params, **self.optimizer_config)
[docs] @staticmethod def add_args(parser): """Add optimizer-specific arguments to the parser.""" # fmt: off parser.add_argument('--momentum', default=0.99, type=float, metavar='M', help='momentum factor') parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', help='weight decay')
# fmt: on @property def optimizer_config(self): """ Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate. """ return { 'lr': self.args.lr[0], 'momentum': self.args.momentum, 'weight_decay': self.args.weight_decay, }
class NAG(Optimizer): def __init__(self, params, lr=required, momentum=0, weight_decay=0): defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay) super(NAG, self).__init__(params, defaults) @property def supports_memory_efficient_fp16(self): return True def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum'] lr = group['lr'] lr_old = group.get('lr_old', lr) lr_correct = lr / lr_old for p in group['params']: if p.grad is None: continue p_data_fp32 = p.data.float() d_p = p.grad.data.float() param_state = self.state[p] if 'momentum_buffer' not in param_state: param_state['momentum_buffer'] = torch.zeros_like(d_p) else: param_state['momentum_buffer'] = param_state['momentum_buffer'].type_as(d_p) buf = param_state['momentum_buffer'] if weight_decay != 0: p_data_fp32.mul_(1 - lr * weight_decay) p_data_fp32.add_(momentum * momentum * lr_correct, buf) p_data_fp32.add_(-(1 + momentum) * lr, d_p) buf.mul_(momentum * lr_correct).add_(-lr, d_p) p.data.copy_(p_data_fp32) group['lr_old'] = lr return loss