OverviewΒΆ

Fairseq can be extended through user-supplied plug-ins. We support five kinds of plug-ins:

  • Models define the neural network architecture and encapsulate all of the learnable parameters.
  • Criterions compute the loss function given the model outputs and targets.
  • Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss.
  • Optimizers update the Model parameters based on the gradients.
  • Learning Rate Schedulers update the learning rate over the course of training.

Training Flow

Given a model, criterion, task, optimizer and lr_scheduler, fairseq implements the following high-level training flow:

for epoch in range(num_epochs):
    itr = task.get_batch_iterator(task.dataset('train'))
    for num_updates, batch in enumerate(itr):
        task.train_step(batch, model, criterion, optimizer)
        average_and_clip_gradients()
        optimizer.step()
        lr_scheduler.step_update(num_updates)
    lr_scheduler.step(epoch)

where the default implementation for train.train_step is roughly:

def train_step(self, batch, model, criterion, optimizer):
    loss = criterion(model, batch)
    optimizer.backward(loss)

Registering new plug-ins

New plug-ins are registered through a set of @register function decorators, for example:

@register_model('my_lstm')
class MyLSTM(FairseqModel):
    (...)

Once registered, new plug-ins can be used with the existing Command-line Tools. See the Tutorial sections for more detailed walkthroughs of how to add new plug-ins.

Loading plug-ins from another directory

New plug-ins can be defined in a custom module stored in the user system. In order to import the module, and make the plugin available to fairseq, the command line supports the --user-dir flag that can be used to specify a custom location for additional modules to load into fairseq.

For example, assuming this directory tree:

/home/user/my-module/
└── __init__.py

with __init__.py:

from fairseq.models import register_model_architecture
from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big

@register_model_architecture('transformer', 'my_transformer')
def transformer_mmt_big(args):
    transformer_vaswani_wmt_en_de_big(args)

it is possible to invoke the fairseq-train script with the new architecture with:

fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation