# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from collections import Collection
from dataclasses import dataclass, field
from typing import List
from fairseq.dataclass import FairseqDataclass
from omegaconf import II, DictConfig
from . import FairseqLRScheduler, register_lr_scheduler
@dataclass
class CosineConfig(FairseqDataclass):
warmup_updates: int = field(
default=0,
metadata={"help": "warmup the learning rate linearly for the first N updates"},
)
warmup_init_lr: float = field(
default=-1,
metadata={
"help": "initial learning rate during warmup phase; default is args.lr"
},
)
max_lr: float = field(
default=1.0, metadata={"help": "max learning rate, must be more than args.lr"}
)
t_mult: float = field(
default=1.0, metadata={"help": "factor to grow the length of each period"}
)
lr_period_updates: float = field(
default=-1, metadata={"help": "initial number of updates per period"}
)
lr_shrink: float = field(
default=0.1, metadata={"help": "shrink factor for annealing"}
)
# TODO common var for parent class
lr: List[float] = II("optimization.lr")
max_update: int = II("optimization.max_update")
[docs]@register_lr_scheduler("cosine", dataclass=CosineConfig)
class CosineSchedule(FairseqLRScheduler):
"""Assign LR based on a cyclical schedule that follows the cosine function.
See https://arxiv.org/pdf/1608.03983.pdf for details.
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (``--warmup-init-lr``) until the configured
max learning rate (``--max-lr``).
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup::
lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))
where ``t_curr`` is current percentage of updates within the current period
range and ``t_i`` is the current period range, which is scaled by ``t_mul``
after every iteration.
"""
def __init__(
self, cfg: DictConfig, fairseq_optimizer
):
super().__init__(cfg, fairseq_optimizer)
if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
raise ValueError(
"Cannot use a fixed learning rate schedule with cosine."
f" Consider --lr-scheduler=fixed instead. ({cfg.lr})"
)
warmup_end_lr = cfg.max_lr
lr = (
cfg.lr[0]
if isinstance(cfg.lr, Collection)
else cfg.lr
)
if cfg.warmup_init_lr < 0:
cfg.warmup_init_lr = lr
self.min_lr = lr
self.max_lr = cfg.max_lr
assert self.max_lr > self.min_lr, "max_lr must be more than lr"
self.t_mult = cfg.t_mult
self.period = cfg.lr_period_updates
if self.period <= 0:
assert (
cfg.max_update >= 0
), "Either --max_update or --lr-period-updates must be set"
self.period = cfg.max_update - cfg.warmup_updates
if cfg.warmup_updates > 0:
# linearly warmup for the first args.warmup_updates
self.lr_step = (
warmup_end_lr - cfg.warmup_init_lr
) / cfg.warmup_updates
else:
self.lr_step = 1
self.warmup_updates = cfg.warmup_updates
self.lr_shrink = cfg.lr_shrink
# initial learning rate
self.lr = cfg.warmup_init_lr
self.optimizer.set_lr(self.lr)
[docs] def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss)
# we don't change the learning rate at epoch boundaries
return self.optimizer.get_lr()
[docs] def step_update(self, num_updates):
"""Update the learning rate after each update."""
if num_updates < self.cfg.warmup_updates:
self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
else:
curr_updates = num_updates - self.cfg.warmup_updates
if self.t_mult != 1:
i = math.floor(
math.log(
1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult
)
)
t_i = self.t_mult ** i * self.period
t_curr = (
curr_updates
- (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period
)
else:
i = math.floor(curr_updates / self.period)
t_i = self.period
t_curr = curr_updates - (self.period * i)
lr_shrink = self.lr_shrink ** i
min_lr = self.min_lr * lr_shrink
max_lr = self.max_lr * lr_shrink
self.lr = min_lr + 0.5 * (max_lr - min_lr) * (
1 + math.cos(math.pi * t_curr / t_i)
)
self.optimizer.set_lr(self.lr)
return self.lr