Source code for fairseq.optim.fairseq_optimizer

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from fairseq import utils
from fairseq.dataclass.utils import gen_parser_from_dataclass

[docs]class FairseqOptimizer(object): def __init__(self, cfg): super().__init__() self.cfg = cfg
[docs] @classmethod def add_args(cls, parser): """Add optimizer-specific arguments to the parser.""" dc = getattr(cls, "__dataclass", None) if dc is not None: gen_parser_from_dataclass(parser, dc())
@property def optimizer(self): """Return a torch.optim.optimizer.Optimizer instance.""" if not hasattr(self, "_optimizer"): raise NotImplementedError if not isinstance(self._optimizer, torch.optim.Optimizer): raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") return self._optimizer @optimizer.setter def optimizer(self, optimizer): """Reset optimizer instance.""" if not hasattr(self, "_optimizer"): raise NotImplementedError if not isinstance(self._optimizer, torch.optim.Optimizer): raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") self._optimizer = optimizer @property def optimizer_config(self): """ Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate. """ raise NotImplementedError @property def params(self): """Return an iterable of the parameters held by the optimizer.""" for param_group in self.param_groups: for p in param_group["params"]: yield p @property def param_groups(self): return self.optimizer.param_groups def __getstate__(self): return self._optimizer.__getstate__()
[docs] def get_lr(self): """Return the current learning rate.""" return self.param_groups[0]["lr"]
[docs] def set_lr(self, lr): """Set the learning rate.""" for param_group in self.param_groups: param_group["lr"] = lr
[docs] def state_dict(self): """Return the optimizer's state dict.""" return self.optimizer.state_dict()
[docs] def load_state_dict(self, state_dict, optimizer_overrides=None): """Load an optimizer state dict. In general we should prefer the configuration of the existing optimizer instance (e.g., learning rate) over that found in the state_dict. This allows us to resume training from a checkpoint using a new set of optimizer args. """ self.optimizer.load_state_dict(state_dict) if optimizer_overrides is not None and len(optimizer_overrides) > 0: # override learning rate, momentum, etc. with latest values for group in self.param_groups: group.update(optimizer_overrides)
[docs] def backward(self, loss): """Computes the sum of gradients of the given tensor w.r.t. graph leaves.""" loss.backward()
[docs] def all_reduce_grads(self, module): """Manually all-reduce gradients (if required).""" if hasattr(module, "all_reduce_grads"): module.all_reduce_grads()
[docs] def multiply_grads(self, c): """Multiplies grads by a constant *c*.""" for p in self.params: if p.grad is not None: if torch.is_tensor(c): c =
[docs] def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): """Clips gradient norm.""" return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn)
[docs] def step(self, closure=None, scale=1.0, groups=None): """Performs a single optimization step.""" if self.supports_step_with_scale: if self.supports_groups: self.optimizer.step(closure, scale=scale, groups=groups) else: self.optimizer.step(closure, scale=scale) else: if scale != 1.0: self.multiply_grads(1.0 / scale) if self.supports_groups: self.optimizer.step(closure, groups=groups) else: self.optimizer.step(closure)
[docs] def zero_grad(self): """Clears the gradients of all optimized parameters.""" for p in self.params: p.grad = None self.optimizer.zero_grad()
@property def supports_memory_efficient_fp16(self): if hasattr(self.optimizer, "supports_memory_efficient_fp16"): return self.optimizer.supports_memory_efficient_fp16 return False @property def supports_step_with_scale(self): if hasattr(self.optimizer, "supports_step_with_scale"): return self.optimizer.supports_step_with_scale return False @property def supports_groups(self): if hasattr(self.optimizer, "supports_groups"): return self.optimizer.supports_groups return False @property def supports_flat_params(self): """ Whether the optimizer supports collapsing of the model parameters/gradients into a single contiguous Tensor. """ if hasattr(self.optimizer, "supports_flat_params"): return self.optimizer.supports_flat_params return False
[docs] def average_params(self): pass
[docs] def broadcast_global_state_dict(self, state_dict): """ Broadcasts a global state dict to all ranks. Useful for optimizers that shard state between ranks. """ if hasattr(self.optimizer, "broadcast_global_state_dict"): return self.optimizer.broadcast_global_state_dict(state_dict) else: return state_dict
class LegacyFairseqOptimizer(FairseqOptimizer): def __init__(self, args): self.args = args