Source code for fairseq.optim.fairseq_optimizer

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch


[docs]class FairseqOptimizer(object): def __init__(self, args, params): super().__init__() self.args = args self.params = list(params)
[docs] @staticmethod def add_args(parser): """Add optimizer-specific arguments to the parser.""" pass
@property def optimizer(self): """Return a torch.optim.optimizer.Optimizer instance.""" if not hasattr(self, '_optimizer'): raise NotImplementedError if not isinstance(self._optimizer, torch.optim.Optimizer): raise ValueError('_optimizer must be an instance of torch.optim.Optimizer') return self._optimizer @property def optimizer_config(self): """ Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate. """ raise NotImplementedError def __getstate__(self): return self._optimizer.__getstate__()
[docs] def get_lr(self): """Return the current learning rate.""" return self.optimizer.param_groups[0]['lr']
[docs] def set_lr(self, lr): """Set the learning rate.""" for param_group in self.optimizer.param_groups: param_group['lr'] = lr
[docs] def state_dict(self): """Return the optimizer's state dict.""" return self.optimizer.state_dict()
[docs] def load_state_dict(self, state_dict, optimizer_overrides=None): """Load an optimizer state dict. In general we should prefer the configuration of the existing optimizer instance (e.g., learning rate) over that found in the state_dict. This allows us to resume training from a checkpoint using a new set of optimizer args. """ self.optimizer.load_state_dict(state_dict) if optimizer_overrides is not None and len(optimizer_overrides) > 0: # override learning rate, momentum, etc. with latest values for group in self.optimizer.param_groups: group.update(optimizer_overrides)
[docs] def backward(self, loss): """Computes the sum of gradients of the given tensor w.r.t. graph leaves.""" loss.backward()
[docs] def multiply_grads(self, c): """Multiplies grads by a constant *c*.""" for p in self.params: if p.grad is not None: p.grad.data.mul_(c)
[docs] def clip_grad_norm(self, max_norm): """Clips gradient norm.""" if max_norm > 0: return torch.nn.utils.clip_grad_norm_(self.params, max_norm) else: return math.sqrt(sum(p.grad.data.norm()**2 for p in self.params if p.grad is not None))
[docs] def step(self, closure=None): """Performs a single optimization step.""" self.optimizer.step(closure)
[docs] def zero_grad(self): """Clears the gradients of all optimized parameters.""" for group in self.optimizer.param_groups: for p in group['params']: p.grad = None self.optimizer.zero_grad()
@property def supports_memory_efficient_fp16(self): if hasattr(self.optimizer, 'supports_memory_efficient_fp16'): return self.optimizer.supports_memory_efficient_fp16 return False