fairseq documentation

Fairseq is a sequence modeling toolkit written in PyTorch that allows researchers and developers to train custom models for translation, summarization, language modeling and other text generation tasks.

Evaluating Pre-trained Models

First, download a pre-trained model along with its vocabularies:

> curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -

This model uses a Byte Pair Encoding (BPE) vocabulary, so we’ll have to apply the encoding to the source text before it can be translated. This can be done with the apply_bpe.py script using the wmt14.en-fr.fconv-cuda/bpecodes file. @@ is used as a continuation marker and the original text can be easily recovered with e.g. sed s/@@ //g or by passing the --remove-bpe flag to fairseq-generate. Prior to BPE, input text needs to be tokenized using tokenizer.perl from mosesdecoder.

Let’s use fairseq-interactive to generate translations interactively. Here, we use a beam size of 5:

> MODEL_DIR=wmt14.en-fr.fconv-py
> fairseq-interactive \
    --path $MODEL_DIR/model.pt $MODEL_DIR \
    --beam 5 --source-lang en --target-lang fr
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types
| [fr] dictionary: 44463 types
| Type the input sentence and press return:
> Why is it rare to discover new marine mam@@ mal species ?
O       Why is it rare to discover new marine mam@@ mal species ?
H       -0.1525060087442398     Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ?
P       -0.2221 -0.3122 -0.1289 -0.2673 -0.1711 -0.1930 -0.1101 -0.1660 -0.1003 -0.0740 -0.1101 -0.0814 -0.1238 -0.0985 -0.1288

This generation script produces three types of outputs: a line prefixed with O is a copy of the original source sentence; H is the hypothesis along with an average log-likelihood; and P is the positional score per token position, including the end-of-sentence marker which is omitted from the text.

See the README for a full list of pre-trained models available.

Training a New Model

The following tutorial is for machine translation. For an example of how to use Fairseq for other tasks, such as Language Modeling, please see the examples/ directory.

Data Pre-processing

Fairseq contains example pre-processing scripts for several translation datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT 2014 (English-German). To pre-process and binarize the IWSLT dataset:

> cd examples/translation/
> bash prepare-iwslt14.sh
> cd ../..
> TEXT=examples/translation/iwslt14.tokenized.de-en
> fairseq-preprocess --source-lang de --target-lang en \
    --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
    --destdir data-bin/iwslt14.tokenized.de-en

This will write binarized data that can be used for model training to data-bin/iwslt14.tokenized.de-en.

Training

Use fairseq-train to train a new model. Here a few example settings that work well for the IWSLT 2014 dataset:

> mkdir -p checkpoints/fconv
> CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
    --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
    --arch fconv_iwslt_de_en --save-dir checkpoints/fconv

By default, fairseq-train will use all available GPUs on your machine. Use the CUDA_VISIBLE_DEVICES environment variable to select specific GPUs and/or to change the number of GPU devices that will be used.

Also note that the batch size is specified in terms of the maximum number of tokens per batch (--max-tokens). You may need to use a smaller value depending on the available GPU memory on your system.

Generation

Once your model is trained, you can generate translations using fairseq-generate (for binarized data) or fairseq-interactive (for raw text):

> fairseq-generate data-bin/iwslt14.tokenized.de-en \
    --path checkpoints/fconv/checkpoint_best.pt \
    --batch-size 128 --beam 5
| [de] dictionary: 35475 types
| [en] dictionary: 24739 types
| data-bin/iwslt14.tokenized.de-en test 6750 examples
| model fconv
| loaded checkpoint trainings/fconv/checkpoint_best.pt
S-721   danke .
T-721   thank you .
...

To generate translations with only a CPU, use the --cpu flag. BPE continuation markers can be removed with the --remove-bpe flag.

Advanced Training Options

Large mini-batch training with delayed updates

The --update-freq option can be used to accumulate gradients from multiple mini-batches and delay updating, creating a larger effective batch size. Delayed updates can also improve training speed by reducing inter-GPU communication costs and by saving idle time caused by variance in workload across GPUs. See Ott et al. (2018) for more details.

To train on a single GPU with an effective batch size that is equivalent to training on 8 GPUs:

> CUDA_VISIBLE_DEVICES=0 fairseq-train --update-freq 8 (...)

Training with half precision floating point (FP16)

Note

FP16 training requires a Volta GPU and CUDA 9.1 or greater

Recent GPUs enable efficient half precision floating point computation, e.g., using Nvidia Tensor Cores. Fairseq supports FP16 training with the --fp16 flag:

> fairseq-train --fp16 (...)

Lazily loading large training datasets

By default fairseq loads the entire training set into system memory. For large datasets, the --lazy-load option can be used to instead load batches on-demand. For optimal performance, use the --num-workers option to control the number of background processes that will load batches.

Distributed training

Distributed training in fairseq is implemented on top of torch.distributed. The easiest way to launch jobs is with the torch.distributed.launch tool.

For example, to train a large English-German Transformer model on 2 nodes each with 8 GPUs (in total 16 GPUs), run the following command on each node, replacing node_rank=0 with node_rank=1 on the second node:

> python -m torch.distributed.launch --nproc_per_node=8 \
    --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" \
    --master_port=1234 \
    $(which fairseq-train) data-bin/wmt16_en_de_bpe32k \
    --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
    --lr 0.0005 --min-lr 1e-09 \
    --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 3584 \
    --fp16

Command-line Tools

Fairseq provides several command-line tools for training and evaluating models:

fairseq-preprocess

Data pre-processing: build vocabularies and binarize training data.

usage: fairseq-preprocess [-h] [--no-progress-bar] [--log-interval N]
                          [--log-format {json,none,simple,tqdm}]
                          [--tensorboard-logdir DIR] [--tbmf-wrapper]
                          [--seed N] [--cpu] [--fp16]
                          [--memory-efficient-fp16]
                          [--fp16-init-scale FP16_INIT_SCALE]
                          [--fp16-scale-window FP16_SCALE_WINDOW]
                          [--fp16-scale-tolerance FP16_SCALE_TOLERANCE]
                          [--min-loss-scale D]
                          [--threshold-loss-scale THRESHOLD_LOSS_SCALE]
                          [--user-dir USER_DIR]
                          [--criterion {adaptive_loss,label_smoothed_cross_entropy,composite_loss,masked_lm_loss,cross_entropy}]
                          [--optimizer {adadelta,adam,adafactor,adagrad,nag,lamb,sgd}]
                          [--lr-scheduler {cosine,reduce_lr_on_plateau,fixed,triangular,polynomial_decay,inverse_sqrt}]
                          [--task TASK] [-s SRC] [-t TARGET] [--trainpref FP]
                          [--validpref FP] [--testpref FP] [--destdir DIR]
                          [--thresholdtgt N] [--thresholdsrc N] [--tgtdict FP]
                          [--srcdict FP] [--nwordstgt N] [--nwordssrc N]
                          [--alignfile ALIGN] [--dataset-impl FORMAT]
                          [--joined-dictionary] [--only-source]
                          [--padding-factor N] [--workers N]

Named Arguments

--no-progress-bar

disable progress bar

Default: False

--log-interval

log progress every N batches (when progress bar is disabled)

Default: 1000

--log-format

Possible choices: json, none, simple, tqdm

log format to use

--tensorboard-logdir

path to save logs for tensorboard, should match –logdir of running tensorboard (default: no tensorboard logging)

Default: “”

--tbmf-wrapper

[FB only]

Default: False

--seed

pseudo random number generator seed

Default: 1

--cpu

use CPU instead of CUDA

Default: False

--fp16

use FP16

Default: False

--memory-efficient-fp16

use a memory-efficient version of FP16 training; implies –fp16

Default: False

--fp16-init-scale

default FP16 loss scale

Default: 128

--fp16-scale-window number of updates before increasing loss scale
--fp16-scale-tolerance

pct of updates that can overflow before decreasing the loss scale

Default: 0.0

--min-loss-scale

minimum FP16 loss scale, after which training is stopped

Default: 0.0001

--threshold-loss-scale threshold FP16 loss scale from below
--user-dir path to a python module containing custom extensions (tasks and/or architectures)
--criterion

Possible choices: adaptive_loss, label_smoothed_cross_entropy, composite_loss, masked_lm_loss, cross_entropy

Default: “cross_entropy”

--optimizer

Possible choices: adadelta, adam, adafactor, adagrad, nag, lamb, sgd

Default: “nag”

--lr-scheduler

Possible choices: cosine, reduce_lr_on_plateau, fixed, triangular, polynomial_decay, inverse_sqrt

Default: “fixed”

--task

Possible choices: translation, translation_from_pretrained_xlm, multilingual_translation, semisupervised_translation, cross_lingual_lm, masked_lm, translation_moe, language_modeling

task

Default: “translation”

--dataset-impl

Possible choices: raw, lazy, cached, mmap

output dataset implementation

Default: “cached”

Preprocessing

-s, --source-lang source language
-t, --target-lang target language
--trainpref train file prefix
--validpref comma separated, valid file prefixes
--testpref comma separated, test file prefixes
--destdir

destination dir

Default: “data-bin”

--thresholdtgt

map words appearing less than threshold times to unknown

Default: 0

--thresholdsrc

map words appearing less than threshold times to unknown

Default: 0

--tgtdict reuse given target dictionary
--srcdict reuse given source dictionary
--nwordstgt

number of target words to retain

Default: -1

--nwordssrc

number of source words to retain

Default: -1

--alignfile an alignment file (optional)
--joined-dictionary

Generate joined dictionary

Default: False

--only-source

Only process the source language

Default: False

--padding-factor

Pad dictionary size to be multiple of N

Default: 8

--workers

number of parallel workers

Default: 1

fairseq-train

Train a new model on one or across multiple GPUs.

usage: fairseq-train [-h] [--no-progress-bar] [--log-interval N]
                     [--log-format {json,none,simple,tqdm}]
                     [--tensorboard-logdir DIR] [--tbmf-wrapper] [--seed N]
                     [--cpu] [--fp16] [--memory-efficient-fp16]
                     [--fp16-init-scale FP16_INIT_SCALE]
                     [--fp16-scale-window FP16_SCALE_WINDOW]
                     [--fp16-scale-tolerance FP16_SCALE_TOLERANCE]
                     [--min-loss-scale D]
                     [--threshold-loss-scale THRESHOLD_LOSS_SCALE]
                     [--user-dir USER_DIR]
                     [--criterion {adaptive_loss,label_smoothed_cross_entropy,composite_loss,masked_lm_loss,cross_entropy}]
                     [--optimizer {adadelta,adam,adafactor,adagrad,nag,lamb,sgd}]
                     [--lr-scheduler {cosine,reduce_lr_on_plateau,fixed,triangular,polynomial_decay,inverse_sqrt}]
                     [--task TASK] [--num-workers N]
                     [--skip-invalid-size-inputs-valid-test] [--max-tokens N]
                     [--max-sentences N] [--required-batch-size-multiple N]
                     [--dataset-impl FORMAT] [--train-subset SPLIT]
                     [--valid-subset SPLIT] [--validate-interval N]
                     [--disable-validation] [--max-sentences-valid N]
                     [--curriculum N] [--distributed-world-size N]
                     [--distributed-rank DISTRIBUTED_RANK]
                     [--distributed-backend DISTRIBUTED_BACKEND]
                     [--distributed-init-method DISTRIBUTED_INIT_METHOD]
                     [--distributed-port DISTRIBUTED_PORT]
                     [--device-id DEVICE_ID] [--distributed-no-spawn]
                     [--ddp-backend {c10d,no_c10d}] [--bucket-cap-mb MB]
                     [--fix-batches-to-gpus] [--find-unused-parameters] --arch
                     ARCH [--max-epoch N] [--max-update N] [--clip-norm NORM]
                     [--sentence-avg] [--update-freq N1,N2,...,N_K]
                     [--lr LR_1,LR_2,...,LR_N] [--min-lr LR] [--use-bmuf]
                     [--global-sync-iter GLOBAL_SYNC_ITER] [--save-dir DIR]
                     [--restore-file RESTORE_FILE] [--reset-dataloader]
                     [--reset-lr-scheduler] [--reset-meters]
                     [--reset-optimizer] [--optimizer-overrides DICT]
                     [--save-interval N] [--save-interval-updates N]
                     [--keep-interval-updates N] [--keep-last-epochs N]
                     [--no-save] [--no-epoch-checkpoints]

Named Arguments

--no-progress-bar

disable progress bar

Default: False

--log-interval

log progress every N batches (when progress bar is disabled)

Default: 1000

--log-format

Possible choices: json, none, simple, tqdm

log format to use

--tensorboard-logdir

path to save logs for tensorboard, should match –logdir of running tensorboard (default: no tensorboard logging)

Default: “”

--tbmf-wrapper

[FB only]

Default: False

--seed

pseudo random number generator seed

Default: 1

--cpu

use CPU instead of CUDA

Default: False

--fp16

use FP16

Default: False

--memory-efficient-fp16

use a memory-efficient version of FP16 training; implies –fp16

Default: False

--fp16-init-scale

default FP16 loss scale

Default: 128

--fp16-scale-window number of updates before increasing loss scale
--fp16-scale-tolerance

pct of updates that can overflow before decreasing the loss scale

Default: 0.0

--min-loss-scale

minimum FP16 loss scale, after which training is stopped

Default: 0.0001

--threshold-loss-scale threshold FP16 loss scale from below
--user-dir path to a python module containing custom extensions (tasks and/or architectures)
--criterion

Possible choices: adaptive_loss, label_smoothed_cross_entropy, composite_loss, masked_lm_loss, cross_entropy

Default: “cross_entropy”

--optimizer

Possible choices: adadelta, adam, adafactor, adagrad, nag, lamb, sgd

Default: “nag”

--lr-scheduler

Possible choices: cosine, reduce_lr_on_plateau, fixed, triangular, polynomial_decay, inverse_sqrt

Default: “fixed”

--task

Possible choices: translation, translation_from_pretrained_xlm, multilingual_translation, semisupervised_translation, cross_lingual_lm, masked_lm, translation_moe, language_modeling

task

Default: “translation”

--dataset-impl

Possible choices: raw, lazy, cached, mmap

output dataset implementation

Default: “cached”

Dataset and data loading

--num-workers

how many subprocesses to use for data loading

Default: 0

--skip-invalid-size-inputs-valid-test

ignore too long or too short lines in valid and test set

Default: False

--max-tokens maximum number of tokens in a batch
--max-sentences, --batch-size maximum number of sentences in a batch
--required-batch-size-multiple

batch size will be a multiplier of this value

Default: 8

--train-subset

Possible choices: train, valid, test

data subset to use for training (train, valid, test)

Default: “train”

--valid-subset

comma separated list of data subsets to use for validation (train, valid, valid1, test, test1)

Default: “valid”

--validate-interval

validate every N epochs

Default: 1

--disable-validation

disable validation

Default: False

--max-sentences-valid maximum number of sentences in a validation batch (defaults to –max-sentences)
--curriculum

don’t shuffle batches for first N epochs

Default: 0

Distributed training

--distributed-world-size

total number of GPUs across all nodes (default: all visible GPUs)

Default: 1

--distributed-rank

rank of the current worker

Default: 0

--distributed-backend

distributed backend

Default: “nccl”

--distributed-init-method typically tcp://hostname:port that will be used to establish initial connetion
--distributed-port

port number (not required if using –distributed-init-method)

Default: -1

--device-id, --local_rank

which GPU to use (usually configured automatically)

Default: 0

--distributed-no-spawn

do not spawn multiple processes even if multiple GPUs are visible

Default: False

--ddp-backend

Possible choices: c10d, no_c10d

DistributedDataParallel backend

Default: “c10d”

--bucket-cap-mb

bucket size for reduction

Default: 25

--fix-batches-to-gpus

don’t shuffle batches between GPUs; this reduces overall randomness and may affect precision but avoids the cost of re-reading the data

Default: False

--find-unused-parameters

disable unused parameter detection (not applicable to no_c10d ddp-backend

Default: False

Model configuration

--arch, -a

Possible choices: transformer, transformer_iwslt_de_en, transformer_wmt_en_de, transformer_vaswani_wmt_en_de_big, transformer_vaswani_wmt_en_fr_big, transformer_wmt_en_de_big, transformer_wmt_en_de_big_t2t, transformer_from_pretrained_xlm, transformer_lm, transformer_lm_big, transformer_lm_baevski_wiki103, transformer_lm_wiki103, transformer_lm_baevski_gbw, transformer_lm_gbw, transformer_lm_gpt, transformer_lm_gpt2_small, transformer_lm_gpt2_medium, transformer_lm_gpt2_big, lightconv, lightconv_iwslt_de_en, lightconv_wmt_en_de, lightconv_wmt_en_de_big, lightconv_wmt_en_fr_big, lightconv_wmt_zh_en_big, masked_lm, bert_base, bert_large, xlm_base, fconv, fconv_iwslt_de_en, fconv_wmt_en_ro, fconv_wmt_en_de, fconv_wmt_en_fr, fconv_lm, fconv_lm_dauphin_wikitext103, fconv_lm_dauphin_gbw, lightconv_lm, lightconv_lm_gbw, fconv_self_att, fconv_self_att_wp, lstm, lstm_wiseman_iwslt_de_en, lstm_luong_wmt_en_de, multilingual_transformer, multilingual_transformer_iwslt_de_en

Model Architecture

Default: “fconv”

Optimization

--max-epoch, --me

force stop training at specified epoch

Default: 0

--max-update, --mu

force stop training at specified update

Default: 0

--clip-norm

clip threshold of gradients

Default: 25

--sentence-avg

normalize gradients by the number of sentences in a batch (default is to normalize by number of tokens)

Default: False

--update-freq

update parameters every N_i batches, when in epoch i

Default: 1

--lr, --learning-rate

learning rate for the first N epochs; all epochs >N using LR_N (note: this may be interpreted differently depending on –lr-scheduler)

Default: 0.25

--min-lr

stop training when the learning rate reaches this minimum

Default: -1

--use-bmuf

specify global optimizer for syncing models on different GPUs/Shards

Default: False

--global-sync-iter

Iteration for syncing global model

Default: 10

Checkpointing

--save-dir

path to save checkpoints

Default: “checkpoints”

--restore-file

filename in save-dir from which to load checkpoint

Default: “checkpoint_last.pt”

--reset-dataloader

if set, does not reload dataloader state from the checkpoint

Default: False

--reset-lr-scheduler

if set, does not load lr scheduler state from the checkpoint

Default: False

--reset-meters

if set, does not load meters from the checkpoint

Default: False

--reset-optimizer

if set, does not load optimizer state from the checkpoint

Default: False

--optimizer-overrides

a dictionary used to override optimizer args when loading a checkpoint

Default: “{}”

--save-interval

save a checkpoint every N epochs

Default: 1

--save-interval-updates

save a checkpoint (and validate) every N updates

Default: 0

--keep-interval-updates

keep the last N checkpoints saved with –save-interval-updates

Default: -1

--keep-last-epochs

keep last N epoch checkpoints

Default: -1

--no-save

don’t save models or checkpoints

Default: False

--no-epoch-checkpoints

only store last and best checkpoints

Default: False

fairseq-generate

fairseq-interactive

Translate raw text with a trained model. Batches data on-the-fly.

usage: fairseq-interactive [-h] [--no-progress-bar] [--log-interval N]
                           [--log-format {json,none,simple,tqdm}]
                           [--tensorboard-logdir DIR] [--tbmf-wrapper]
                           [--seed N] [--cpu] [--fp16]
                           [--memory-efficient-fp16]
                           [--fp16-init-scale FP16_INIT_SCALE]
                           [--fp16-scale-window FP16_SCALE_WINDOW]
                           [--fp16-scale-tolerance FP16_SCALE_TOLERANCE]
                           [--min-loss-scale D]
                           [--threshold-loss-scale THRESHOLD_LOSS_SCALE]
                           [--user-dir USER_DIR]
                           [--criterion {adaptive_loss,label_smoothed_cross_entropy,composite_loss,masked_lm_loss,cross_entropy}]
                           [--optimizer {adadelta,adam,adafactor,adagrad,nag,lamb,sgd}]
                           [--lr-scheduler {cosine,reduce_lr_on_plateau,fixed,triangular,polynomial_decay,inverse_sqrt}]
                           [--task TASK] [--num-workers N]
                           [--skip-invalid-size-inputs-valid-test]
                           [--max-tokens N] [--max-sentences N]
                           [--required-batch-size-multiple N]
                           [--dataset-impl FORMAT] [--gen-subset SPLIT]
                           [--num-shards N] [--shard-id ID] [--path FILE]
                           [--remove-bpe [REMOVE_BPE]] [--quiet]
                           [--model-overrides DICT] [--results-path RESDIR]
                           [--beam N] [--nbest N] [--max-len-a N]
                           [--max-len-b N] [--min-len N] [--match-source-len]
                           [--no-early-stop] [--unnormalized]
                           [--no-beamable-mm] [--lenpen LENPEN]
                           [--unkpen UNKPEN] [--replace-unk [REPLACE_UNK]]
                           [--sacrebleu] [--score-reference]
                           [--prefix-size PS] [--no-repeat-ngram-size N]
                           [--sampling] [--sampling-topk PS] [--temperature N]
                           [--diverse-beam-groups N]
                           [--diverse-beam-strength N] [--print-alignment]
                           [--buffer-size N] [--input FILE]

Named Arguments

--no-progress-bar

disable progress bar

Default: False

--log-interval

log progress every N batches (when progress bar is disabled)

Default: 1000

--log-format

Possible choices: json, none, simple, tqdm

log format to use

--tensorboard-logdir

path to save logs for tensorboard, should match –logdir of running tensorboard (default: no tensorboard logging)

Default: “”

--tbmf-wrapper

[FB only]

Default: False

--seed

pseudo random number generator seed

Default: 1

--cpu

use CPU instead of CUDA

Default: False

--fp16

use FP16

Default: False

--memory-efficient-fp16

use a memory-efficient version of FP16 training; implies –fp16

Default: False

--fp16-init-scale

default FP16 loss scale

Default: 128

--fp16-scale-window number of updates before increasing loss scale
--fp16-scale-tolerance

pct of updates that can overflow before decreasing the loss scale

Default: 0.0

--min-loss-scale

minimum FP16 loss scale, after which training is stopped

Default: 0.0001

--threshold-loss-scale threshold FP16 loss scale from below
--user-dir path to a python module containing custom extensions (tasks and/or architectures)
--criterion

Possible choices: adaptive_loss, label_smoothed_cross_entropy, composite_loss, masked_lm_loss, cross_entropy

Default: “cross_entropy”

--optimizer

Possible choices: adadelta, adam, adafactor, adagrad, nag, lamb, sgd

Default: “nag”

--lr-scheduler

Possible choices: cosine, reduce_lr_on_plateau, fixed, triangular, polynomial_decay, inverse_sqrt

Default: “fixed”

--task

Possible choices: translation, translation_from_pretrained_xlm, multilingual_translation, semisupervised_translation, cross_lingual_lm, masked_lm, translation_moe, language_modeling

task

Default: “translation”

--dataset-impl

Possible choices: raw, lazy, cached, mmap

output dataset implementation

Default: “cached”

Dataset and data loading

--num-workers

how many subprocesses to use for data loading

Default: 0

--skip-invalid-size-inputs-valid-test

ignore too long or too short lines in valid and test set

Default: False

--max-tokens maximum number of tokens in a batch
--max-sentences, --batch-size maximum number of sentences in a batch
--required-batch-size-multiple

batch size will be a multiplier of this value

Default: 8

--gen-subset

data subset to generate (train, valid, test)

Default: “test”

--num-shards

shard generation over N shards

Default: 1

--shard-id

id of the shard to generate (id < num_shards)

Default: 0

Generation

--path path(s) to model file(s), colon separated
--remove-bpe remove BPE tokens before scoring (can be set to sentencepiece)
--quiet

only print final scores

Default: False

--model-overrides

a dictionary used to override model args at generation that were used during model training

Default: “{}”

--results-path path to save eval results (optional)”
--beam

beam size

Default: 5

--nbest

number of hypotheses to output

Default: 1

--max-len-a

generate sequences of maximum length ax + b, where x is the source length

Default: 0

--max-len-b

generate sequences of maximum length ax + b, where x is the source length

Default: 200

--min-len

minimum generation length

Default: 1

--match-source-len

generations should match the source length

Default: False

--no-early-stop

continue searching even after finalizing k=beam hypotheses; this is more correct, but increases generation time by 50%

Default: False

--unnormalized

compare unnormalized hypothesis scores

Default: False

--no-beamable-mm

don’t use BeamableMM in attention layers

Default: False

--lenpen

length penalty: <1.0 favors shorter, >1.0 favors longer sentences

Default: 1

--unkpen

unknown word penalty: <0 produces more unks, >0 produces fewer

Default: 0

--replace-unk perform unknown replacement (optionally with alignment dictionary)
--sacrebleu

score with sacrebleu

Default: False

--score-reference

just score the reference translation

Default: False

--prefix-size

initialize generation by target prefix of given length

Default: 0

--no-repeat-ngram-size

ngram blocking such that this size ngram cannot be repeated in the generation

Default: 0

--sampling

sample hypotheses instead of using beam search

Default: False

--sampling-topk

sample from top K likely next words instead of all words

Default: -1

--temperature

temperature for generation

Default: 1.0

--diverse-beam-groups

number of groups for Diverse Beam Search

Default: -1

--diverse-beam-strength

strength of diversity penalty for Diverse Beam Search

Default: 0.5

--print-alignment

if set, uses attention feedback to compute and print alignment to source tokens

Default: False

Interactive

--buffer-size

read this many sentences into a buffer before processing them

Default: 0

--input

file to read from; use - for stdin

Default: “-“

fairseq-score

fairseq-eval-lm

Evaluate the perplexity of a trained language model.

usage: fairseq-eval-lm [-h] [--no-progress-bar] [--log-interval N]
                       [--log-format {json,none,simple,tqdm}]
                       [--tensorboard-logdir DIR] [--tbmf-wrapper] [--seed N]
                       [--cpu] [--fp16] [--memory-efficient-fp16]
                       [--fp16-init-scale FP16_INIT_SCALE]
                       [--fp16-scale-window FP16_SCALE_WINDOW]
                       [--fp16-scale-tolerance FP16_SCALE_TOLERANCE]
                       [--min-loss-scale D]
                       [--threshold-loss-scale THRESHOLD_LOSS_SCALE]
                       [--user-dir USER_DIR]
                       [--criterion {adaptive_loss,label_smoothed_cross_entropy,composite_loss,masked_lm_loss,cross_entropy}]
                       [--optimizer {adadelta,adam,adafactor,adagrad,nag,lamb,sgd}]
                       [--lr-scheduler {cosine,reduce_lr_on_plateau,fixed,triangular,polynomial_decay,inverse_sqrt}]
                       [--task TASK] [--num-workers N]
                       [--skip-invalid-size-inputs-valid-test]
                       [--max-tokens N] [--max-sentences N]
                       [--required-batch-size-multiple N]
                       [--dataset-impl FORMAT] [--gen-subset SPLIT]
                       [--num-shards N] [--shard-id ID] [--path FILE]
                       [--remove-bpe [REMOVE_BPE]] [--quiet]
                       [--model-overrides DICT] [--results-path RESDIR]
                       [--output-word-probs] [--output-word-stats]
                       [--context-window N] [--softmax-batch N]

Named Arguments

--no-progress-bar

disable progress bar

Default: False

--log-interval

log progress every N batches (when progress bar is disabled)

Default: 1000

--log-format

Possible choices: json, none, simple, tqdm

log format to use

--tensorboard-logdir

path to save logs for tensorboard, should match –logdir of running tensorboard (default: no tensorboard logging)

Default: “”

--tbmf-wrapper

[FB only]

Default: False

--seed

pseudo random number generator seed

Default: 1

--cpu

use CPU instead of CUDA

Default: False

--fp16

use FP16

Default: False

--memory-efficient-fp16

use a memory-efficient version of FP16 training; implies –fp16

Default: False

--fp16-init-scale

default FP16 loss scale

Default: 128

--fp16-scale-window number of updates before increasing loss scale
--fp16-scale-tolerance

pct of updates that can overflow before decreasing the loss scale

Default: 0.0

--min-loss-scale

minimum FP16 loss scale, after which training is stopped

Default: 0.0001

--threshold-loss-scale threshold FP16 loss scale from below
--user-dir path to a python module containing custom extensions (tasks and/or architectures)
--criterion

Possible choices: adaptive_loss, label_smoothed_cross_entropy, composite_loss, masked_lm_loss, cross_entropy

Default: “cross_entropy”

--optimizer

Possible choices: adadelta, adam, adafactor, adagrad, nag, lamb, sgd

Default: “nag”

--lr-scheduler

Possible choices: cosine, reduce_lr_on_plateau, fixed, triangular, polynomial_decay, inverse_sqrt

Default: “fixed”

--task

Possible choices: translation, translation_from_pretrained_xlm, multilingual_translation, semisupervised_translation, cross_lingual_lm, masked_lm, translation_moe, language_modeling

task

Default: “language_modeling”

--dataset-impl

Possible choices: raw, lazy, cached, mmap

output dataset implementation

Default: “cached”

Dataset and data loading

--num-workers

how many subprocesses to use for data loading

Default: 0

--skip-invalid-size-inputs-valid-test

ignore too long or too short lines in valid and test set

Default: False

--max-tokens maximum number of tokens in a batch
--max-sentences, --batch-size maximum number of sentences in a batch
--required-batch-size-multiple

batch size will be a multiplier of this value

Default: 8

--gen-subset

data subset to generate (train, valid, test)

Default: “test”

--num-shards

shard generation over N shards

Default: 1

--shard-id

id of the shard to generate (id < num_shards)

Default: 0

LM Evaluation

--path path(s) to model file(s), colon separated
--remove-bpe remove BPE tokens before scoring (can be set to sentencepiece)
--quiet

only print final scores

Default: False

--model-overrides

a dictionary used to override model args at generation that were used during model training

Default: “{}”

--results-path path to save eval results (optional)”
--output-word-probs

if set, outputs words and their predicted log probabilities to standard output

Default: False

--output-word-stats

if set, outputs word statistics such as word count, average probability, etc

Default: False

--context-window

ensures that every evaluated token has access to a context of at least this size, if possible

Default: 0

--softmax-batch

if BxT is more than this, will batch the softmax over vocab to this amount of tokens in order to fit into GPU memory

Default: 9223372036854775807

Overview

Fairseq can be extended through user-supplied plug-ins. We support five kinds of plug-ins:

  • Models define the neural network architecture and encapsulate all of the learnable parameters.
  • Criterions compute the loss function given the model outputs and targets.
  • Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss.
  • Optimizers update the Model parameters based on the gradients.
  • Learning Rate Schedulers update the learning rate over the course of training.

Training Flow

Given a model, criterion, task, optimizer and lr_scheduler, fairseq implements the following high-level training flow:

for epoch in range(num_epochs):
    itr = task.get_batch_iterator(task.dataset('train'))
    for num_updates, batch in enumerate(itr):
        task.train_step(batch, model, criterion, optimizer)
        average_and_clip_gradients()
        optimizer.step()
        lr_scheduler.step_update(num_updates)
    lr_scheduler.step(epoch)

where the default implementation for task.train_step is roughly:

def train_step(self, batch, model, criterion, optimizer):
    loss = criterion(model, batch)
    optimizer.backward(loss)
    return loss

Registering new plug-ins

New plug-ins are registered through a set of @register function decorators, for example:

@register_model('my_lstm')
class MyLSTM(FairseqEncoderDecoderModel):
    (...)

Once registered, new plug-ins can be used with the existing Command-line Tools. See the Tutorial sections for more detailed walkthroughs of how to add new plug-ins.

Loading plug-ins from another directory

New plug-ins can be defined in a custom module stored in the user system. In order to import the module, and make the plugin available to fairseq, the command line supports the --user-dir flag that can be used to specify a custom location for additional modules to load into fairseq.

For example, assuming this directory tree:

/home/user/my-module/
└── __init__.py

with __init__.py:

from fairseq.models import register_model_architecture
from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big

@register_model_architecture('transformer', 'my_transformer')
def transformer_mmt_big(args):
    transformer_vaswani_wmt_en_de_big(args)

it is possible to invoke the fairseq-train script with the new architecture with:

fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation

Tutorial: Simple LSTM

In this tutorial we will extend fairseq by adding a new FairseqEncoderDecoderModel that encodes a source sentence with an LSTM and then passes the final hidden state to a second LSTM that decodes the target sentence (without attention).

This tutorial covers:

  1. Writing an Encoder and Decoder to encode/decode the source/target sentence, respectively.
  2. Registering a new Model so that it can be used with the existing Command-line Tools.
  3. Training the Model using the existing command-line tools.
  4. Making generation faster by modifying the Decoder to use Incremental decoding.

1. Building an Encoder and Decoder

In this section we’ll define a simple LSTM Encoder and Decoder. All Encoders should implement the FairseqEncoder interface and Decoders should implement the FairseqDecoder interface. These interfaces themselves extend torch.nn.Module, so FairseqEncoders and FairseqDecoders can be written and used in the same ways as ordinary PyTorch Modules.

Encoder

Our Encoder will embed the tokens in the source sentence, feed them to a torch.nn.LSTM and return the final hidden state. To create our encoder save the following in a new file named fairseq/models/simple_lstm.py:

import torch.nn as nn
from fairseq import utils
from fairseq.models import FairseqEncoder

class SimpleLSTMEncoder(FairseqEncoder):

    def __init__(
        self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,
    ):
        super().__init__(dictionary)
        self.args = args

        # Our encoder will embed the inputs before feeding them to the LSTM.
        self.embed_tokens = nn.Embedding(
            num_embeddings=len(dictionary),
            embedding_dim=embed_dim,
            padding_idx=dictionary.pad(),
        )
        self.dropout = nn.Dropout(p=dropout)

        # We'll use a single-layer, unidirectional LSTM for simplicity.
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=1,
            bidirectional=False,
        )

    def forward(self, src_tokens, src_lengths):
        # The inputs to the ``forward()`` function are determined by the
        # Task, and in particular the ``'net_input'`` key in each
        # mini-batch. We discuss Tasks in the next tutorial, but for now just
        # know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*
        # has shape `(batch)`.

        # Note that the source is typically padded on the left. This can be
        # configured by adding the `--left-pad-source "False"` command-line
        # argument, but here we'll make the Encoder handle either kind of
        # padding by converting everything to be right-padded.
        if self.args.left_pad_source:
            # Convert left-padding to right-padding.
            src_tokens = utils.convert_padding_direction(
                src_tokens,
                padding_idx=self.dictionary.pad(),
                left_to_right=True
            )

        # Embed the source.
        x = self.embed_tokens(src_tokens)

        # Apply dropout.
        x = self.dropout(x)

        # Pack the sequence into a PackedSequence object to feed to the LSTM.
        x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)

        # Get the output from the LSTM.
        _outputs, (final_hidden, _final_cell) = self.lstm(x)

        # Return the Encoder's output. This can be any object and will be
        # passed directly to the Decoder.
        return {
            # this will have shape `(bsz, hidden_dim)`
            'final_hidden': final_hidden.squeeze(0),
        }

    # Encoders are required to implement this method so that we can rearrange
    # the order of the batch elements during inference (e.g., beam search).
    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to `new_order`.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            `encoder_out` rearranged according to `new_order`
        """
        final_hidden = encoder_out['final_hidden']
        return {
            'final_hidden': final_hidden.index_select(0, new_order),
        }

Decoder

Our Decoder will predict the next word, conditioned on the Encoder’s final hidden state and an embedded representation of the previous target word – which is sometimes called input feeding or teacher forcing. More specifically, we’ll use a torch.nn.LSTM to produce a sequence of hidden states that we’ll project to the size of the output vocabulary to predict each target word.

import torch
from fairseq.models import FairseqDecoder

class SimpleLSTMDecoder(FairseqDecoder):

    def __init__(
        self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
        dropout=0.1,
    ):
        super().__init__(dictionary)

        # Our decoder will embed the inputs before feeding them to the LSTM.
        self.embed_tokens = nn.Embedding(
            num_embeddings=len(dictionary),
            embedding_dim=embed_dim,
            padding_idx=dictionary.pad(),
        )
        self.dropout = nn.Dropout(p=dropout)

        # We'll use a single-layer, unidirectional LSTM for simplicity.
        self.lstm = nn.LSTM(
            # For the first layer we'll concatenate the Encoder's final hidden
            # state with the embedded target tokens.
            input_size=encoder_hidden_dim + embed_dim,
            hidden_size=hidden_dim,
            num_layers=1,
            bidirectional=False,
        )

        # Define the output projection.
        self.output_projection = nn.Linear(hidden_dim, len(dictionary))

    # During training Decoders are expected to take the entire target sequence
    # (shifted right by one position) and produce logits over the vocabulary.
    # The *prev_output_tokens* tensor begins with the end-of-sentence symbol,
    # ``dictionary.eos()``, followed by the target sequence.
    def forward(self, prev_output_tokens, encoder_out):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for input feeding/teacher forcing
            encoder_out (Tensor, optional): output from the encoder, used for
                encoder-side attention

        Returns:
            tuple:
                - the last decoder layer's output of shape
                  `(batch, tgt_len, vocab)`
                - the last decoder layer's attention weights of shape
                  `(batch, tgt_len, src_len)`
        """
        bsz, tgt_len = prev_output_tokens.size()

        # Extract the final hidden state from the Encoder.
        final_encoder_hidden = encoder_out['final_hidden']

        # Embed the target sequence, which has been shifted right by one
        # position and now starts with the end-of-sentence symbol.
        x = self.embed_tokens(prev_output_tokens)

        # Apply dropout.
        x = self.dropout(x)

        # Concatenate the Encoder's final hidden state to *every* embedded
        # target token.
        x = torch.cat(
            [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
            dim=2,
        )

        # Using PackedSequence objects in the Decoder is harder than in the
        # Encoder, since the targets are not sorted in descending length order,
        # which is a requirement of ``pack_padded_sequence()``. Instead we'll
        # feed nn.LSTM directly.
        initial_state = (
            final_encoder_hidden.unsqueeze(0),  # hidden
            torch.zeros_like(final_encoder_hidden).unsqueeze(0),  # cell
        )
        output, _ = self.lstm(
            x.transpose(0, 1),  # convert to shape `(tgt_len, bsz, dim)`
            initial_state,
        )
        x = output.transpose(0, 1)  # convert to shape `(bsz, tgt_len, hidden)`

        # Project the outputs to the size of the vocabulary.
        x = self.output_projection(x)

        # Return the logits and ``None`` for the attention weights
        return x, None

2. Registering the Model

Now that we’ve defined our Encoder and Decoder we must register our model with fairseq using the register_model() function decorator. Once the model is registered we’ll be able to use it with the existing Command-line Tools.

All registered models must implement the BaseFairseqModel interface. For sequence-to-sequence models (i.e., any model with a single Encoder and Decoder), we can instead implement the FairseqEncoderDecoderModel interface.

Create a small wrapper class in the same file and register it in fairseq with the name 'simple_lstm':

from fairseq.models import FairseqEncoderDecoderModel, register_model

# Note: the register_model "decorator" should immediately precede the
# definition of the Model class.

@register_model('simple_lstm')
class SimpleLSTMModel(FairseqEncoderDecoderModel):

    @staticmethod
    def add_args(parser):
        # Models can override this method to add new command-line arguments.
        # Here we'll add some new command-line arguments to configure dropout
        # and the dimensionality of the embeddings and hidden states.
        parser.add_argument(
            '--encoder-embed-dim', type=int, metavar='N',
            help='dimensionality of the encoder embeddings',
        )
        parser.add_argument(
            '--encoder-hidden-dim', type=int, metavar='N',
            help='dimensionality of the encoder hidden state',
        )
        parser.add_argument(
            '--encoder-dropout', type=float, default=0.1,
            help='encoder dropout probability',
        )
        parser.add_argument(
            '--decoder-embed-dim', type=int, metavar='N',
            help='dimensionality of the decoder embeddings',
        )
        parser.add_argument(
            '--decoder-hidden-dim', type=int, metavar='N',
            help='dimensionality of the decoder hidden state',
        )
        parser.add_argument(
            '--decoder-dropout', type=float, default=0.1,
            help='decoder dropout probability',
        )

    @classmethod
    def build_model(cls, args, task):
        # Fairseq initializes models by calling the ``build_model()``
        # function. This provides more flexibility, since the returned model
        # instance can be of a different type than the one that was called.
        # In this case we'll just return a SimpleLSTMModel instance.

        # Initialize our Encoder and Decoder.
        encoder = SimpleLSTMEncoder(
            args=args,
            dictionary=task.source_dictionary,
            embed_dim=args.encoder_embed_dim,
            hidden_dim=args.encoder_hidden_dim,
            dropout=args.encoder_dropout,
        )
        decoder = SimpleLSTMDecoder(
            dictionary=task.target_dictionary,
            encoder_hidden_dim=args.encoder_hidden_dim,
            embed_dim=args.decoder_embed_dim,
            hidden_dim=args.decoder_hidden_dim,
            dropout=args.decoder_dropout,
        )
        model = SimpleLSTMModel(encoder, decoder)

        # Print the model architecture.
        print(model)

        return model

    # We could override the ``forward()`` if we wanted more control over how
    # the encoder and decoder interact, but it's not necessary for this
    # tutorial since we can inherit the default implementation provided by
    # the FairseqEncoderDecoderModel base class, which looks like:
    #
    # def forward(self, src_tokens, src_lengths, prev_output_tokens):
    #     encoder_out = self.encoder(src_tokens, src_lengths)
    #     decoder_out = self.decoder(prev_output_tokens, encoder_out)
    #     return decoder_out

Finally let’s define a named architecture with the configuration for our model. This is done with the register_model_architecture() function decorator. Thereafter this named architecture can be used with the --arch command-line argument, e.g., --arch tutorial_simple_lstm:

from fairseq.models import register_model_architecture

# The first argument to ``register_model_architecture()`` should be the name
# of the model we registered above (i.e., 'simple_lstm'). The function we
# register here should take a single argument *args* and modify it in-place
# to match the desired architecture.

@register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
def tutorial_simple_lstm(args):
    # We use ``getattr()`` to prioritize arguments that are explicitly given
    # on the command-line, so that the defaults defined below are only used
    # when no other value has been specified.
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
    args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
    args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)

3. Training the Model

Now we’re ready to train the model. We can use the existing fairseq-train command-line tool for this, making sure to specify our new Model architecture (--arch tutorial_simple_lstm).

Note

Make sure you’ve already preprocessed the data from the IWSLT example in the examples/translation/ directory.

> fairseq-train data-bin/iwslt14.tokenized.de-en \
  --arch tutorial_simple_lstm \
  --encoder-dropout 0.2 --decoder-dropout 0.2 \
  --optimizer adam --lr 0.005 --lr-shrink 0.5 \
  --max-tokens 12000
(...)
| epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396
| epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954

The model files should appear in the checkpoints/ directory. While this model architecture is not very good, we can use the fairseq-generate script to generate translations and compute our BLEU score over the test set:

> fairseq-generate data-bin/iwslt14.tokenized.de-en \
  --path checkpoints/checkpoint_best.pt \
  --beam 5 \
  --remove-bpe
(...)
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)

4. Making generation faster

While autoregressive generation from sequence-to-sequence models is inherently slow, our implementation above is especially slow because it recomputes the entire sequence of Decoder hidden states for every output token (i.e., it is O(n^2)). We can make this significantly faster by instead caching the previous hidden states.

In fairseq this is called Incremental decoding. Incremental decoding is a special mode at inference time where the Model only receives a single timestep of input corresponding to the immediately previous output token (for input feeding) and must produce the next output incrementally. Thus the model must cache any long-term state that is needed about the sequence, e.g., hidden states, convolutional states, etc.

To implement incremental decoding we will modify our model to implement the FairseqIncrementalDecoder interface. Compared to the standard FairseqDecoder interface, the incremental decoder interface allows forward() methods to take an extra keyword argument (incremental_state) that can be used to cache state across time-steps.

Let’s replace our SimpleLSTMDecoder with an incremental one:

import torch
from fairseq.models import FairseqIncrementalDecoder

class SimpleLSTMDecoder(FairseqIncrementalDecoder):

    def __init__(
        self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
        dropout=0.1,
    ):
        # This remains the same as before.
        super().__init__(dictionary)
        self.embed_tokens = nn.Embedding(
            num_embeddings=len(dictionary),
            embedding_dim=embed_dim,
            padding_idx=dictionary.pad(),
        )
        self.dropout = nn.Dropout(p=dropout)
        self.lstm = nn.LSTM(
            input_size=encoder_hidden_dim + embed_dim,
            hidden_size=hidden_dim,
            num_layers=1,
            bidirectional=False,
        )
        self.output_projection = nn.Linear(hidden_dim, len(dictionary))

    # We now take an additional kwarg (*incremental_state*) for caching the
    # previous hidden and cell states.
    def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        if incremental_state is not None:
            # If the *incremental_state* argument is not ``None`` then we are
            # in incremental inference mode. While *prev_output_tokens* will
            # still contain the entire decoded prefix, we will only use the
            # last step and assume that the rest of the state is cached.
            prev_output_tokens = prev_output_tokens[:, -1:]

        # This remains the same as before.
        bsz, tgt_len = prev_output_tokens.size()
        final_encoder_hidden = encoder_out['final_hidden']
        x = self.embed_tokens(prev_output_tokens)
        x = self.dropout(x)
        x = torch.cat(
            [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
            dim=2,
        )

        # We will now check the cache and load the cached previous hidden and
        # cell states, if they exist, otherwise we will initialize them to
        # zeros (as before). We will use the ``utils.get_incremental_state()``
        # and ``utils.set_incremental_state()`` helpers.
        initial_state = utils.get_incremental_state(
            self, incremental_state, 'prev_state',
        )
        if initial_state is None:
            # first time initialization, same as the original version
            initial_state = (
                final_encoder_hidden.unsqueeze(0),  # hidden
                torch.zeros_like(final_encoder_hidden).unsqueeze(0),  # cell
            )

        # Run one step of our LSTM.
        output, latest_state = self.lstm(x.transpose(0, 1), initial_state)

        # Update the cache with the latest hidden and cell states.
        utils.set_incremental_state(
            self, incremental_state, 'prev_state', latest_state,
        )

        # This remains the same as before
        x = output.transpose(0, 1)
        x = self.output_projection(x)
        return x, None

    # The ``FairseqIncrementalDecoder`` interface also requires implementing a
    # ``reorder_incremental_state()`` method, which is used during beam search
    # to select and reorder the incremental state.
    def reorder_incremental_state(self, incremental_state, new_order):
        # Load the cached state.
        prev_state = utils.get_incremental_state(
            self, incremental_state, 'prev_state',
        )

        # Reorder batches according to *new_order*.
        reordered_state = (
            prev_state[0].index_select(1, new_order),  # hidden
            prev_state[1].index_select(1, new_order),  # cell
        )

        # Update the cached state.
        utils.set_incremental_state(
            self, incremental_state, 'prev_state', reordered_state,
        )

Finally, we can rerun generation and observe the speedup:

# Before

> fairseq-generate data-bin/iwslt14.tokenized.de-en \
  --path checkpoints/checkpoint_best.pt \
  --beam 5 \
  --remove-bpe
(...)
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)

# After

> fairseq-generate data-bin/iwslt14.tokenized.de-en \
  --path checkpoints/checkpoint_best.pt \
  --beam 5 \
  --remove-bpe
(...)
| Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s)
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)

Tutorial: Classifying Names with a Character-Level RNN

In this tutorial we will extend fairseq to support classification tasks. In particular we will re-implement the PyTorch tutorial for Classifying Names with a Character-Level RNN in fairseq. It is recommended to quickly skim that tutorial before beginning this one.

This tutorial covers:

  1. Preprocessing the data to create dictionaries.
  2. Registering a new Model that encodes an input sentence with a simple RNN and predicts the output label.
  3. Registering a new Task that loads our dictionaries and dataset.
  4. Training the Model using the existing command-line tools.
  5. Writing an evaluation script that imports fairseq and allows us to interactively evaluate our model on new inputs.

1. Preprocessing the data

The original tutorial provides raw data, but we’ll work with a modified version of the data that is already tokenized into characters and split into separate train, valid and test sets.

Download and extract the data from here: tutorial_names.tar.gz

Once extracted, let’s preprocess the data using the fairseq-preprocess command-line tool to create the dictionaries. While this tool is primarily intended for sequence-to-sequence problems, we’re able to reuse it here by treating the label as a “target” sequence of length 1. We’ll also output the preprocessed files in “raw” format using the --dataset-impl option to enhance readability:

> fairseq-preprocess \
  --trainpref names/train --validpref names/valid --testpref names/test \
  --source-lang input --target-lang label \
  --destdir names-bin --dataset-impl raw

After running the above command you should see a new directory, names-bin/, containing the dictionaries for inputs and labels.

2. Registering a new Model

Next we’ll register a new model in fairseq that will encode an input sentence with a simple RNN and predict the output label. Compared to the original PyTorch tutorial, our version will also work with batches of data and GPU Tensors.

First let’s copy the simple RNN module implemented in the PyTorch tutorial. Create a new file named fairseq/models/rnn_classifier.py with the following contents:

import torch
import torch.nn as nn

class RNN(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

We must also register this model with fairseq using the register_model() function decorator. Once the model is registered we’ll be able to use it with the existing Command-line Tools.

All registered models must implement the BaseFairseqModel interface, so we’ll create a small wrapper class in the same file and register it in fairseq with the name 'rnn_classifier':

from fairseq.models import BaseFairseqModel, register_model

# Note: the register_model "decorator" should immediately precede the
# definition of the Model class.

@register_model('rnn_classifier')
class FairseqRNNClassifier(BaseFairseqModel):

    @staticmethod
    def add_args(parser):
        # Models can override this method to add new command-line arguments.
        # Here we'll add a new command-line argument to configure the
        # dimensionality of the hidden state.
        parser.add_argument(
            '--hidden-dim', type=int, metavar='N',
            help='dimensionality of the hidden state',
        )

    @classmethod
    def build_model(cls, args, task):
        # Fairseq initializes models by calling the ``build_model()``
        # function. This provides more flexibility, since the returned model
        # instance can be of a different type than the one that was called.
        # In this case we'll just return a FairseqRNNClassifier instance.

        # Initialize our RNN module
        rnn = RNN(
            # We'll define the Task in the next section, but for now just
            # notice that the task holds the dictionaries for the "source"
            # (i.e., the input sentence) and "target" (i.e., the label).
            input_size=len(task.source_dictionary),
            hidden_size=args.hidden_dim,
            output_size=len(task.target_dictionary),
        )

        # Return the wrapped version of the module
        return FairseqRNNClassifier(
            rnn=rnn,
            input_vocab=task.source_dictionary,
        )

    def __init__(self, rnn, input_vocab):
        super(FairseqRNNClassifier, self).__init__()

        self.rnn = rnn
        self.input_vocab = input_vocab

        # The RNN module in the tutorial expects one-hot inputs, so we can
        # precompute the identity matrix to help convert from indices to
        # one-hot vectors. We register it as a buffer so that it is moved to
        # the GPU when ``cuda()`` is called.
        self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab)))

    def forward(self, src_tokens, src_lengths):
        # The inputs to the ``forward()`` function are determined by the
        # Task, and in particular the ``'net_input'`` key in each
        # mini-batch. We'll define the Task in the next section, but for
        # now just know that *src_tokens* has shape `(batch, src_len)` and
        # *src_lengths* has shape `(batch)`.
        bsz, max_src_len = src_tokens.size()

        # Initialize the RNN hidden state. Compared to the original PyTorch
        # tutorial we'll also handle batched inputs and work on the GPU.
        hidden = self.rnn.initHidden()
        hidden = hidden.repeat(bsz, 1)  # expand for batched inputs
        hidden = hidden.to(src_tokens.device)  # move to GPU

        for i in range(max_src_len):
            # WARNING: The inputs have padding, so we should mask those
            # elements here so that padding doesn't affect the results.
            # This is left as an exercise for the reader. The padding symbol
            # is given by ``self.input_vocab.pad()`` and the unpadded length
            # of each input is given by *src_lengths*.

            # One-hot encode a batch of input characters.
            input = self.one_hot_inputs[src_tokens[:, i].long()]

            # Feed the input to our RNN.
            output, hidden = self.rnn(input, hidden)

        # Return the final output state for making a prediction
        return output

Finally let’s define a named architecture with the configuration for our model. This is done with the register_model_architecture() function decorator. Thereafter this named architecture can be used with the --arch command-line argument, e.g., --arch pytorch_tutorial_rnn:

from fairseq.models import register_model_architecture

# The first argument to ``register_model_architecture()`` should be the name
# of the model we registered above (i.e., 'rnn_classifier'). The function we
# register here should take a single argument *args* and modify it in-place
# to match the desired architecture.

@register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn')
def pytorch_tutorial_rnn(args):
    # We use ``getattr()`` to prioritize arguments that are explicitly given
    # on the command-line, so that the defaults defined below are only used
    # when no other value has been specified.
    args.hidden_dim = getattr(args, 'hidden_dim', 128)

3. Registering a new Task

Now we’ll register a new FairseqTask that will load our dictionaries and dataset. Tasks can also control how the data is batched into mini-batches, but in this tutorial we’ll reuse the batching provided by fairseq.data.LanguagePairDataset.

Create a new file named fairseq/tasks/simple_classification.py with the following contents:

import os
import torch

from fairseq.data import Dictionary, LanguagePairDataset
from fairseq.tasks import FairseqTask, register_task


@register_task('simple_classification')
class SimpleClassificationTask(FairseqTask):

    @staticmethod
    def add_args(parser):
        # Add some command-line arguments for specifying where the data is
        # located and the maximum supported input length.
        parser.add_argument('data', metavar='FILE',
                            help='file prefix for data')
        parser.add_argument('--max-positions', default=1024, type=int,
                            help='max input length')

    @classmethod
    def setup_task(cls, args, **kwargs):
        # Here we can perform any setup required for the task. This may include
        # loading Dictionaries, initializing shared Embedding layers, etc.
        # In this case we'll just load the Dictionaries.
        input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt'))
        label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt'))
        print('| [input] dictionary: {} types'.format(len(input_vocab)))
        print('| [label] dictionary: {} types'.format(len(label_vocab)))

        return SimpleClassificationTask(args, input_vocab, label_vocab)

    def __init__(self, args, input_vocab, label_vocab):
        super().__init__(args)
        self.input_vocab = input_vocab
        self.label_vocab = label_vocab

    def load_dataset(self, split, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""

        prefix = os.path.join(self.args.data, '{}.input-label'.format(split))

        # Read input sentences.
        sentences, lengths = [], []
        with open(prefix + '.input', encoding='utf-8') as file:
            for line in file:
                sentence = line.strip()

                # Tokenize the sentence, splitting on spaces
                tokens = self.input_vocab.encode_line(
                    sentence, add_if_not_exist=False,
                )

                sentences.append(tokens)
                lengths.append(tokens.numel())

        # Read labels.
        labels = []
        with open(prefix + '.label', encoding='utf-8') as file:
            for line in file:
                label = line.strip()
                labels.append(
                    # Convert label to a numeric ID.
                    torch.LongTensor([self.label_vocab.add_symbol(label)])
                )

        assert len(sentences) == len(labels)
        print('| {} {} {} examples'.format(self.args.data, split, len(sentences)))

        # We reuse LanguagePairDataset since classification can be modeled as a
        # sequence-to-sequence task where the target sequence has length 1.
        self.datasets[split] = LanguagePairDataset(
            src=sentences,
            src_sizes=lengths,
            src_dict=self.input_vocab,
            tgt=labels,
            tgt_sizes=torch.ones(len(labels)),  # targets have length 1
            tgt_dict=self.label_vocab,
            left_pad_source=False,
            max_source_positions=self.args.max_positions,
            max_target_positions=1,
            # Since our target is a single class label, there's no need for
            # input feeding. If we set this to ``True`` then our Model's
            # ``forward()`` method would receive an additional argument called
            # *prev_output_tokens* that would contain a shifted version of the
            # target sequence.
            input_feeding=False,
        )

    def max_positions(self):
        """Return the max input length allowed by the task."""
        # The source should be less than *args.max_positions* and the "target"
        # has max length 1.
        return (self.args.max_positions, 1)

    @property
    def source_dictionary(self):
        """Return the source :class:`~fairseq.data.Dictionary`."""
        return self.input_vocab

    @property
    def target_dictionary(self):
        """Return the target :class:`~fairseq.data.Dictionary`."""
        return self.label_vocab

    # We could override this method if we wanted more control over how batches
    # are constructed, but it's not necessary for this tutorial since we can
    # reuse the batching provided by LanguagePairDataset.
    #
    # def get_batch_iterator(
    #     self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
    #     ignore_invalid_inputs=False, required_batch_size_multiple=1,
    #     seed=1, num_shards=1, shard_id=0,
    # ):
    #     (...)

4. Training the Model

Now we’re ready to train the model. We can use the existing fairseq-train command-line tool for this, making sure to specify our new Task (--task simple_classification) and Model architecture (--arch pytorch_tutorial_rnn):

Note

You can also configure the dimensionality of the hidden state by passing the --hidden-dim argument to fairseq-train.

> fairseq-train names-bin \
  --task simple_classification \
  --arch pytorch_tutorial_rnn \
  --optimizer adam --lr 0.001 --lr-shrink 0.5 \
  --max-tokens 1000
(...)
| epoch 027 | loss 1.200 | ppl 2.30 | wps 15728 | ups 119.4 | wpb 116 | bsz 116 | num_updates 3726 | lr 1.5625e-05 | gnorm 1.290 | clip 0% | oom 0 | wall 32 | train_wall 21
| epoch 027 | valid on 'valid' subset | valid_loss 1.41304 | valid_ppl 2.66 | num_updates 3726 | best 1.41208
| done training in 31.6 seconds

The model files should appear in the checkpoints/ directory.

5. Writing an evaluation script

Finally we can write a short script to evaluate our model on new inputs. Create a new file named eval_classifier.py with the following contents:

from fairseq import checkpoint_utils, data, options, tasks

# Parse command-line arguments for generation
parser = options.get_generation_parser(default_task='simple_classification')
args = options.parse_args_and_arch(parser)

# Setup task
task = tasks.setup_task(args)

# Load model
print('| loading model from {}'.format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
model = models[0]

while True:
    sentence = input('\nInput: ')

    # Tokenize into characters
    chars = ' '.join(list(sentence.strip()))
    tokens = task.source_dictionary.encode_line(
        chars, add_if_not_exist=False,
    )

    # Build mini-batch to feed to the model
    batch = data.language_pair_dataset.collate(
        samples=[{'id': -1, 'source': tokens}],  # bsz = 1
        pad_idx=task.source_dictionary.pad(),
        eos_idx=task.source_dictionary.eos(),
        left_pad_source=False,
        input_feeding=False,
    )

    # Feed batch to the model and get predictions
    preds = model(**batch['net_input'])

    # Print top 3 predictions and their log-probabilities
    top_scores, top_labels = preds[0].topk(k=3)
    for score, label_idx in zip(top_scores, top_labels):
        label_name = task.target_dictionary.string([label_idx])
        print('({:.2f})\t{}'.format(score, label_name))

Now we can evaluate our model interactively. Note that we have included the original data path (names-bin/) so that the dictionaries can be loaded:

> python eval_classifier.py names-bin --path checkpoints/checkpoint_best.pt
| [input] dictionary: 64 types
| [label] dictionary: 24 types
| loading model from checkpoints/checkpoint_best.pt

Input: Satoshi
(-0.61) Japanese
(-1.20) Arabic
(-2.86) Italian

Input: Sinbad
(-0.30) Arabic
(-1.76) English
(-4.08) Russian

Tasks

Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss.

Tasks can be selected via the --task command-line argument. Once selected, a task may expose additional command-line arguments for further configuration.

Example usage:

# setup the task (e.g., load dictionaries)
task = fairseq.tasks.setup_task(args)

# build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)

# load datasets
task.load_dataset('train')
task.load_dataset('valid')

# iterate over mini-batches of data
batch_itr = task.get_batch_iterator(
    task.dataset('train'), max_tokens=4096,
)
for batch in batch_itr:
    # compute the loss
    loss, sample_size, logging_output = task.get_loss(
        model, criterion, batch,
    )
    loss.backward()

Translation

class fairseq.tasks.translation.TranslationTask(args, src_dict, tgt_dict)[source]

Translate from one (source) language to another (target) language.

Parameters:
  • src_dict (Dictionary) – dictionary for the source language
  • tgt_dict (Dictionary) – dictionary for the target language

Note

The translation task is compatible with fairseq-train, fairseq-generate and fairseq-interactive.

The translation task provides the following additional command-line arguments:

usage:  [--task translation] [-s SRC] [-t TARGET] [--lazy-load] [--raw-text]
        [--left-pad-source BOOL] [--left-pad-target BOOL]
        [--max-source-positions N] [--max-target-positions N]
        [--upsample-primary UPSAMPLE_PRIMARY]
        data

Task name

--task Enable this task with: --task=translation

Additional command-line arguments

data colon separated path to data directories list, will be iterated upon during epochs in round-robin manner
-s, --source-lang source language
-t, --target-lang target language
--lazy-load

load the dataset lazily

Default: False

--raw-text

load raw text dataset

Default: False

--left-pad-source

pad the source on the left

Default: “True”

--left-pad-target

pad the target on the left

Default: “False”

--max-source-positions

max number of tokens in the source sequence

Default: 1024

--max-target-positions

max number of tokens in the target sequence

Default: 1024

--upsample-primary

amount to upsample primary dataset

Default: 1

Language Modeling

class fairseq.tasks.language_modeling.LanguageModelingTask(args, dictionary, output_dictionary=None, targets=None)[source]

Train a language model.

Parameters:
  • dictionary (Dictionary) – the dictionary for the input of the language model
  • output_dictionary (Dictionary) – the dictionary for the output of the language model. In most cases it will be the same as dictionary, but could possibly be a more limited version of the dictionary (if --output-dictionary-size is used).
  • targets (List[str]) – list of the target types that the language model should predict. Can be one of “self”, “future”, and “past”. Defaults to “future”.

Note

The language modeling task is compatible with fairseq-train, fairseq-generate, fairseq-interactive and fairseq-eval-lm.

The language modeling task provides the following additional command-line arguments:

usage:  [--task language_modeling] [--sample-break-mode {none,complete,eos}]
        [--tokens-per-sample TOKENS_PER_SAMPLE] [--lazy-load] [--raw-text]
        [--output-dictionary-size OUTPUT_DICTIONARY_SIZE] [--self-target]
        [--future-target] [--past-target] [--add-bos-token]
        [--max-target-positions N]
        data

Task name

--task Enable this task with: --task=language_modeling

Additional command-line arguments

data path to data directory
--sample-break-mode

Possible choices: none, complete, eos

If omitted or “none”, fills each sample with tokens-per-sample tokens. If set to “complete”, splits samples only at the end of sentence, but may include multiple sentences per sample. If set to “eos”, includes only one sentence per sample.

--tokens-per-sample

max number of tokens per sample for LM dataset

Default: 1024

--lazy-load

load the dataset lazily

Default: False

--raw-text

load raw text dataset

Default: False

--output-dictionary-size

limit the size of output dictionary

Default: -1

--self-target

include self target

Default: False

--future-target

include future target

Default: False

--past-target

include past target

Default: False

--add-bos-token

prepend beginning of sentence token (<s>)

Default: False

--max-target-positions max number of tokens in the target sequence

Adding new tasks

fairseq.tasks.register_task(name)[source]

New tasks can be added to fairseq with the register_task() function decorator.

For example:

@register_task('classification')
class ClassificationTask(FairseqTask):
    (...)

Note

All Tasks must implement the FairseqTask interface.

Please see the

Parameters:name (str) – the name of the task
class fairseq.tasks.FairseqTask(args)[source]

Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss.

static add_args(parser)[source]

Add task-specific arguments to the parser.

aggregate_logging_outputs(logging_outputs, criterion)[source]
build_criterion(args)[source]

Build the FairseqCriterion instance for this task.

Parameters:args (argparse.Namespace) – parsed command-line arguments
Returns:a FairseqCriterion instance
classmethod build_dictionary(filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8)[source]

Build the dictionary

Parameters:
  • filenames (list) – list of filenames
  • workers (int) – number of concurrent workers
  • threshold (int) – defines the minimum word count
  • nwords (int) – defines the total number of words in the final dictionary, including special symbols
  • padding_factor (int) – can be used to pad the dictionary size to be a multiple of 8, which is important on some hardware (e.g., Nvidia Tensor Cores).
build_generator(args)[source]
build_model(args)[source]

Build the BaseFairseqModel instance for this task.

Parameters:args (argparse.Namespace) – parsed command-line arguments
Returns:a BaseFairseqModel instance
dataset(split)[source]

Return a loaded dataset split.

Parameters:split (str) – name of the split (e.g., train, valid, test)
Returns:a FairseqDataset corresponding to split
get_batch_iterator(dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0)[source]

Get an iterator that yields batches of data from the given dataset.

Parameters:
  • dataset (FairseqDataset) – dataset to batch
  • max_tokens (int, optional) – max number of tokens in each batch (default: None).
  • max_sentences (int, optional) – max number of sentences in each batch (default: None).
  • max_positions (optional) – max sentence length supported by the model (default: None).
  • ignore_invalid_inputs (bool, optional) – don’t raise Exception for sentences that are too long (default: False).
  • required_batch_size_multiple (int, optional) – require batch size to be a multiple of N (default: 1).
  • seed (int, optional) – seed for random number generator for reproducibility (default: 1).
  • num_shards (int, optional) – shard the data iterator into N shards (default: 1).
  • shard_id (int, optional) – which shard of the data iterator to return (default: 0).
  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0).
  • epoch (int, optional) – the epoch to start the iterator from (default: 0).
Returns:

a batched iterator over the

given dataset split

Return type:

EpochBatchIterator

grad_denom(sample_sizes, criterion)[source]
inference_step(generator, models, sample, prefix_tokens=None)[source]
load_dataset(split, combine=False, **kwargs)[source]

Load a given dataset split.

Parameters:split (str) – name of the split (e.g., train, valid, test)
classmethod load_dictionary(filename)[source]

Load the dictionary from the filename

Parameters:filename (str) – the filename
max_positions()[source]

Return the max input length allowed by the task.

classmethod setup_task(args, **kwargs)[source]

Setup the task (e.g., load dictionaries).

Parameters:args (argparse.Namespace) – parsed command-line arguments
source_dictionary

Return the source Dictionary (if applicable for this task).

target_dictionary

Return the target Dictionary (if applicable for this task).

train_step(sample, model, criterion, optimizer, ignore_grad=False)[source]

Do forward and backward, and return the loss as computed by criterion for the given model and sample.

Parameters:
Returns:

  • the loss
  • the sample size, which is used as the denominator for the gradient
  • logging outputs to display while training

Return type:

tuple

update_step(num_updates)[source]

Task level update when number of update increases. This is called after optimization step and learning rate update of each step

valid_step(sample, model, criterion)[source]

Models

A Model defines the neural network’s forward() method and encapsulates all of the learnable parameters in the network. Each model also provides a set of named architectures that define the precise network configuration (e.g., embedding dimension, number of layers, etc.).

Both the model type and architecture are selected via the --arch command-line argument. Once selected, a model may expose additional command-line arguments for further configuration.

Note

All fairseq Models extend BaseFairseqModel, which in turn extends torch.nn.Module. Thus any fairseq Model can be used as a stand-alone Module in other PyTorch code.

Convolutional Neural Networks (CNN)

class fairseq.models.fconv.FConvModel(encoder, decoder)[source]

A fully convolutional model, i.e. a convolutional encoder and a convolutional decoder, as described in “Convolutional Sequence to Sequence Learning” (Gehring et al., 2017).

Parameters:

The Convolutional model provides the following named architectures and command-line arguments:

usage: 
        [--arch {fconv,fconv_iwslt_de_en,fconv_wmt_en_ro,fconv_wmt_en_de,fconv_wmt_en_fr}]
        [--dropout D] [--encoder-embed-dim N] [--encoder-embed-path STR]
        [--encoder-layers EXPR] [--decoder-embed-dim N]
        [--decoder-embed-path STR] [--decoder-layers EXPR]
        [--decoder-out-embed-dim N] [--decoder-attention EXPR]
        [--share-input-output-embed]

Named architectures

--arch Possible choices: fconv, fconv_iwslt_de_en, fconv_wmt_en_ro, fconv_wmt_en_de, fconv_wmt_en_fr

Additional command-line arguments

--dropout dropout probability
--encoder-embed-dim encoder embedding dimension
--encoder-embed-path path to pre-trained encoder embedding
--encoder-layers encoder layers [(dim, kernel_size), …]
--decoder-embed-dim decoder embedding dimension
--decoder-embed-path path to pre-trained decoder embedding
--decoder-layers decoder layers [(dim, kernel_size), …]
--decoder-out-embed-dim decoder output embedding dimension
--decoder-attention decoder attention [True, …]
--share-input-output-embed

share input and output embeddings (requires –decoder-out-embed-dim and –decoder-embed-dim to be equal)

Default: False

static add_args(parser)[source]

Add model-specific arguments to the parser.

classmethod build_model(args, task)[source]

Build a new model instance.

class fairseq.models.fconv.FConvEncoder(dictionary, embed_dim=512, embed_dict=None, max_positions=1024, convolutions=((512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3)), dropout=0.1)[source]

Convolutional encoder consisting of len(convolutions) layers.

Parameters:
  • dictionary (Dictionary) – encoding dictionary
  • embed_dim (int, optional) – embedding dimension
  • embed_dict (str, optional) – filename from which to load pre-trained embeddings
  • max_positions (int, optional) – maximum supported input sequence length
  • convolutions (list, optional) – the convolutional layer structure. Each list item i corresponds to convolutional layer i. Layers are given as (out_channels, kernel_width, [residual]). Residual connections are added between layers when residual=1 (which is the default behavior).
  • dropout (float, optional) – dropout to be applied before each conv layer
forward(src_tokens, src_lengths)[source]
Parameters:
  • src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
  • src_lengths (LongTensor) – lengths of each source sentence of shape (batch)
Returns:

  • encoder_out (tuple): a tuple with two elements, where the first element is the last encoder layer’s output and the second element is the same quantity summed with the input embedding (used for attention). The shape of both tensors is (batch, src_len, embed_dim).
  • encoder_padding_mask (ByteTensor): the positions of padding elements of shape (batch, src_len)

Return type:

dict

max_positions()[source]

Maximum input length supported by the encoder.

reorder_encoder_out(encoder_out, new_order)[source]

Reorder encoder output according to new_order.

Parameters:
  • encoder_out – output from the forward() method
  • new_order (LongTensor) – desired order
Returns:

encoder_out rearranged according to new_order

class fairseq.models.fconv.FConvDecoder(dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256, max_positions=1024, convolutions=((512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3)), attention=True, dropout=0.1, share_embed=False, positional_embeddings=True, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0)[source]

Convolutional decoder

forward(prev_output_tokens, encoder_out=None, incremental_state=None, **unused)[source]
Parameters:
  • prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for input feeding/teacher forcing
  • encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
  • incremental_state (dict, optional) – dictionary used for storing state during Incremental decoding
Returns:

  • the decoder’s output of shape (batch, tgt_len, vocab)
  • a dictionary with any model-specific outputs

Return type:

tuple

max_positions()[source]

Maximum output length supported by the decoder.

reorder_incremental_state(incremental_state, new_order)[source]

Reorder incremental state.

This should be called when the order of the input has changed from the previous time step. A typical use case is beam search, where the input order changes between time steps based on the selection of beams.

upgrade_state_dict(state_dict)[source]

Upgrade a (possibly old) state dict for new versions of fairseq.

Long Short-Term Memory (LSTM) networks

class fairseq.models.lstm.LSTMModel(encoder, decoder)[source]
static add_args(parser)[source]

Add model-specific arguments to the parser.

classmethod build_model(args, task)[source]

Build a new model instance.

class fairseq.models.lstm.LSTMEncoder(dictionary, embed_dim=512, hidden_size=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, left_pad=True, pretrained_embed=None, padding_value=0.0)[source]

LSTM encoder.

forward(src_tokens, src_lengths)[source]

Args: src_tokens (LongTensor): tokens in the source language of shape

(batch, src_len)
src_lengths (LongTensor): lengths of each source sentence of shape
(batch)
max_positions()[source]

Maximum input length supported by the encoder.

reorder_encoder_out(encoder_out, new_order)[source]

Reorder encoder output according to new_order.

Parameters:
  • encoder_out – output from the forward() method
  • new_order (LongTensor) – desired order
Returns:

encoder_out rearranged according to new_order

class fairseq.models.lstm.LSTMDecoder(dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True, encoder_output_units=512, pretrained_embed=None, share_input_output_embed=False, adaptive_softmax_cutoff=None)[source]

LSTM decoder.

forward(prev_output_tokens, encoder_out, incremental_state=None)[source]
Parameters:
  • prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for input feeding/teacher forcing
  • encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
  • incremental_state (dict, optional) – dictionary used for storing state during Incremental decoding
Returns:

  • the decoder’s output of shape (batch, tgt_len, vocab)
  • a dictionary with any model-specific outputs

Return type:

tuple

max_positions()[source]

Maximum output length supported by the decoder.

reorder_incremental_state(incremental_state, new_order)[source]

Reorder incremental state.

This should be called when the order of the input has changed from the previous time step. A typical use case is beam search, where the input order changes between time steps based on the selection of beams.

Transformer (self-attention) networks

class fairseq.models.transformer.TransformerModel(encoder, decoder)[source]

Transformer model from “Attention Is All You Need” (Vaswani, et al, 2017).

Parameters:

The Transformer model provides the following named architectures and command-line arguments:

usage: 
        [--arch {transformer,transformer_iwslt_de_en,transformer_wmt_en_de,transformer_vaswani_wmt_en_de_big,transformer_vaswani_wmt_en_fr_big,transformer_wmt_en_de_big,transformer_wmt_en_de_big_t2t}]
        [--activation-fn {relu,gelu,gelu_fast,gelu_accurate,tanh}]
        [--dropout D] [--attention-dropout D] [--activation-dropout D]
        [--encoder-embed-path STR] [--encoder-embed-dim N]
        [--encoder-ffn-embed-dim N] [--encoder-layers N]
        [--encoder-attention-heads N] [--encoder-normalize-before]
        [--encoder-learned-pos] [--decoder-embed-path STR]
        [--decoder-embed-dim N] [--decoder-ffn-embed-dim N]
        [--decoder-layers N] [--decoder-attention-heads N]
        [--decoder-learned-pos] [--decoder-normalize-before]
        [--share-decoder-input-output-embed] [--share-all-embeddings]
        [--no-token-positional-embeddings] [--adaptive-softmax-cutoff EXPR]
        [--adaptive-softmax-dropout D]

Named architectures

--arch Possible choices: transformer, transformer_iwslt_de_en, transformer_wmt_en_de, transformer_vaswani_wmt_en_de_big, transformer_vaswani_wmt_en_fr_big, transformer_wmt_en_de_big, transformer_wmt_en_de_big_t2t

Additional command-line arguments

--activation-fn

Possible choices: relu, gelu, gelu_fast, gelu_accurate, tanh

activation function to use

--dropout dropout probability
--attention-dropout dropout probability for attention weights
--activation-dropout, --relu-dropout dropout probability after activation in FFN.
--encoder-embed-path path to pre-trained encoder embedding
--encoder-embed-dim encoder embedding dimension
--encoder-ffn-embed-dim encoder embedding dimension for FFN
--encoder-layers num encoder layers
--encoder-attention-heads num encoder attention heads
--encoder-normalize-before

apply layernorm before each encoder block

Default: False

--encoder-learned-pos

use learned positional embeddings in the encoder

Default: False

--decoder-embed-path path to pre-trained decoder embedding
--decoder-embed-dim decoder embedding dimension
--decoder-ffn-embed-dim decoder embedding dimension for FFN
--decoder-layers num decoder layers
--decoder-attention-heads num decoder attention heads
--decoder-learned-pos

use learned positional embeddings in the decoder

Default: False

--decoder-normalize-before

apply layernorm before each decoder block

Default: False

--share-decoder-input-output-embed

share decoder input and output embeddings

Default: False

--share-all-embeddings

share encoder, decoder and output embeddings (requires shared dictionary and embed dim)

Default: False

--no-token-positional-embeddings

if set, disables positional embeddings (outside self attention)

Default: False

--adaptive-softmax-cutoff comma separated list of adaptive softmax cutoff points. Must be used with adaptive_loss criterion
--adaptive-softmax-dropout sets adaptive softmax dropout for the tail projections
static add_args(parser)[source]

Add model-specific arguments to the parser.

classmethod build_model(args, task)[source]

Build a new model instance.

class fairseq.models.transformer.TransformerEncoder(args, dictionary, embed_tokens)[source]

Transformer encoder consisting of args.encoder_layers layers. Each layer is a TransformerEncoderLayer.

Parameters:
forward(src_tokens, src_lengths)[source]
Parameters:
  • src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
  • src_lengths (torch.LongTensor) – lengths of each source sentence of shape (batch)
Returns:

  • encoder_out (Tensor): the last encoder layer’s output of shape (src_len, batch, embed_dim)
  • encoder_padding_mask (ByteTensor): the positions of padding elements of shape (batch, src_len)

Return type:

dict

max_positions()[source]

Maximum input length supported by the encoder.

reorder_encoder_out(encoder_out, new_order)[source]

Reorder encoder output according to new_order.

Parameters:
  • encoder_out – output from the forward() method
  • new_order (LongTensor) – desired order
Returns:

encoder_out rearranged according to new_order

upgrade_state_dict_named(state_dict, name)[source]

Upgrade a (possibly old) state dict for new versions of fairseq.

class fairseq.models.transformer.TransformerEncoderLayer(args)[source]

Encoder layer block.

In the original paper each operation (multi-head attention or FFN) is postprocessed with: dropout -> add residual -> layernorm. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: dropout -> add residual. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting args.encoder_normalize_before to True.

Parameters:args (argparse.Namespace) – parsed command-line arguments
forward(x, encoder_padding_mask)[source]
Parameters:
  • x (Tensor) – input to the layer of shape (seq_len, batch, embed_dim)
  • encoder_padding_mask (ByteTensor) – binary ByteTensor of shape (batch, src_len) where padding elements are indicated by 1.
Returns:

encoded output of shape (batch, src_len, embed_dim)

upgrade_state_dict_named(state_dict, name)[source]

Rename layer norm states from …layer_norms.0.weight to …self_attn_layer_norm.weight and …layer_norms.1.weight to …final_layer_norm.weight

class fairseq.models.transformer.TransformerDecoder(args, dictionary, embed_tokens, no_encoder_attn=False)[source]

Transformer decoder consisting of args.decoder_layers layers. Each layer is a TransformerDecoderLayer.

Parameters:
  • args (argparse.Namespace) – parsed command-line arguments
  • dictionary (Dictionary) – decoding dictionary
  • embed_tokens (torch.nn.Embedding) – output embedding
  • no_encoder_attn (bool, optional) – whether to attend to encoder outputs (default: False).
extract_features(prev_output_tokens, encoder_out=None, incremental_state=None, **unused)[source]

Similar to forward but only return features.

Returns:
  • the decoder’s features of shape (batch, tgt_len, embed_dim)
  • a dictionary with any model-specific outputs
Return type:tuple
forward(prev_output_tokens, encoder_out=None, incremental_state=None, **unused)[source]
Parameters:
  • prev_output_tokens (LongTensor) – previous decoder outputs of shape (batch, tgt_len), for input feeding/teacher forcing
  • encoder_out (Tensor, optional) – output from the encoder, used for encoder-side attention
  • incremental_state (dict) – dictionary used for storing state during Incremental decoding
Returns:

  • the decoder’s output of shape (batch, tgt_len, vocab)
  • a dictionary with any model-specific outputs

Return type:

tuple

max_positions()[source]

Maximum output length supported by the decoder.

output_layer(features, **kwargs)[source]

Project features to the vocabulary size.

upgrade_state_dict_named(state_dict, name)[source]

Upgrade a (possibly old) state dict for new versions of fairseq.

class fairseq.models.transformer.TransformerDecoderLayer(args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False)[source]

Decoder layer block.

In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: dropout -> add residual -> layernorm. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: dropout -> add residual. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting args.decoder_normalize_before to True.

Parameters:
  • args (argparse.Namespace) – parsed command-line arguments
  • no_encoder_attn (bool, optional) – whether to attend to encoder outputs (default: False).
forward(x, encoder_out=None, encoder_padding_mask=None, incremental_state=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None)[source]
Parameters:
  • x (Tensor) – input to the layer of shape (seq_len, batch, embed_dim)
  • encoder_padding_mask (ByteTensor) – binary ByteTensor of shape (batch, src_len) where padding elements are indicated by 1.
Returns:

encoded output of shape (batch, src_len, embed_dim)

Adding new models

fairseq.models.register_model(name)[source]

New model types can be added to fairseq with the register_model() function decorator.

For example:

@register_model('lstm')
class LSTM(FairseqEncoderDecoderModel):
    (...)

Note

All models must implement the BaseFairseqModel interface. Typically you will extend FairseqEncoderDecoderModel for sequence-to-sequence tasks or FairseqLanguageModel for language modeling tasks.

Parameters:name (str) – the name of the model
fairseq.models.register_model_architecture(model_name, arch_name)[source]

New model architectures can be added to fairseq with the register_model_architecture() function decorator. After registration, model architectures can be selected with the --arch command-line argument.

For example:

@register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
def lstm_luong_wmt_en_de(args):
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000)
    (...)

The decorated function should take a single argument args, which is a argparse.Namespace of arguments parsed from the command-line. The decorated function should modify these arguments in-place to match the desired architecture.

Parameters:
  • model_name (str) – the name of the Model (Model must already be registered)
  • arch_name (str) – the name of the model architecture (--arch)
class fairseq.models.BaseFairseqModel[source]

Base class for fairseq models.

static add_args(parser)[source]

Add model-specific arguments to the parser.

classmethod build_model(args, task)[source]

Build a new model instance.

extract_features(*args, **kwargs)[source]

Similar to forward but only return features.

classmethod from_pretrained(model_name_or_path, checkpoint_file='model.pt', data_name_or_path=None, **kwargs)[source]

Load a FairseqModel from a pre-trained model file. Downloads and caches the pre-trained model file if needed.

The base implementation returns a fairseq.hub_utils.Generator, which can be used to generate translations or sample from language models. The underlying FairseqModel can be accessed via the generator.models attribute.

Other models may override this to implement custom PyTorch Hub APIs.

Parameters:
  • model_name_or_path (str) – either the name of a pre-trained model to load or a path/URL to a pre-trained model state dict
  • checkpoint_file (str, optional) – colon-separated list of checkpoint files in the model archive to ensemble (default: ‘model.pt’)
  • data_name_or_path (str, optional) – point args.data to the archive at the given path/URL. Can start with ‘.’ or ‘./’ to reuse the model archive path.
get_normalized_probs(net_output, log_probs, sample=None)[source]

Get normalized probabilities (or log probs) from a net’s output.

get_targets(sample, net_output)[source]

Get targets from either the sample or the net’s output.

classmethod hub_models()[source]
load_state_dict(state_dict, strict=True)[source]

Copies parameters and buffers from state_dict into this module and its descendants.

Overrides the method in nn.Module. Compared with that method this additionally “upgrades” state_dicts from old checkpoints.

make_generation_fast_(**kwargs)[source]

Optimize model for faster generation.

max_positions()[source]

Maximum length supported by the model.

prepare_for_onnx_export_(**kwargs)[source]

Make model exportable via ONNX trace.

upgrade_state_dict(state_dict)[source]

Upgrade old state dicts to work with newer code.

upgrade_state_dict_named(state_dict, name)[source]

Upgrade old state dicts to work with newer code.

Parameters:
  • state_dict (dict) – state dictionary to upgrade, in place
  • name (str) – the state dict key corresponding to the current module
class fairseq.models.FairseqEncoderDecoderModel(encoder, decoder)[source]

Base class for encoder-decoder models.

Parameters:
extract_features(src_tokens, src_lengths, prev_output_tokens, **kwargs)[source]

Similar to forward but only return features.

Returns:
  • the decoder’s features of shape (batch, tgt_len, embed_dim)
  • a dictionary with any model-specific outputs
Return type:tuple
forward(src_tokens, src_lengths, prev_output_tokens, **kwargs)[source]

Run the forward pass for an encoder-decoder model.

First feed a batch of source tokens through the encoder. Then, feed the encoder output and previous decoder outputs (i.e., input feeding/teacher forcing) to the decoder to produce the next outputs:

encoder_out = self.encoder(src_tokens, src_lengths)
return self.decoder(prev_output_tokens, encoder_out)
Parameters:
  • src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
  • src_lengths (LongTensor) – source sentence lengths of shape (batch)
  • prev_output_tokens (LongTensor) – previous decoder outputs of shape (batch, tgt_len), for input feeding/teacher forcing
Returns:

  • the decoder’s output of shape (batch, tgt_len, vocab)
  • a dictionary with any model-specific outputs

Return type:

tuple

max_decoder_positions()[source]

Maximum length supported by the decoder.

max_positions()[source]

Maximum length supported by the model.

output_layer(features, **kwargs)[source]

Project features to the default output size (typically vocabulary size).

class fairseq.models.FairseqEncoderModel(encoder)[source]

Base class for encoder-only models.

Parameters:encoder (FairseqEncoder) – the encoder
forward(src_tokens, src_lengths, **kwargs)[source]

Run the forward pass for a encoder-only model.

Feeds a batch of tokens through the encoder to generate features.

Parameters:
  • src_tokens (LongTensor) – input tokens of shape (batch, src_len)
  • src_lengths (LongTensor) – source sentence lengths of shape (batch)
Returns:

the encoder’s output, typically of shape (batch, src_len, features)

get_normalized_probs(net_output, log_probs, sample=None)[source]

Get normalized probabilities (or log probs) from a net’s output.

max_positions()[source]

Maximum length supported by the model.

class fairseq.models.FairseqLanguageModel(decoder)[source]

Base class for decoder-only models.

Parameters:decoder (FairseqDecoder) – the decoder
extract_features(src_tokens, **kwargs)[source]

Similar to forward but only return features.

Returns:
  • the decoder’s features of shape (batch, seq_len, embed_dim)
  • a dictionary with any model-specific outputs
Return type:tuple
forward(src_tokens, **kwargs)[source]

Run the forward pass for a decoder-only model.

Feeds a batch of tokens through the decoder to predict the next tokens.

Parameters:
  • src_tokens (LongTensor) – tokens on which to condition the decoder, of shape (batch, tgt_len)
  • src_lengths (LongTensor) – source sentence lengths of shape (batch)
Returns:

  • the decoder’s output of shape (batch, seq_len, vocab)
  • a dictionary with any model-specific outputs

Return type:

tuple

max_decoder_positions()[source]

Maximum length supported by the decoder.

max_positions()[source]

Maximum length supported by the model.

output_layer(features, **kwargs)[source]

Project features to the default output size (typically vocabulary size).

supported_targets
class fairseq.models.FairseqMultiModel(encoders, decoders)[source]

Base class for combining multiple encoder-decoder models.

static build_shared_embeddings(dicts: Dict[str, fairseq.data.dictionary.Dictionary], langs: List[str], embed_dim: int, build_embedding: callable, pretrained_embed_path: Optional[str] = None)[source]

Helper function to build shared embeddings for a set of languages after checking that all dicts corresponding to those languages are equivalent.

Parameters:
  • dicts – Dict of lang_id to its corresponding Dictionary
  • langs – languages that we want to share embeddings for
  • embed_dim – embedding dimension
  • build_embedding – callable function to actually build the embedding
  • pretrained_embed_path – Optional path to load pretrained embeddings
decoder
encoder
forward(src_tokens, src_lengths, prev_output_tokens, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

max_decoder_positions()[source]

Maximum length supported by the decoder.

max_positions()[source]

Maximum length supported by the model.

class fairseq.models.FairseqEncoder(dictionary)[source]

Base class for encoders.

forward(src_tokens, src_lengths=None, **kwargs)[source]
Parameters:
  • src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
  • src_lengths (LongTensor) – lengths of each source sentence of shape (batch)
max_positions()[source]

Maximum input length supported by the encoder.

reorder_encoder_out(encoder_out, new_order)[source]

Reorder encoder output according to new_order.

Parameters:
  • encoder_out – output from the forward() method
  • new_order (LongTensor) – desired order
Returns:

encoder_out rearranged according to new_order

upgrade_state_dict(state_dict)[source]

Upgrade a (possibly old) state dict for new versions of fairseq.

class fairseq.models.CompositeEncoder(encoders)[source]

A wrapper around a dictionary of FairseqEncoder objects.

We run forward on each encoder and return a dictionary of outputs. The first encoder’s dictionary is used for initialization.

Parameters:encoders (dict) – a dictionary of FairseqEncoder objects.
forward(src_tokens, src_lengths)[source]
Parameters:
  • src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
  • src_lengths (LongTensor) – lengths of each source sentence of shape (batch)
Returns:

the outputs from each Encoder

Return type:

dict

max_positions()[source]

Maximum input length supported by the encoder.

reorder_encoder_out(encoder_out, new_order)[source]

Reorder encoder output according to new_order.

upgrade_state_dict(state_dict)[source]

Upgrade a (possibly old) state dict for new versions of fairseq.

class fairseq.models.FairseqDecoder(dictionary)[source]

Base class for decoders.

extract_features(prev_output_tokens, encoder_out=None, **kwargs)[source]
Returns:
  • the decoder’s features of shape (batch, tgt_len, embed_dim)
  • a dictionary with any model-specific outputs
Return type:tuple
forward(prev_output_tokens, encoder_out=None, **kwargs)[source]
Parameters:
  • prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for input feeding/teacher forcing
  • encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
Returns:

  • the decoder’s output of shape (batch, tgt_len, vocab)
  • a dictionary with any model-specific outputs

Return type:

tuple

get_normalized_probs(net_output, log_probs, sample)[source]

Get normalized probabilities (or log probs) from a net’s output.

max_positions()[source]

Maximum input length supported by the decoder.

output_layer(features, **kwargs)[source]

Project features to the default output size, e.g., vocabulary size.

Parameters:features (Tensor) – features returned by extract_features.
upgrade_state_dict(state_dict)[source]

Upgrade a (possibly old) state dict for new versions of fairseq.

Incremental decoding

class fairseq.models.FairseqIncrementalDecoder(dictionary)[source]

Base class for incremental decoders.

Incremental decoding is a special mode at inference time where the Model only receives a single timestep of input corresponding to the previous output token (for input feeding) and must produce the next output incrementally. Thus the model must cache any long-term state that is needed about the sequence, e.g., hidden states, convolutional states, etc.

Compared to the standard FairseqDecoder interface, the incremental decoder interface allows forward() functions to take an extra keyword argument (incremental_state) that can be used to cache state across time-steps.

The FairseqIncrementalDecoder interface also defines the reorder_incremental_state() method, which is used during beam search to select and reorder the incremental state based on the selection of beams.

To learn more about how incremental decoding works, refer to this blog.

extract_features(prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs)[source]
Returns:
  • the decoder’s features of shape (batch, tgt_len, embed_dim)
  • a dictionary with any model-specific outputs
Return type:tuple
forward(prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs)[source]
Parameters:
  • prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for input feeding/teacher forcing
  • encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
  • incremental_state (dict, optional) – dictionary used for storing state during Incremental decoding
Returns:

  • the decoder’s output of shape (batch, tgt_len, vocab)
  • a dictionary with any model-specific outputs

Return type:

tuple

reorder_incremental_state(incremental_state, new_order)[source]

Reorder incremental state.

This should be called when the order of the input has changed from the previous time step. A typical use case is beam search, where the input order changes between time steps based on the selection of beams.

set_beam_size(beam_size)[source]

Sets the beam size in the decoder and all children.

Criterions

Criterions compute the loss function given the model and batch, roughly:

loss = criterion(model, batch)
class fairseq.criterions.FairseqCriterion(args, task)[source]
static add_args(parser)[source]

Add criterion-specific arguments to the parser.

static aggregate_logging_outputs(logging_outputs)[source]

Aggregate logging outputs from data parallel training.

classmethod build_criterion(args, task)[source]
forward(model, sample, reduce=True)[source]

Compute the loss for the given sample.

Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training

static grad_denom(sample_sizes)[source]

Compute the gradient denominator for a set of sample sizes.

class fairseq.criterions.adaptive_loss.AdaptiveLoss(args, task)[source]

This is an implementation of the loss function accompanying the adaptive softmax approximation for graphical processing units (GPU), described in the paper “Efficient softmax approximation for GPUs” (http://arxiv.org/abs/1609.04309).

static aggregate_logging_outputs(logging_outputs)[source]

Aggregate logging outputs from data parallel training.

forward(model, sample, reduce=True)[source]

Compute the loss for the given sample.

Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training

class fairseq.criterions.composite_loss.CompositeLoss(args, task)[source]

This is a composite loss that, given a list of model outputs and a list of targets, computes an average of losses for each output-target pair

static add_args(parser)[source]

Add criterion-specific arguments to the parser.

classmethod build_criterion(args, task)[source]
static build_underlying_criterion(args, task)[source]
class fairseq.criterions.cross_entropy.CrossEntropyCriterion(args, task)[source]
static aggregate_logging_outputs(logging_outputs)[source]

Aggregate logging outputs from data parallel training.

compute_loss(model, net_output, sample, reduce=True)[source]
forward(model, sample, reduce=True)[source]

Compute the loss for the given sample.

Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training

class fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion(args, task)[source]
static add_args(parser)[source]

Add criterion-specific arguments to the parser.

static aggregate_logging_outputs(logging_outputs)[source]

Aggregate logging outputs from data parallel training.

compute_loss(model, net_output, sample, reduce=True)[source]
forward(model, sample, reduce=True)[source]

Compute the loss for the given sample.

Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training

Optimizers

Optimizers update the Model parameters based on the gradients.

class fairseq.optim.FP16Optimizer(args, params, fp32_optimizer, fp32_params)[source]

Wrap an optimizer to support FP16 (mixed precision) training.

backward(loss)[source]

Computes the sum of gradients of the given tensor w.r.t. graph leaves.

Compared to fairseq.optim.FairseqOptimizer.backward(), this function additionally dynamically scales the loss to avoid gradient underflow.

classmethod build_optimizer(args, params)[source]
Parameters:
  • args (argparse.Namespace) – fairseq args
  • params (iterable) – iterable of parameters to optimize
clip_grad_norm(max_norm)[source]

Clips gradient norm and updates dynamic loss scaler.

get_lr()[source]

Return the current learning rate.

load_state_dict(state_dict, optimizer_overrides=None)[source]

Load an optimizer state dict.

In general we should prefer the configuration of the existing optimizer instance (e.g., learning rate) over that found in the state_dict. This allows us to resume training from a checkpoint using a new set of optimizer args.

multiply_grads(c)[source]

Multiplies grads by a constant c.

optimizer

Return a torch.optim.optimizer.Optimizer instance.

optimizer_config

Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.

set_lr(lr)[source]

Set the learning rate.

state_dict()[source]

Return the optimizer’s state dict.

step(closure=None)[source]

Performs a single optimization step.

zero_grad()[source]

Clears the gradients of all optimized parameters.

class fairseq.optim.MemoryEfficientFP16Optimizer(args, params, optimizer)[source]

Wrap an optimizer to support FP16 (mixed precision) training.

Compared to fairseq.optim.FP16Optimizer, this version does not maintain an FP32 copy of the model. We instead expect the optimizer to convert the gradients to FP32 internally and sync the results back to the FP16 model params. This significantly reduces memory usage but slightly increases the time spent in the optimizer.

Since this wrapper depends on specific functionality in the wrapped optimizer (i.e., on-the-fly conversion of grads to FP32), only certain optimizers can be wrapped. This is determined by the supports_memory_efficient_fp16 property.

backward(loss)[source]

Computes the sum of gradients of the given tensor w.r.t. graph leaves.

Compared to fairseq.optim.FairseqOptimizer.backward(), this function additionally dynamically scales the loss to avoid gradient underflow.

classmethod build_optimizer(args, params)[source]
Parameters:
  • args (argparse.Namespace) – fairseq args
  • params (iterable) – iterable of parameters to optimize
clip_grad_norm(max_norm)[source]

Clips gradient norm and updates dynamic loss scaler.

get_lr()[source]

Return the current learning rate.

load_state_dict(state_dict, optimizer_overrides=None)[source]

Load an optimizer state dict.

In general we should prefer the configuration of the existing optimizer instance (e.g., learning rate) over that found in the state_dict. This allows us to resume training from a checkpoint using a new set of optimizer args.

multiply_grads(c)[source]

Multiplies grads by a constant c.

optimizer

Return a torch.optim.optimizer.Optimizer instance.

optimizer_config

Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.

set_lr(lr)[source]

Set the learning rate.

state_dict()[source]

Return the optimizer’s state dict.

step(closure=None)[source]

Performs a single optimization step.

zero_grad()[source]

Clears the gradients of all optimized parameters.

class fairseq.optim.FairseqOptimizer(args, params)[source]
static add_args(parser)[source]

Add optimizer-specific arguments to the parser.

backward(loss)[source]

Computes the sum of gradients of the given tensor w.r.t. graph leaves.

clip_grad_norm(max_norm)[source]

Clips gradient norm.

get_lr()[source]

Return the current learning rate.

load_state_dict(state_dict, optimizer_overrides=None)[source]

Load an optimizer state dict.

In general we should prefer the configuration of the existing optimizer instance (e.g., learning rate) over that found in the state_dict. This allows us to resume training from a checkpoint using a new set of optimizer args.

multiply_grads(c)[source]

Multiplies grads by a constant c.

optimizer

Return a torch.optim.optimizer.Optimizer instance.

optimizer_config

Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.

set_lr(lr)[source]

Set the learning rate.

state_dict()[source]

Return the optimizer’s state dict.

step(closure=None)[source]

Performs a single optimization step.

supports_memory_efficient_fp16
zero_grad()[source]

Clears the gradients of all optimized parameters.

class fairseq.optim.adadelta.Adadelta(args, params)[source]
static add_args(parser)[source]

Add optimizer-specific arguments to the parser.

optimizer_config

Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.

class fairseq.optim.adagrad.Adagrad(args, params)[source]
static add_args(parser)[source]

Add optimizer-specific arguments to the parser.

optimizer_config

Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.

class fairseq.optim.adafactor.FairseqAdafactor(args, params)[source]
static add_args(parser)[source]

Add optimizer-specific arguments to the parser.

optimizer_config

Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate. Note : Convergence issues empirically observed with fp16 on.

Might require search for appropriate configuration.
class fairseq.optim.adam.FairseqAdam(args, params)[source]
static add_args(parser)[source]

Add optimizer-specific arguments to the parser.

optimizer_config

Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.

class fairseq.optim.fp16_optimizer.FP16Optimizer(args, params, fp32_optimizer, fp32_params)[source]

Wrap an optimizer to support FP16 (mixed precision) training.

backward(loss)[source]

Computes the sum of gradients of the given tensor w.r.t. graph leaves.

Compared to fairseq.optim.FairseqOptimizer.backward(), this function additionally dynamically scales the loss to avoid gradient underflow.

classmethod build_optimizer(args, params)[source]
Parameters:
  • args (argparse.Namespace) – fairseq args
  • params (iterable) – iterable of parameters to optimize
clip_grad_norm(max_norm)[source]

Clips gradient norm and updates dynamic loss scaler.

get_lr()[source]

Return the current learning rate.

load_state_dict(state_dict, optimizer_overrides=None)[source]

Load an optimizer state dict.

In general we should prefer the configuration of the existing optimizer instance (e.g., learning rate) over that found in the state_dict. This allows us to resume training from a checkpoint using a new set of optimizer args.

multiply_grads(c)[source]

Multiplies grads by a constant c.

optimizer

Return a torch.optim.optimizer.Optimizer instance.

optimizer_config

Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.

set_lr(lr)[source]

Set the learning rate.

state_dict()[source]

Return the optimizer’s state dict.

step(closure=None)[source]

Performs a single optimization step.

zero_grad()[source]

Clears the gradients of all optimized parameters.

class fairseq.optim.nag.FairseqNAG(args, params)[source]
static add_args(parser)[source]

Add optimizer-specific arguments to the parser.

optimizer_config

Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.

class fairseq.optim.sgd.SGD(args, params)[source]
static add_args(parser)[source]

Add optimizer-specific arguments to the parser.

optimizer_config

Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.

Learning Rate Schedulers

Learning Rate Schedulers update the learning rate over the course of training. Learning rates can be updated after each update via step_update() or at epoch boundaries via step().

class fairseq.optim.lr_scheduler.FairseqLRScheduler(args, optimizer)[source]
static add_args(parser)[source]

Add arguments to the parser for this LR scheduler.

load_state_dict(state_dict)[source]

Load an LR scheduler state dict.

state_dict()[source]

Return the LR scheduler state dict.

step(epoch, val_loss=None)[source]

Update the learning rate at the end of the given epoch.

step_update(num_updates)[source]

Update the learning rate after each update.

class fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule(args, optimizer)[source]

Assign LR based on a cyclical schedule that follows the cosine function.

See https://arxiv.org/pdf/1608.03983.pdf for details.

We also support a warmup phase where we linearly increase the learning rate from some initial learning rate (--warmup-init-lr) until the configured learning rate (--lr).

During warmup:

lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]

After warmup:

lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))

where t_curr is current percentage of updates within the current period range and t_i is the current period range, which is scaled by t_mul after every iteration.

static add_args(parser)[source]

Add arguments to the parser for this LR scheduler.

step(epoch, val_loss=None)[source]

Update the learning rate at the end of the given epoch.

step_update(num_updates)[source]

Update the learning rate after each update.

class fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule(args, optimizer)[source]

Decay the LR on a fixed schedule.

static add_args(parser)[source]

Add arguments to the parser for this LR scheduler.

get_next_lr(epoch)[source]
step(epoch, val_loss=None)[source]

Update the learning rate at the end of the given epoch.

step_update(num_updates)[source]

Update the learning rate after each update.

class fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule(args, optimizer)[source]

Decay the LR based on the inverse square root of the update number.

We also support a warmup phase where we linearly increase the learning rate from some initial learning rate (--warmup-init-lr) until the configured learning rate (--lr). Thereafter we decay proportional to the number of updates, with a decay factor set to align with the configured learning rate.

During warmup:

lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]

After warmup:

decay_factor = args.lr * sqrt(args.warmup_updates)
lr = decay_factor / sqrt(update_num)
static add_args(parser)[source]

Add arguments to the parser for this LR scheduler.

step(epoch, val_loss=None)[source]

Update the learning rate at the end of the given epoch.

step_update(num_updates)[source]

Update the learning rate after each update.

class fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau(args, optimizer)[source]

Decay the LR by a factor every time the validation loss plateaus.

static add_args(parser)[source]

Add arguments to the parser for this LR scheduler.

load_state_dict(state_dict)[source]

Load an LR scheduler state dict.

state_dict()[source]

Return the LR scheduler state dict.

step(epoch, val_loss=None)[source]

Update the learning rate at the end of the given epoch.

class fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule(args, optimizer)[source]

Assign LR based on a triangular cyclical schedule.

See https://arxiv.org/pdf/1506.01186.pdf for details.

static add_args(parser)[source]

Add arguments to the parser for this LR scheduler.

step(epoch, val_loss=None)[source]

Update the learning rate at the end of the given epoch.

step_update(num_updates)[source]

Update the learning rate after each update.

Data Loading and Utilities

Datasets

Datasets define the data format and provide helpers for creating mini-batches.

class fairseq.data.FairseqDataset[source]

A dataset that provides helpers for batching.

collater(samples)[source]

Merge a list of samples to form a mini-batch.

Parameters:samples (List[dict]) – samples to collate
Returns:a mini-batch suitable for forwarding with a Model
Return type:dict
num_tokens(index)[source]

Return the number of tokens in a sample. This value is used to enforce --max-tokens during batching.

ordered_indices()[source]

Return an ordered list of indices. Batches will be constructed based on this order.

prefetch(indices)[source]

Prefetch the data required for this epoch.

size(index)[source]

Return an example’s size as a float or tuple. This value is used when filtering a dataset with --max-positions.

supports_prefetch

Whether this dataset supports prefetching.

class fairseq.data.LanguagePairDataset(src, src_sizes, src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, left_pad_source=True, left_pad_target=False, max_source_positions=1024, max_target_positions=1024, shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False)[source]

A pair of torch.utils.data.Datasets.

Parameters:
  • src (torch.utils.data.Dataset) – source dataset to wrap
  • src_sizes (List[int]) – source sentence lengths
  • src_dict (Dictionary) – source vocabulary
  • tgt (torch.utils.data.Dataset, optional) – target dataset to wrap
  • tgt_sizes (List[int], optional) – target sentence lengths
  • tgt_dict (Dictionary, optional) – target vocabulary
  • left_pad_source (bool, optional) – pad source tensors on the left side (default: True).
  • left_pad_target (bool, optional) – pad target tensors on the left side (default: False).
  • max_source_positions (int, optional) – max number of tokens in the source sentence (default: 1024).
  • max_target_positions (int, optional) – max number of tokens in the target sentence (default: 1024).
  • shuffle (bool, optional) – shuffle dataset elements before batching (default: True).
  • input_feeding (bool, optional) – create a shifted version of the targets to be passed into the model for input feeding/teacher forcing (default: True).
  • remove_eos_from_source (bool, optional) – if set, removes eos from end of source if it’s present (default: False).
  • append_eos_to_target (bool, optional) – if set, appends eos to end of target if it’s absent (default: False).
collater(samples)[source]

Merge a list of samples to form a mini-batch.

Parameters:samples (List[dict]) – samples to collate
Returns:a mini-batch with the following keys:
  • id (LongTensor): example IDs in the original input order
  • ntokens (int): total number of tokens in the batch
  • net_input (dict): the input to the Model, containing keys:
    • src_tokens (LongTensor): a padded 2D Tensor of tokens in the source sentence of shape (bsz, src_len). Padding will appear on the left if left_pad_source is True.
    • src_lengths (LongTensor): 1D Tensor of the unpadded lengths of each source sentence of shape (bsz)
    • prev_output_tokens (LongTensor): a padded 2D Tensor of tokens in the target sentence, shifted right by one position for input feeding/teacher forcing, of shape (bsz, tgt_len). This key will not be present if input_feeding is False. Padding will appear on the left if left_pad_target is True.
  • target (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape (bsz, tgt_len). Padding will appear on the left if left_pad_target is True.
Return type:dict
num_tokens(index)[source]

Return the number of tokens in a sample. This value is used to enforce --max-tokens during batching.

ordered_indices()[source]

Return an ordered list of indices. Batches will be constructed based on this order.

prefetch(indices)[source]

Prefetch the data required for this epoch.

size(index)[source]

Return an example’s size as a float or tuple. This value is used when filtering a dataset with --max-positions.

supports_prefetch

Whether this dataset supports prefetching.

class fairseq.data.MonolingualDataset(dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle, targets=None, add_bos_token=False)[source]

A wrapper around torch.utils.data.Dataset for monolingual data.

Parameters:
  • dataset (torch.utils.data.Dataset) – dataset to wrap
  • sizes (List[int]) – sentence lengths
  • vocab (Dictionary) – vocabulary
  • shuffle (bool, optional) – shuffle the elements before batching (default: True).
collater(samples)[source]

Merge a list of samples to form a mini-batch.

Parameters:samples (List[dict]) – samples to collate
Returns:a mini-batch with the following keys:
  • id (LongTensor): example IDs in the original input order
  • ntokens (int): total number of tokens in the batch
  • net_input (dict): the input to the Model, containing keys:
    • src_tokens (LongTensor): a padded 2D Tensor of tokens in the source sentence of shape (bsz, src_len). Padding will appear on the right.
  • target (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape (bsz, tgt_len). Padding will appear on the right.
Return type:dict
num_tokens(index)[source]

Return the number of tokens in a sample. This value is used to enforce --max-tokens during batching.

ordered_indices()[source]

Return an ordered list of indices. Batches will be constructed based on this order.

prefetch(indices)[source]

Prefetch the data required for this epoch.

size(index)[source]

Return an example’s size as a float or tuple. This value is used when filtering a dataset with --max-positions.

supports_prefetch

Whether this dataset supports prefetching.

Helper Datasets

These datasets wrap other fairseq.data.FairseqDataset instances and provide additional functionality:

class fairseq.data.BacktranslationDataset(tgt_dataset, src_dict, tgt_dict=None, backtranslation_fn=None, output_collater=None, cuda=True, **kwargs)[source]

Sets up a backtranslation dataset which takes a tgt batch, generates a src using a tgt-src backtranslation function (backtranslation_fn), and returns the corresponding {generated src, input tgt} batch.

Parameters:
  • tgt_dataset (FairseqDataset) – the dataset to be backtranslated. Only the source side of this dataset will be used. After backtranslation, the source sentences in this dataset will be returned as the targets.
  • src_dict (Dictionary) – the dictionary of backtranslated sentences.
  • tgt_dict (Dictionary, optional) – the dictionary of sentences to be backtranslated.
  • backtranslation_fn (callable, optional) – function to call to generate backtranslations. This is typically the generate method of a SequenceGenerator object. Pass in None when it is not available at initialization time, and use set_backtranslation_fn function to set it when available.
  • output_collater (callable, optional) – function to call on the backtranslated samples to create the final batch (default: tgt_dataset.collater).
  • cuda – use GPU for generation
collater(samples)[source]

Merge and backtranslate a list of samples to form a mini-batch.

Using the samples from tgt_dataset, load a collated target sample to feed to the backtranslation model. Then take the backtranslation with the best score as the source and the original input as the target.

Note: we expect tgt_dataset to provide a function collater() that will collate samples into the format expected by backtranslation_fn. After backtranslation, we will feed the new list of samples (i.e., the (backtranslated source, original source) pairs) to output_collater and return the result.

Parameters:samples (List[dict]) – samples to backtranslate and collate
Returns:a mini-batch with keys coming from output_collater
Return type:dict
num_tokens(index)[source]

Just use the tgt dataset num_tokens

ordered_indices()[source]

Just use the tgt dataset ordered_indices

prefetch(indices)[source]

Prefetch the data required for this epoch.

size(index)[source]

Return an example’s size as a float or tuple. This value is used when filtering a dataset with --max-positions.

Note: we use tgt_dataset to approximate the length of the source sentence, since we do not know the actual length until after backtranslation.

supports_prefetch

Whether this dataset supports prefetching.

class fairseq.data.ConcatDataset(datasets, sample_ratios=1)[source]
collater(samples)[source]

Merge a list of samples to form a mini-batch.

Parameters:samples (List[dict]) – samples to collate
Returns:a mini-batch suitable for forwarding with a Model
Return type:dict
num_tokens(index: int)[source]

Return the number of tokens in a sample. This value is used to enforce --max-tokens during batching.

ordered_indices()[source]

Returns indices sorted by length. So less padding is needed.

prefetch(indices)[source]

Prefetch the data required for this epoch.

size(idx: int)[source]

Return an example’s size as a float or tuple.

supports_prefetch

Whether this dataset supports prefetching.

class fairseq.data.RoundRobinZipDatasets(datasets, eval_key=None)[source]

Zip multiple FairseqDataset instances together.

Shorter datasets are repeated in a round-robin fashion to match the length of the longest one.

Parameters:
  • datasets (Dict[FairseqDataset]) – a dictionary of FairseqDataset instances.
  • eval_key (str, optional) – a key used at evaluation time that causes this instance to pass-through batches from datasets[eval_key].
collater(samples)[source]

Merge a list of samples to form a mini-batch.

num_tokens(index)[source]

Return an example’s length (number of tokens), used for batching.

ordered_indices()[source]

Ordered indices for batching.

prefetch(indices)[source]

Prefetch the data required for this epoch.

size(index)[source]

Return an example’s size as a float or tuple. This value is used when filtering a dataset with --max-positions.

supports_prefetch

Whether this dataset supports prefetching.

class fairseq.data.TransformEosDataset(dataset, eos, append_eos_to_src=False, remove_eos_from_src=False, append_eos_to_tgt=False, remove_eos_from_tgt=False, has_target=True)[source]

A FairseqDataset wrapper that appends/prepends/strips EOS.

Note that the transformation is applied in collater().

Parameters:
  • dataset (FairseqDataset) – dataset to wrap
  • eos (int) – index of the end-of-sentence symbol
  • append_eos_to_src (bool, optional) – append EOS to the end of src
  • remove_eos_from_src (bool, optional) – remove EOS from the end of src
  • append_eos_to_tgt (bool, optional) – append EOS to the end of tgt
  • remove_eos_from_tgt (bool, optional) – remove EOS from the end of tgt
collater(samples)[source]

Merge a list of samples to form a mini-batch.

Parameters:samples (List[dict]) – samples to collate
Returns:a mini-batch suitable for forwarding with a Model
Return type:dict
num_tokens(index)[source]

Return the number of tokens in a sample. This value is used to enforce --max-tokens during batching.

ordered_indices()[source]

Return an ordered list of indices. Batches will be constructed based on this order.

prefetch(indices)[source]

Prefetch the data required for this epoch.

size(index)[source]

Return an example’s size as a float or tuple. This value is used when filtering a dataset with --max-positions.

supports_prefetch

Whether this dataset supports prefetching.

Dictionary

class fairseq.data.Dictionary(pad='<pad>', eos='</s>', unk='<unk>', bos='<s>')[source]

A mapping from symbols to consecutive integers

add_symbol(word, n=1)[source]

Adds a word to the dictionary

bos()[source]

Helper to get index of beginning-of-sentence symbol

eos()[source]

Helper to get index of end-of-sentence symbol

finalize(threshold=-1, nwords=-1, padding_factor=8)[source]

Sort symbols by frequency in descending order, ignoring special ones.

Parameters:
  • threshold defines the minimum word count (-) –
  • nwords defines the total number of words in the final dictionary, (-) – including special symbols
  • padding_factor can be used to pad the dictionary size to be a (-) – multiple of 8, which is important on some hardware (e.g., Nvidia Tensor Cores).
index(sym)[source]

Returns the index of the specified symbol

classmethod load(f, ignore_utf_errors=False)[source]

Loads the dictionary from a text file with the format:

` <symbol0> <count0> <symbol1> <count1> ... `

pad()[source]

Helper to get index of pad symbol

save(f)[source]

Stores dictionary into a text file

string(tensor, bpe_symbol=None, escape_unk=False)[source]

Helper for converting a tensor of token indices to a string.

Can optionally remove BPE symbols or escape <unk> words.

unk()[source]

Helper to get index of unk symbol

unk_string(escape=False)[source]

Return unknown string, optionally escaped as: <<unk>>

update(new_dict)[source]

Updates counts from new dictionary.

Iterators

class fairseq.data.CountingIterator(iterable, start=0)[source]

Wrapper around an iterable that maintains the iteration count.

Parameters:iterable (iterable) – iterable to wrap
count

number of elements consumed from this iterator

Type:int
has_next()[source]

Whether the iterator has been exhausted.

skip(num_to_skip)[source]

Fast-forward the iterator by skipping num_to_skip elements.

class fairseq.data.EpochBatchIterator(dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0)[source]

A multi-epoch iterator over a torch.utils.data.Dataset.

Compared to torch.utils.data.DataLoader, this iterator:

  • can be reused across multiple epochs with the next_epoch_itr() method (optionally shuffled between epochs)
  • can be serialized/deserialized with the state_dict() and load_state_dict() methods
  • supports sharding with the num_shards and shard_id arguments
Parameters:
  • dataset (Dataset) – dataset from which to load the data
  • collate_fn (callable) – merges a list of samples to form a mini-batch
  • batch_sampler (Sampler) – an iterator over batches of indices
  • seed (int, optional) – seed for random number generator for reproducibility (default: 1).
  • num_shards (int, optional) – shard the data iterator into N shards (default: 1).
  • shard_id (int, optional) – which shard of the data iterator to return (default: 0).
  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0).
  • epoch (int, optional) – the epoch to start the iterator from (default: 0).
end_of_epoch()[source]

Returns whether the most recent epoch iterator has been exhausted

iterations_in_epoch

The number of consumed batches in the current epoch.

load_state_dict(state_dict)[source]

Copies the state of the iterator from the given state_dict.

next_epoch_itr(shuffle=True, fix_batches_to_gpus=False)[source]

Return a new iterator over the dataset.

Parameters:
  • shuffle (bool, optional) – shuffle batches before returning the iterator (default: True).
  • fix_batches_to_gpus – ensure that batches are always allocated to the same shards across epochs. Requires that dataset supports prefetching (default: False).
state_dict()[source]

Returns a dictionary containing a whole state of the iterator.

class fairseq.data.GroupedIterator(iterable, chunk_size)[source]

Wrapper around an iterable that returns groups (chunks) of items.

Parameters:
  • iterable (iterable) – iterable to wrap
  • chunk_size (int) – size of each chunk
class fairseq.data.ShardedIterator(iterable, num_shards, shard_id, fill_value=None)[source]

A sharded wrapper around an iterable, padded to length.

Parameters:
  • iterable (iterable) – iterable to wrap
  • num_shards (int) – number of shards to split the iterable into
  • shard_id (int) – which shard to iterator over
  • fill_value (Any, optional) – padding value when the iterable doesn’t evenly divide num_shards (default: None).

Modules

Fairseq provides several stand-alone torch.nn.Module classes that may be helpful when implementing a new BaseFairseqModel.

class fairseq.modules.AdaptiveInput(vocab_size: int, padding_idx: int, initial_dim: int, factor: float, output_dim: int, cutoff: List[int])[source]
forward(input: torch.Tensor)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

weights_for_band(band: int)[source]
class fairseq.modules.AdaptiveSoftmax(vocab_size, input_dim, cutoff, dropout, factor=4.0, adaptive_inputs=None, tie_proj=False)[source]

This is an implementation of the efficient softmax approximation for graphical processing units (GPU), described in the paper “Efficient softmax approximation for GPUs” (http://arxiv.org/abs/1609.04309).

adapt_target(target)[source]

In order to be efficient, the AdaptiveSoftMax does not compute the scores for all the word of the vocabulary for all the examples. It is thus necessary to call the method adapt_target of the AdaptiveSoftMax layer inside each forward pass.

forward(input, target)[source]
Parameters:
  • input – (b x t x d)
  • target – (b x t)
Returns:

output for each cutoff section and new targets by cut off

Return type:

2 lists

get_log_prob(input, target)[source]

Computes the log probabilities for all the words of the vocabulary, given a 2D tensor of hidden vectors.

upgrade_state_dict_named(state_dict, name)[source]
class fairseq.modules.BeamableMM(beam_size=None)[source]

This module provides an optimized MM for beam decoding with attention.

It leverage the fact that the source-side of the input is replicated beam times and the target-side of the input is of width one. This layer speeds up inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.

forward(input1, input2)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

set_beam_size(beam_size)[source]
class fairseq.modules.CharacterTokenEmbedder(vocab: fairseq.data.dictionary.Dictionary, filters: List[Tuple[int, int]], char_embed_dim: int, word_embed_dim: int, highway_layers: int, max_char_len: int = 50, char_inputs: bool = False)[source]
forward(input: torch.Tensor)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

padding_idx
prepare_for_onnx_export_()[source]
reset_parameters()[source]
set_vocab(vocab, max_char_len)[source]
class fairseq.modules.ConvTBC(in_channels, out_channels, kernel_size, padding=0)[source]

1D convolution over an input of shape (time x batch x channel)

The implementation uses gemm to perform the convolution. This implementation is faster than cuDNN for small kernel sizes.

forward(input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fairseq.modules.DownsampledMultiHeadAttention(out_channels, embed_dim, num_heads, dropout=0.0, bias=True, project_input=True, gated=False, downsample=False)[source]

Multi-headed attention with Gating and Downsampling

forward(query, key, value, mask_future_timesteps=False, key_padding_mask=None, use_scalar_bias=False)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fairseq.modules.DynamicConv1dTBC(input_size, kernel_size=1, padding_l=None, num_heads=1, weight_dropout=0.0, weight_softmax=False, renorm_padding=False, bias=False, conv_bias=False, query_size=None, in_proj=False)[source]

Dynamic lightweight convolution taking T x B x C inputs :param input_size: # of channels of the input :param kernel_size: convolution channels :param padding_l: padding to the left when using “same” padding :param num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) :param weight_dropout: the drop rate of the DropConnect to drop the weight :param weight_softmax: normalize the weight with softmax before the convolution :param renorm_padding: re-normalize the filters to ignore the padded part (only the non-padding parts sum up to 1) :param bias: use bias :param conv_bias: bias of the convolution :param query_size: specified when feeding a different input as the query :param in_proj: project the input and generate the filter together

Shape:
Input: TxBxC, i.e. (timesteps, batch_size, input_size) Output: TxBxC, i.e. (timesteps, batch_size, input_size)
weight

the learnable weights of the module of shape (num_heads, 1, kernel_size)

bias

the learnable bias of the module of shape (input_size)

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x, incremental_state=None, query=None, unfold=None)[source]

Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C :param x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) :param incremental_state: A dict to keep the state :param unfold: unfold the input or not. If not, we use the matrix trick instead :param query: use the specified query to predict the conv filters

in_proj
reorder_incremental_state(incremental_state, new_order)[source]
reset_parameters()[source]
fairseq.modules.gelu(x: torch.Tensor) → torch.Tensor[source]
fairseq.modules.gelu_accurate(x)[source]
class fairseq.modules.GradMultiply[source]
static backward(ctx, grad)[source]

Defines a formula for differentiating the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs did forward() return, and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computated w.r.t. the output.

static forward(ctx, x, scale)[source]

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store tensors that can be then retrieved during the backward pass.

class fairseq.modules.Highway(input_dim: int, num_layers: int = 1)[source]

A Highway layer. Adopted from the AllenNLP implementation.

forward(x: torch.Tensor)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters()[source]
fairseq.modules.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, export=False)[source]
class fairseq.modules.LearnedPositionalEmbedding(num_embeddings: int, embedding_dim: int, padding_idx: int)[source]

This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to the forward function.

forward(input, incremental_state=None, positions=None)[source]

Input is expected to be of size [bsz x seqlen].

max_positions()[source]

Maximum number of supported positions.

class fairseq.modules.LightweightConv1dTBC(input_size, kernel_size=1, padding_l=None, num_heads=1, weight_dropout=0.0, weight_softmax=False, bias=False)[source]

Lightweight Convolution assuming the input is TxBxC :param input_size: # of channels of the input :param kernel_size: convolution channels :param padding_l: padding to the left when using “same” padding :param num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) :param weight_dropout: the drop rate of the DropConnect to drop the weight :param weight_softmax: normalize the weight with softmax before the convolution :param bias: use bias

Shape:
Input: TxBxC, i.e. (timesteps, batch_size, input_size) Output: TxBxC, i.e. (timesteps, batch_size, input_size)
weight

the learnable weights of the module of shape (num_heads, 1, kernel_size)

bias

the learnable bias of the module of shape (input_size)

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x, incremental_state=None, unfold=False)[source]

Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C :param x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) :param incremental_state: A dict to keep the state :param unfold: unfold the input or not. If not, we use the matrix trick instead

prepare_for_onnx_export_()[source]
reorder_incremental_state(incremental_state, new_order)[source]
reset_parameters()[source]
class fairseq.modules.LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)[source]

An optimized version of nn.Conv1d.

At training time, this module uses ConvTBC, which is an optimized version of Conv1d. At inference time, it optimizes incremental generation (i.e., one time step at a time) by replacing the convolutions with linear layers. Note that the input order changes from training to inference.

forward(input, incremental_state=None)[source]
Parameters:incremental_state – Used to buffer signal; if not None, then input is expected to contain a single frame. If the input order changes between time steps, call reorder_incremental_state.
Input:
Time x Batch x Channel during training Batch x Time x Channel during inference
reorder_incremental_state(incremental_state, new_order)[source]
class fairseq.modules.LogSumExpMoE[source]

Standard LogSumExp forward pass, but use posterior for the backward.

See “Mixture Models for Diverse Machine Translation: Tricks of the Trade” (Shen et al., 2019).

static backward(ctx, grad_output)[source]

Defines a formula for differentiating the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs did forward() return, and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computated w.r.t. the output.

static forward(ctx, logp, posterior, dim=-1)[source]

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store tensors that can be then retrieved during the backward pass.

class fairseq.modules.MeanPoolGatingNetwork(embed_dim, num_experts, dropout=None)[source]

A simple mean-pooling gating network for selecting experts.

This module applies mean pooling over an encoder’s output and returns reponsibilities for each expert. The encoder format is expected to match fairseq.models.transformer.TransformerEncoder.

forward(encoder_out)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fairseq.modules.MultiheadAttention(embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False)[source]

Multi-headed attention.

See “Attention Is All You Need” for more details.

forward(query, key, value, key_padding_mask=None, incremental_state=None, need_weights=True, static_kv=False, attn_mask=None)[source]

Input shape: Time x Batch x Channel

Timesteps can be masked by supplying a T x T mask in the attn_mask argument. Padding elements can be excluded from the key by passing a binary ByteTensor (key_padding_mask) with shape: batch x src_len, where padding elements are indicated by 1s.

in_proj_k(key)[source]
in_proj_q(query)[source]
in_proj_qkv(query)[source]
in_proj_v(value)[source]
prepare_for_onnx_export_()[source]
reorder_incremental_state(incremental_state, new_order)[source]

Reorder buffered internal state (for incremental generation).

reset_parameters()[source]
fairseq.modules.PositionalEmbedding(num_embeddings: int, embedding_dim: int, padding_idx: int, learned: bool = False)[source]
class fairseq.modules.ScalarBias[source]

Adds a vector of scalars, used in self-attention mechanism to allow the model to optionally attend to this vector instead of the past

static backward(ctx, grad)[source]

Defines a formula for differentiating the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs did forward() return, and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computated w.r.t. the output.

static forward(ctx, input, dim, bias_init)[source]

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store tensors that can be then retrieved during the backward pass.

class fairseq.modules.SinusoidalPositionalEmbedding(embedding_dim, padding_idx, init_size=1024)[source]

This module produces sinusoidal positional embeddings of any length.

Padding symbols are ignored.

forward(input, incremental_state=None, timestep=None, **kwargs)[source]

Input is expected to be of size [bsz x seqlen].

static get_embedding(num_embeddings, embedding_dim, padding_idx=None)[source]

Build sinusoidal embeddings.

This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of “Attention Is All You Need”.

max_positions()[source]

Maximum number of supported positions.

prepare_for_onnx_export_()[source]
class fairseq.modules.TransformerSentenceEncoderLayer(embedding_dim: float = 768, ffn_embedding_dim: float = 3072, num_attention_heads: float = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, activation_fn: str = 'relu', add_bias_kv: bool = False, add_zero_attn: bool = False, export: bool = False)[source]

Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained models.

forward(x: torch.Tensor, self_attn_mask: torch.Tensor = None, self_attn_padding_mask: torch.Tensor = None)[source]

LayerNorm is applied either before or after the self-attention/ffn modules similar to the original Transformer imlementation.

class fairseq.modules.TransformerSentenceEncoder(padding_idx: int, vocab_size: int, num_encoder_layers: int = 6, embedding_dim: int = 768, ffn_embedding_dim: int = 3072, num_attention_heads: int = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, max_seq_len: int = 256, num_segments: int = 2, use_position_embeddings: bool = True, offset_positions_by_padding: bool = True, encoder_normalize_before: bool = False, apply_bert_init: bool = False, activation_fn: str = 'relu', learned_pos_embedding: bool = True, add_bias_kv: bool = False, add_zero_attn: bool = False, embed_scale: float = None, freeze_embeddings: bool = False, n_trans_layers_to_freeze: int = 0, export: bool = False)[source]

Implementation for a Bi-directional Transformer based Sentence Encoder used in BERT/XLM style pre-trained models.

This first computes the token embedding using the token embedding matrix, position embeddings (if specified) and segment embeddings (if specified). After applying the specified number of TransformerEncoderLayers, it outputs all the internal states of the encoder as well as the final representation associated with the first token (usually CLS token).

Input:
  • tokens: B x T matrix representing sentences
  • segment_labels: B x T matrix representing segment label for tokens
Output:
  • a tuple of the following:
    • a list of internal model states used to compute the predictions where each tensor has shape B x T x C
    • sentence representation associated with first input token in format B x C.
forward(tokens: torch.Tensor, segment_labels: torch.Tensor, last_state_only: bool = False, positions: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

fairseq.modules.unfold1d(x, kernel_size, padding_l, pad_value=0)[source]

unfold T x B x C to T x B x C x K

Indices and tables