Source code for fairseq.optim.lr_scheduler.fairseq_lr_scheduler

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from argparse import Namespace

from fairseq.dataclass.utils import gen_parser_from_dataclass

from .. import FairseqOptimizer


[docs]class FairseqLRScheduler(object): def __init__(self, cfg, optimizer): super().__init__() if not isinstance(optimizer, FairseqOptimizer): raise ValueError("optimizer must be an instance of FairseqOptimizer") self.cfg = cfg self.optimizer = optimizer self.best = None
[docs] @classmethod def add_args(cls, parser): """Add arguments to the parser for this LR scheduler.""" dc = getattr(cls, "__dataclass", None) if dc is not None: gen_parser_from_dataclass(parser, dc())
[docs] def state_dict(self): """Return the LR scheduler state dict.""" return {"best": self.best}
[docs] def load_state_dict(self, state_dict): """Load an LR scheduler state dict.""" self.best = state_dict["best"]
[docs] def step_begin_epoch(self, epoch): """Update the learning rate at the beginning of the given epoch.""" pass
[docs] def step(self, epoch, val_loss=None): """Update the learning rate at the end of the given epoch.""" if val_loss is not None: if self.best is None: self.best = val_loss else: self.best = min(self.best, val_loss)
[docs] def step_update(self, num_updates): """Update the learning rate after each update.""" return self.optimizer.get_lr()
class LegacyFairseqLRScheduler(FairseqLRScheduler): def __init__(self, args: Namespace, optimizer): if not isinstance(optimizer, FairseqOptimizer): raise ValueError("optimizer must be an instance of FairseqOptimizer") self.args = args self.optimizer = optimizer self.best = None