fairseq documentation¶
Fairseq is a sequence modeling toolkit written in PyTorch that allows researchers and developers to train custom models for translation, summarization, language modeling and other text generation tasks.
Evaluating Pre-trained Models¶
First, download a pre-trained model along with its vocabularies:
> curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
This model uses a Byte Pair Encoding (BPE)
vocabulary, so we’ll have to apply
the encoding to the source text before it can be translated. This can be
done with the
apply_bpe.py
script using the wmt14.en-fr.fconv-cuda/bpecodes
file. @@
is
used as a continuation marker and the original text can be easily
recovered with e.g. sed s/@@ //g
or by passing the --remove-bpe
flag to fairseq-generate. Prior to BPE, input text needs to be tokenized
using tokenizer.perl
from
mosesdecoder.
Let’s use fairseq-interactive to generate translations interactively. Here, we use a beam size of 5 and preprocess the input with the Moses tokenizer and the given Byte-Pair Encoding vocabulary. It will automatically remove the BPE continuation markers and detokenize the output.
> MODEL_DIR=wmt14.en-fr.fconv-py
> fairseq-interactive \
--path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5 --source-lang en --target-lang fr \
--tokenizer moses \
--bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types
| [fr] dictionary: 44463 types
| Type the input sentence and press return:
Why is it rare to discover new marine mammal species?
S-0 Why is it rare to discover new marine mam@@ mal species ?
H-0 -0.0643349438905716 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
P-0 -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015
This generation script produces three types of outputs: a line prefixed with O is a copy of the original source sentence; H is the hypothesis along with an average log-likelihood; and P is the positional score per token position, including the end-of-sentence marker which is omitted from the text.
See the README for a full list of pre-trained models available.
Training a New Model¶
The following tutorial is for machine translation. For an example of how
to use Fairseq for other tasks, such as Language Modeling, please see the
examples/
directory.
Data Pre-processing¶
Fairseq contains example pre-processing scripts for several translation datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT 2014 (English-German). To pre-process and binarize the IWSLT dataset:
> cd examples/translation/
> bash prepare-iwslt14.sh
> cd ../..
> TEXT=examples/translation/iwslt14.tokenized.de-en
> fairseq-preprocess --source-lang de --target-lang en \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/iwslt14.tokenized.de-en
This will write binarized data that can be used for model training to
data-bin/iwslt14.tokenized.de-en
.
Training¶
Use fairseq-train to train a new model. Here a few example settings that work well for the IWSLT 2014 dataset:
> mkdir -p checkpoints/fconv
> CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
--lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
--arch fconv_iwslt_de_en --save-dir checkpoints/fconv
By default, fairseq-train will use all available GPUs on your machine. Use the
CUDA_VISIBLE_DEVICES
environment variable to select specific GPUs and/or to
change the number of GPU devices that will be used.
Also note that the batch size is specified in terms of the maximum
number of tokens per batch (--max-tokens
). You may need to use a
smaller value depending on the available GPU memory on your system.
Generation¶
Once your model is trained, you can generate translations using fairseq-generate (for binarized data) or fairseq-interactive (for raw text):
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
--path checkpoints/fconv/checkpoint_best.pt \
--batch-size 128 --beam 5
| [de] dictionary: 35475 types
| [en] dictionary: 24739 types
| data-bin/iwslt14.tokenized.de-en test 6750 examples
| model fconv
| loaded checkpoint trainings/fconv/checkpoint_best.pt
S-721 danke .
T-721 thank you .
...
To generate translations with only a CPU, use the --cpu
flag. BPE
continuation markers can be removed with the --remove-bpe
flag.
Advanced Training Options¶
Large mini-batch training with delayed updates¶
The --update-freq
option can be used to accumulate gradients from
multiple mini-batches and delay updating, creating a larger effective
batch size. Delayed updates can also improve training speed by reducing
inter-GPU communication costs and by saving idle time caused by variance
in workload across GPUs. See Ott et al.
(2018) for more details.
To train on a single GPU with an effective batch size that is equivalent to training on 8 GPUs:
> CUDA_VISIBLE_DEVICES=0 fairseq-train --update-freq 8 (...)
Training with half precision floating point (FP16)¶
Note
FP16 training requires a Volta GPU and CUDA 9.1 or greater
Recent GPUs enable efficient half precision floating point computation,
e.g., using Nvidia Tensor Cores.
Fairseq supports FP16 training with the --fp16
flag:
> fairseq-train --fp16 (...)
Lazily loading large training datasets¶
By default fairseq loads the entire training set into system memory. For large
datasets, the --lazy-load
option can be used to instead load batches on-demand.
For optimal performance, use the --num-workers
option to control the number
of background processes that will load batches.
Distributed training¶
Distributed training in fairseq is implemented on top of torch.distributed
.
The easiest way to launch jobs is with the torch.distributed.launch tool.
For example, to train a large English-German Transformer model on 2 nodes each
with 8 GPUs (in total 16 GPUs), run the following command on each node,
replacing node_rank=0
with node_rank=1
on the second node:
> python -m torch.distributed.launch --nproc_per_node=8 \
--nnodes=2 --node_rank=0 --master_addr="192.168.1.1" \
--master_port=1234 \
$(which fairseq-train) data-bin/wmt16_en_de_bpe32k \
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
--lr 0.0005 --min-lr 1e-09 \
--dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-tokens 3584 \
--fp16
Command-line Tools¶
Fairseq provides several command-line tools for training and evaluating models:
- fairseq-preprocess: Data pre-processing: build vocabularies and binarize training data
- fairseq-train: Train a new model on one or multiple GPUs
- fairseq-generate: Translate pre-processed data with a trained model
- fairseq-interactive: Translate raw text with a trained model
- fairseq-score: BLEU scoring of generated translations against reference translations
- fairseq-eval-lm: Language model evaluation
fairseq-preprocess¶
Data pre-processing: build vocabularies and binarize training data.
usage: fairseq-preprocess [-h] [--no-progress-bar] [--log-interval N]
[--log-format {json,none,simple,tqdm}]
[--tensorboard-logdir DIR] [--tbmf-wrapper]
[--seed N] [--cpu] [--fp16]
[--memory-efficient-fp16]
[--fp16-init-scale FP16_INIT_SCALE]
[--fp16-scale-window FP16_SCALE_WINDOW]
[--fp16-scale-tolerance FP16_SCALE_TOLERANCE]
[--min-loss-scale D]
[--threshold-loss-scale THRESHOLD_LOSS_SCALE]
[--user-dir USER_DIR]
[--criterion {sentence_prediction,adaptive_loss,label_smoothed_cross_entropy,composite_loss,binary_cross_entropy,masked_lm,sentence_ranking,legacy_masked_lm_loss,cross_entropy}]
[--tokenizer {nltk,space,moses}]
[--bpe {gpt2,sentencepiece,subword_nmt,fastbpe}]
[--optimizer {adadelta,adam,adafactor,adagrad,nag,adamax,sgd}]
[--lr-scheduler {cosine,reduce_lr_on_plateau,fixed,triangular,polynomial_decay,inverse_sqrt}]
[--task TASK] [-s SRC] [-t TARGET] [--trainpref FP]
[--validpref FP] [--testpref FP] [--destdir DIR]
[--thresholdtgt N] [--thresholdsrc N] [--tgtdict FP]
[--srcdict FP] [--nwordstgt N] [--nwordssrc N]
[--alignfile ALIGN] [--dataset-impl FORMAT]
[--joined-dictionary] [--only-source]
[--padding-factor N] [--workers N]
Named Arguments¶
--no-progress-bar | disable progress bar Default: False |
--log-interval | log progress every N batches (when progress bar is disabled) Default: 1000 |
--log-format | Possible choices: json, none, simple, tqdm log format to use |
--tensorboard-logdir | path to save logs for tensorboard, should match –logdir of running tensorboard (default: no tensorboard logging) Default: “” |
--tbmf-wrapper | [FB only] Default: False |
--seed | pseudo random number generator seed Default: 1 |
--cpu | use CPU instead of CUDA Default: False |
--fp16 | use FP16 Default: False |
--memory-efficient-fp16 | use a memory-efficient version of FP16 training; implies –fp16 Default: False |
--fp16-init-scale | default FP16 loss scale Default: 128 |
--fp16-scale-window | number of updates before increasing loss scale |
--fp16-scale-tolerance | pct of updates that can overflow before decreasing the loss scale Default: 0.0 |
--min-loss-scale | minimum FP16 loss scale, after which training is stopped Default: 0.0001 |
--threshold-loss-scale | threshold FP16 loss scale from below |
--user-dir | path to a python module containing custom extensions (tasks and/or architectures) |
--criterion | Possible choices: sentence_prediction, adaptive_loss, label_smoothed_cross_entropy, composite_loss, binary_cross_entropy, masked_lm, sentence_ranking, legacy_masked_lm_loss, cross_entropy Default: “cross_entropy” |
--tokenizer | Possible choices: nltk, space, moses |
--bpe | Possible choices: gpt2, sentencepiece, subword_nmt, fastbpe |
--optimizer | Possible choices: adadelta, adam, adafactor, adagrad, nag, adamax, sgd Default: “nag” |
--lr-scheduler | Possible choices: cosine, reduce_lr_on_plateau, fixed, triangular, polynomial_decay, inverse_sqrt Default: “fixed” |
--task | Possible choices: sentence_prediction, translation, translation_from_pretrained_xlm, multilingual_translation, semisupervised_translation, cross_lingual_lm, masked_lm, sentence_ranking, audio_pretraining, legacy_masked_lm, translation_moe, language_modeling task Default: “translation” |
--dataset-impl | Possible choices: raw, lazy, cached, mmap output dataset implementation Default: “mmap” |
Preprocessing¶
-s, --source-lang | source language |
-t, --target-lang | target language |
--trainpref | train file prefix |
--validpref | comma separated, valid file prefixes |
--testpref | comma separated, test file prefixes |
--destdir | destination dir Default: “data-bin” |
--thresholdtgt | map words appearing less than threshold times to unknown Default: 0 |
--thresholdsrc | map words appearing less than threshold times to unknown Default: 0 |
--tgtdict | reuse given target dictionary |
--srcdict | reuse given source dictionary |
--nwordstgt | number of target words to retain Default: -1 |
--nwordssrc | number of source words to retain Default: -1 |
--alignfile | an alignment file (optional) |
--joined-dictionary | Generate joined dictionary Default: False |
--only-source | Only process the source language Default: False |
--padding-factor | Pad dictionary size to be multiple of N Default: 8 |
--workers | number of parallel workers Default: 1 |
fairseq-train¶
Train a new model on one or across multiple GPUs.
usage: fairseq-train [-h] [--no-progress-bar] [--log-interval N]
[--log-format {json,none,simple,tqdm}]
[--tensorboard-logdir DIR] [--tbmf-wrapper] [--seed N]
[--cpu] [--fp16] [--memory-efficient-fp16]
[--fp16-init-scale FP16_INIT_SCALE]
[--fp16-scale-window FP16_SCALE_WINDOW]
[--fp16-scale-tolerance FP16_SCALE_TOLERANCE]
[--min-loss-scale D]
[--threshold-loss-scale THRESHOLD_LOSS_SCALE]
[--user-dir USER_DIR]
[--criterion {sentence_prediction,adaptive_loss,label_smoothed_cross_entropy,composite_loss,binary_cross_entropy,masked_lm,sentence_ranking,legacy_masked_lm_loss,cross_entropy}]
[--tokenizer {nltk,space,moses}]
[--bpe {gpt2,sentencepiece,subword_nmt,fastbpe}]
[--optimizer {adadelta,adam,adafactor,adagrad,nag,adamax,sgd}]
[--lr-scheduler {cosine,reduce_lr_on_plateau,fixed,triangular,polynomial_decay,inverse_sqrt}]
[--task TASK] [--num-workers N]
[--skip-invalid-size-inputs-valid-test] [--max-tokens N]
[--max-sentences N] [--required-batch-size-multiple N]
[--dataset-impl FORMAT] [--train-subset SPLIT]
[--valid-subset SPLIT] [--validate-interval N]
[--disable-validation] [--max-tokens-valid N]
[--max-sentences-valid N] [--curriculum N]
[--distributed-world-size N]
[--distributed-rank DISTRIBUTED_RANK]
[--distributed-backend DISTRIBUTED_BACKEND]
[--distributed-init-method DISTRIBUTED_INIT_METHOD]
[--distributed-port DISTRIBUTED_PORT]
[--device-id DEVICE_ID] [--distributed-no-spawn]
[--ddp-backend {c10d,no_c10d}] [--bucket-cap-mb MB]
[--fix-batches-to-gpus] [--find-unused-parameters] --arch
ARCH [--max-epoch N] [--max-update N] [--clip-norm NORM]
[--sentence-avg] [--update-freq N1,N2,...,N_K]
[--lr LR_1,LR_2,...,LR_N] [--min-lr LR] [--use-bmuf]
[--save-dir DIR] [--restore-file RESTORE_FILE]
[--reset-dataloader] [--reset-lr-scheduler]
[--reset-meters] [--reset-optimizer]
[--optimizer-overrides DICT] [--save-interval N]
[--save-interval-updates N] [--keep-interval-updates N]
[--keep-last-epochs N] [--no-save]
[--no-epoch-checkpoints] [--no-last-checkpoints]
[--no-save-optimizer-state]
[--best-checkpoint-metric BEST_CHECKPOINT_METRIC]
[--maximize-best-checkpoint-metric]
Named Arguments¶
--no-progress-bar | disable progress bar Default: False |
--log-interval | log progress every N batches (when progress bar is disabled) Default: 1000 |
--log-format | Possible choices: json, none, simple, tqdm log format to use |
--tensorboard-logdir | path to save logs for tensorboard, should match –logdir of running tensorboard (default: no tensorboard logging) Default: “” |
--tbmf-wrapper | [FB only] Default: False |
--seed | pseudo random number generator seed Default: 1 |
--cpu | use CPU instead of CUDA Default: False |
--fp16 | use FP16 Default: False |
--memory-efficient-fp16 | use a memory-efficient version of FP16 training; implies –fp16 Default: False |
--fp16-init-scale | default FP16 loss scale Default: 128 |
--fp16-scale-window | number of updates before increasing loss scale |
--fp16-scale-tolerance | pct of updates that can overflow before decreasing the loss scale Default: 0.0 |
--min-loss-scale | minimum FP16 loss scale, after which training is stopped Default: 0.0001 |
--threshold-loss-scale | threshold FP16 loss scale from below |
--user-dir | path to a python module containing custom extensions (tasks and/or architectures) |
--criterion | Possible choices: sentence_prediction, adaptive_loss, label_smoothed_cross_entropy, composite_loss, binary_cross_entropy, masked_lm, sentence_ranking, legacy_masked_lm_loss, cross_entropy Default: “cross_entropy” |
--tokenizer | Possible choices: nltk, space, moses |
--bpe | Possible choices: gpt2, sentencepiece, subword_nmt, fastbpe |
--optimizer | Possible choices: adadelta, adam, adafactor, adagrad, nag, adamax, sgd Default: “nag” |
--lr-scheduler | Possible choices: cosine, reduce_lr_on_plateau, fixed, triangular, polynomial_decay, inverse_sqrt Default: “fixed” |
--task | Possible choices: sentence_prediction, translation, translation_from_pretrained_xlm, multilingual_translation, semisupervised_translation, cross_lingual_lm, masked_lm, sentence_ranking, audio_pretraining, legacy_masked_lm, translation_moe, language_modeling task Default: “translation” |
--dataset-impl | Possible choices: raw, lazy, cached, mmap output dataset implementation |
Dataset and data loading¶
--num-workers | how many subprocesses to use for data loading Default: 1 |
--skip-invalid-size-inputs-valid-test | ignore too long or too short lines in valid and test set Default: False |
--max-tokens | maximum number of tokens in a batch |
--max-sentences, --batch-size | maximum number of sentences in a batch |
--required-batch-size-multiple | batch size will be a multiplier of this value Default: 8 |
--train-subset | Possible choices: train, valid, test data subset to use for training (train, valid, test) Default: “train” |
--valid-subset | comma separated list of data subsets to use for validation (train, valid, valid1, test, test1) Default: “valid” |
--validate-interval | validate every N epochs Default: 1 |
--disable-validation | disable validation Default: False |
--max-tokens-valid | maximum number of tokens in a validation batch (defaults to –max-tokens) |
--max-sentences-valid | maximum number of sentences in a validation batch (defaults to –max-sentences) |
--curriculum | don’t shuffle batches for first N epochs Default: 0 |
Distributed training¶
--distributed-world-size | total number of GPUs across all nodes (default: all visible GPUs) Default: 1 |
--distributed-rank | rank of the current worker Default: 0 |
--distributed-backend | distributed backend Default: “nccl” |
--distributed-init-method | typically tcp://hostname:port that will be used to establish initial connetion |
--distributed-port | port number (not required if using –distributed-init-method) Default: -1 |
--device-id, --local_rank | which GPU to use (usually configured automatically) Default: 0 |
--distributed-no-spawn | do not spawn multiple processes even if multiple GPUs are visible Default: False |
--ddp-backend | Possible choices: c10d, no_c10d DistributedDataParallel backend Default: “c10d” |
--bucket-cap-mb | bucket size for reduction Default: 25 |
--fix-batches-to-gpus | don’t shuffle batches between GPUs; this reduces overall randomness and may affect precision but avoids the cost of re-reading the data Default: False |
--find-unused-parameters | disable unused parameter detection (not applicable to no_c10d ddp-backend Default: False |
Model configuration¶
--arch, -a | Possible choices: roberta, roberta_base, roberta_large, transformer, transformer_iwslt_de_en, transformer_wmt_en_de, transformer_vaswani_wmt_en_de_big, transformer_vaswani_wmt_en_fr_big, transformer_wmt_en_de_big, transformer_wmt_en_de_big_t2t, transformer_from_pretrained_xlm, transformer_lm, transformer_lm_big, transformer_lm_baevski_wiki103, transformer_lm_wiki103, transformer_lm_baevski_gbw, transformer_lm_gbw, transformer_lm_gpt, transformer_lm_gpt2_small, transformer_lm_gpt2_medium, transformer_lm_gpt2_big, lightconv, lightconv_iwslt_de_en, lightconv_wmt_en_de, lightconv_wmt_en_de_big, lightconv_wmt_en_fr_big, lightconv_wmt_zh_en_big, masked_lm, bert_base, bert_large, xlm_base, fconv, fconv_iwslt_de_en, fconv_wmt_en_ro, fconv_wmt_en_de, fconv_wmt_en_fr, fconv_lm, fconv_lm_dauphin_wikitext103, fconv_lm_dauphin_gbw, wav2vec, lightconv_lm, lightconv_lm_gbw, fconv_self_att, fconv_self_att_wp, lstm, lstm_wiseman_iwslt_de_en, lstm_luong_wmt_en_de, multilingual_transformer, multilingual_transformer_iwslt_de_en Model Architecture Default: “fconv” |
Optimization¶
--max-epoch, --me | force stop training at specified epoch Default: 0 |
--max-update, --mu | force stop training at specified update Default: 0 |
--clip-norm | clip threshold of gradients Default: 25 |
--sentence-avg | normalize gradients by the number of sentences in a batch (default is to normalize by number of tokens) Default: False |
--update-freq | update parameters every N_i batches, when in epoch i Default: 1 |
--lr, --learning-rate | learning rate for the first N epochs; all epochs >N using LR_N (note: this may be interpreted differently depending on –lr-scheduler) Default: 0.25 |
--min-lr | stop training when the learning rate reaches this minimum Default: -1 |
--use-bmuf | specify global optimizer for syncing models on different GPUs/shards Default: False |
Checkpointing¶
--save-dir | path to save checkpoints Default: “checkpoints” |
--restore-file | filename from which to load checkpoint (default: <save-dir>/checkpoint_last.pt Default: “checkpoint_last.pt” |
--reset-dataloader | if set, does not reload dataloader state from the checkpoint Default: False |
--reset-lr-scheduler | if set, does not load lr scheduler state from the checkpoint Default: False |
--reset-meters | if set, does not load meters from the checkpoint Default: False |
--reset-optimizer | if set, does not load optimizer state from the checkpoint Default: False |
--optimizer-overrides | a dictionary used to override optimizer args when loading a checkpoint Default: “{}” |
--save-interval | save a checkpoint every N epochs Default: 1 |
--save-interval-updates | save a checkpoint (and validate) every N updates Default: 0 |
--keep-interval-updates | keep the last N checkpoints saved with –save-interval-updates Default: -1 |
--keep-last-epochs | keep last N epoch checkpoints Default: -1 |
--no-save | don’t save models or checkpoints Default: False |
--no-epoch-checkpoints | only store last and best checkpoints Default: False |
--no-last-checkpoints | don’t store last checkpoints Default: False |
--no-save-optimizer-state | don’t save optimizer-state as part of checkpoint Default: False |
--best-checkpoint-metric | metric to use for saving “best” checkpoints Default: “loss” |
--maximize-best-checkpoint-metric | select the largest metric value for saving “best” checkpoints Default: False |
fairseq-generate¶
fairseq-interactive¶
Translate raw text with a trained model. Batches data on-the-fly.
usage: fairseq-interactive [-h] [--no-progress-bar] [--log-interval N]
[--log-format {json,none,simple,tqdm}]
[--tensorboard-logdir DIR] [--tbmf-wrapper]
[--seed N] [--cpu] [--fp16]
[--memory-efficient-fp16]
[--fp16-init-scale FP16_INIT_SCALE]
[--fp16-scale-window FP16_SCALE_WINDOW]
[--fp16-scale-tolerance FP16_SCALE_TOLERANCE]
[--min-loss-scale D]
[--threshold-loss-scale THRESHOLD_LOSS_SCALE]
[--user-dir USER_DIR]
[--criterion {sentence_prediction,adaptive_loss,label_smoothed_cross_entropy,composite_loss,binary_cross_entropy,masked_lm,sentence_ranking,legacy_masked_lm_loss,cross_entropy}]
[--tokenizer {nltk,space,moses}]
[--bpe {gpt2,sentencepiece,subword_nmt,fastbpe}]
[--optimizer {adadelta,adam,adafactor,adagrad,nag,adamax,sgd}]
[--lr-scheduler {cosine,reduce_lr_on_plateau,fixed,triangular,polynomial_decay,inverse_sqrt}]
[--task TASK] [--num-workers N]
[--skip-invalid-size-inputs-valid-test]
[--max-tokens N] [--max-sentences N]
[--required-batch-size-multiple N]
[--dataset-impl FORMAT] [--gen-subset SPLIT]
[--num-shards N] [--shard-id ID] [--path FILE]
[--remove-bpe [REMOVE_BPE]] [--quiet]
[--model-overrides DICT] [--results-path RESDIR]
[--beam N] [--nbest N] [--max-len-a N]
[--max-len-b N] [--min-len N] [--match-source-len]
[--no-early-stop] [--unnormalized]
[--no-beamable-mm] [--lenpen LENPEN]
[--unkpen UNKPEN] [--replace-unk [REPLACE_UNK]]
[--sacrebleu] [--score-reference]
[--prefix-size PS] [--no-repeat-ngram-size N]
[--sampling] [--sampling-topk PS]
[--sampling-topp PS] [--temperature N]
[--diverse-beam-groups N]
[--diverse-beam-strength N] [--print-alignment]
[--buffer-size N] [--input FILE]
Named Arguments¶
--no-progress-bar | disable progress bar Default: False |
--log-interval | log progress every N batches (when progress bar is disabled) Default: 1000 |
--log-format | Possible choices: json, none, simple, tqdm log format to use |
--tensorboard-logdir | path to save logs for tensorboard, should match –logdir of running tensorboard (default: no tensorboard logging) Default: “” |
--tbmf-wrapper | [FB only] Default: False |
--seed | pseudo random number generator seed Default: 1 |
--cpu | use CPU instead of CUDA Default: False |
--fp16 | use FP16 Default: False |
--memory-efficient-fp16 | use a memory-efficient version of FP16 training; implies –fp16 Default: False |
--fp16-init-scale | default FP16 loss scale Default: 128 |
--fp16-scale-window | number of updates before increasing loss scale |
--fp16-scale-tolerance | pct of updates that can overflow before decreasing the loss scale Default: 0.0 |
--min-loss-scale | minimum FP16 loss scale, after which training is stopped Default: 0.0001 |
--threshold-loss-scale | threshold FP16 loss scale from below |
--user-dir | path to a python module containing custom extensions (tasks and/or architectures) |
--criterion | Possible choices: sentence_prediction, adaptive_loss, label_smoothed_cross_entropy, composite_loss, binary_cross_entropy, masked_lm, sentence_ranking, legacy_masked_lm_loss, cross_entropy Default: “cross_entropy” |
--tokenizer | Possible choices: nltk, space, moses |
--bpe | Possible choices: gpt2, sentencepiece, subword_nmt, fastbpe |
--optimizer | Possible choices: adadelta, adam, adafactor, adagrad, nag, adamax, sgd Default: “nag” |
--lr-scheduler | Possible choices: cosine, reduce_lr_on_plateau, fixed, triangular, polynomial_decay, inverse_sqrt Default: “fixed” |
--task | Possible choices: sentence_prediction, translation, translation_from_pretrained_xlm, multilingual_translation, semisupervised_translation, cross_lingual_lm, masked_lm, sentence_ranking, audio_pretraining, legacy_masked_lm, translation_moe, language_modeling task Default: “translation” |
--dataset-impl | Possible choices: raw, lazy, cached, mmap output dataset implementation |
Dataset and data loading¶
--num-workers | how many subprocesses to use for data loading Default: 1 |
--skip-invalid-size-inputs-valid-test | ignore too long or too short lines in valid and test set Default: False |
--max-tokens | maximum number of tokens in a batch |
--max-sentences, --batch-size | maximum number of sentences in a batch |
--required-batch-size-multiple | batch size will be a multiplier of this value Default: 8 |
--gen-subset | data subset to generate (train, valid, test) Default: “test” |
--num-shards | shard generation over N shards Default: 1 |
--shard-id | id of the shard to generate (id < num_shards) Default: 0 |
Generation¶
--path | path(s) to model file(s), colon separated |
--remove-bpe | remove BPE tokens before scoring (can be set to sentencepiece) |
--quiet | only print final scores Default: False |
--model-overrides | a dictionary used to override model args at generation that were used during model training Default: “{}” |
--results-path | path to save eval results (optional)” |
--beam | beam size Default: 5 |
--nbest | number of hypotheses to output Default: 1 |
--max-len-a | generate sequences of maximum length ax + b, where x is the source length Default: 0 |
--max-len-b | generate sequences of maximum length ax + b, where x is the source length Default: 200 |
--min-len | minimum generation length Default: 1 |
--match-source-len | generations should match the source length Default: False |
--no-early-stop | deprecated Default: False |
--unnormalized | compare unnormalized hypothesis scores Default: False |
--no-beamable-mm | don’t use BeamableMM in attention layers Default: False |
--lenpen | length penalty: <1.0 favors shorter, >1.0 favors longer sentences Default: 1 |
--unkpen | unknown word penalty: <0 produces more unks, >0 produces fewer Default: 0 |
--replace-unk | perform unknown replacement (optionally with alignment dictionary) |
--sacrebleu | score with sacrebleu Default: False |
--score-reference | just score the reference translation Default: False |
--prefix-size | initialize generation by target prefix of given length Default: 0 |
--no-repeat-ngram-size | ngram blocking such that this size ngram cannot be repeated in the generation Default: 0 |
--sampling | sample hypotheses instead of using beam search Default: False |
--sampling-topk | sample from top K likely next words instead of all words Default: -1 |
--sampling-topp | sample from the smallest set whose cumulative probability mass exceeds p for next words Default: -1.0 |
--temperature | temperature for generation Default: 1.0 |
--diverse-beam-groups | number of groups for Diverse Beam Search Default: -1 |
--diverse-beam-strength | strength of diversity penalty for Diverse Beam Search Default: 0.5 |
--print-alignment | if set, uses attention feedback to compute and print alignment to source tokens Default: False |
Interactive¶
--buffer-size | read this many sentences into a buffer before processing them Default: 0 |
--input | file to read from; use - for stdin Default: “-“ |
fairseq-score¶
fairseq-eval-lm¶
Evaluate the perplexity of a trained language model.
usage: fairseq-eval-lm [-h] [--no-progress-bar] [--log-interval N]
[--log-format {json,none,simple,tqdm}]
[--tensorboard-logdir DIR] [--tbmf-wrapper] [--seed N]
[--cpu] [--fp16] [--memory-efficient-fp16]
[--fp16-init-scale FP16_INIT_SCALE]
[--fp16-scale-window FP16_SCALE_WINDOW]
[--fp16-scale-tolerance FP16_SCALE_TOLERANCE]
[--min-loss-scale D]
[--threshold-loss-scale THRESHOLD_LOSS_SCALE]
[--user-dir USER_DIR]
[--criterion {sentence_prediction,adaptive_loss,label_smoothed_cross_entropy,composite_loss,binary_cross_entropy,masked_lm,sentence_ranking,legacy_masked_lm_loss,cross_entropy}]
[--tokenizer {nltk,space,moses}]
[--bpe {gpt2,sentencepiece,subword_nmt,fastbpe}]
[--optimizer {adadelta,adam,adafactor,adagrad,nag,adamax,sgd}]
[--lr-scheduler {cosine,reduce_lr_on_plateau,fixed,triangular,polynomial_decay,inverse_sqrt}]
[--task TASK] [--num-workers N]
[--skip-invalid-size-inputs-valid-test]
[--max-tokens N] [--max-sentences N]
[--required-batch-size-multiple N]
[--dataset-impl FORMAT] [--gen-subset SPLIT]
[--num-shards N] [--shard-id ID] [--path FILE]
[--remove-bpe [REMOVE_BPE]] [--quiet]
[--model-overrides DICT] [--results-path RESDIR]
[--output-word-probs] [--output-word-stats]
[--context-window N] [--softmax-batch N]
Named Arguments¶
--no-progress-bar | disable progress bar Default: False |
--log-interval | log progress every N batches (when progress bar is disabled) Default: 1000 |
--log-format | Possible choices: json, none, simple, tqdm log format to use |
--tensorboard-logdir | path to save logs for tensorboard, should match –logdir of running tensorboard (default: no tensorboard logging) Default: “” |
--tbmf-wrapper | [FB only] Default: False |
--seed | pseudo random number generator seed Default: 1 |
--cpu | use CPU instead of CUDA Default: False |
--fp16 | use FP16 Default: False |
--memory-efficient-fp16 | use a memory-efficient version of FP16 training; implies –fp16 Default: False |
--fp16-init-scale | default FP16 loss scale Default: 128 |
--fp16-scale-window | number of updates before increasing loss scale |
--fp16-scale-tolerance | pct of updates that can overflow before decreasing the loss scale Default: 0.0 |
--min-loss-scale | minimum FP16 loss scale, after which training is stopped Default: 0.0001 |
--threshold-loss-scale | threshold FP16 loss scale from below |
--user-dir | path to a python module containing custom extensions (tasks and/or architectures) |
--criterion | Possible choices: sentence_prediction, adaptive_loss, label_smoothed_cross_entropy, composite_loss, binary_cross_entropy, masked_lm, sentence_ranking, legacy_masked_lm_loss, cross_entropy Default: “cross_entropy” |
--tokenizer | Possible choices: nltk, space, moses |
--bpe | Possible choices: gpt2, sentencepiece, subword_nmt, fastbpe |
--optimizer | Possible choices: adadelta, adam, adafactor, adagrad, nag, adamax, sgd Default: “nag” |
--lr-scheduler | Possible choices: cosine, reduce_lr_on_plateau, fixed, triangular, polynomial_decay, inverse_sqrt Default: “fixed” |
--task | Possible choices: sentence_prediction, translation, translation_from_pretrained_xlm, multilingual_translation, semisupervised_translation, cross_lingual_lm, masked_lm, sentence_ranking, audio_pretraining, legacy_masked_lm, translation_moe, language_modeling task Default: “language_modeling” |
--dataset-impl | Possible choices: raw, lazy, cached, mmap output dataset implementation |
Dataset and data loading¶
--num-workers | how many subprocesses to use for data loading Default: 1 |
--skip-invalid-size-inputs-valid-test | ignore too long or too short lines in valid and test set Default: False |
--max-tokens | maximum number of tokens in a batch |
--max-sentences, --batch-size | maximum number of sentences in a batch |
--required-batch-size-multiple | batch size will be a multiplier of this value Default: 8 |
--gen-subset | data subset to generate (train, valid, test) Default: “test” |
--num-shards | shard generation over N shards Default: 1 |
--shard-id | id of the shard to generate (id < num_shards) Default: 0 |
LM Evaluation¶
--path | path(s) to model file(s), colon separated |
--remove-bpe | remove BPE tokens before scoring (can be set to sentencepiece) |
--quiet | only print final scores Default: False |
--model-overrides | a dictionary used to override model args at generation that were used during model training Default: “{}” |
--results-path | path to save eval results (optional)” |
--output-word-probs | if set, outputs words and their predicted log probabilities to standard output Default: False |
--output-word-stats | if set, outputs word statistics such as word count, average probability, etc Default: False |
--context-window | ensures that every evaluated token has access to a context of at least this size, if possible Default: 0 |
--softmax-batch | if BxT is more than this, will batch the softmax over vocab to this amount of tokens in order to fit into GPU memory Default: 9223372036854775807 |
Overview¶
Fairseq can be extended through user-supplied plug-ins. We support five kinds of plug-ins:
- Models define the neural network architecture and encapsulate all of the learnable parameters.
- Criterions compute the loss function given the model outputs and targets.
- Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss.
- Optimizers update the Model parameters based on the gradients.
- Learning Rate Schedulers update the learning rate over the course of training.
Training Flow
Given a model
, criterion
, task
, optimizer
and lr_scheduler
,
fairseq implements the following high-level training flow:
for epoch in range(num_epochs):
itr = task.get_batch_iterator(task.dataset('train'))
for num_updates, batch in enumerate(itr):
task.train_step(batch, model, criterion, optimizer)
average_and_clip_gradients()
optimizer.step()
lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch)
where the default implementation for task.train_step
is roughly:
def train_step(self, batch, model, criterion, optimizer):
loss = criterion(model, batch)
optimizer.backward(loss)
return loss
Registering new plug-ins
New plug-ins are registered through a set of @register
function
decorators, for example:
@register_model('my_lstm')
class MyLSTM(FairseqEncoderDecoderModel):
(...)
Once registered, new plug-ins can be used with the existing Command-line Tools. See the Tutorial sections for more detailed walkthroughs of how to add new plug-ins.
Loading plug-ins from another directory
New plug-ins can be defined in a custom module stored in the user system. In
order to import the module, and make the plugin available to fairseq, the
command line supports the --user-dir
flag that can be used to specify a
custom location for additional modules to load into fairseq.
For example, assuming this directory tree:
/home/user/my-module/
└── __init__.py
with __init__.py
:
from fairseq.models import register_model_architecture
from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big
@register_model_architecture('transformer', 'my_transformer')
def transformer_mmt_big(args):
transformer_vaswani_wmt_en_de_big(args)
it is possible to invoke the fairseq-train script with the new architecture with:
fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation
Tutorial: Simple LSTM¶
In this tutorial we will extend fairseq by adding a new
FairseqEncoderDecoderModel
that encodes a source
sentence with an LSTM and then passes the final hidden state to a second LSTM
that decodes the target sentence (without attention).
This tutorial covers:
- Writing an Encoder and Decoder to encode/decode the source/target sentence, respectively.
- Registering a new Model so that it can be used with the existing Command-line Tools.
- Training the Model using the existing command-line tools.
- Making generation faster by modifying the Decoder to use Incremental decoding.
1. Building an Encoder and Decoder¶
In this section we’ll define a simple LSTM Encoder and Decoder. All Encoders
should implement the FairseqEncoder
interface and
Decoders should implement the FairseqDecoder
interface.
These interfaces themselves extend torch.nn.Module
, so FairseqEncoders
and FairseqDecoders can be written and used in the same ways as ordinary PyTorch
Modules.
Encoder¶
Our Encoder will embed the tokens in the source sentence, feed them to a
torch.nn.LSTM
and return the final hidden state. To create our encoder
save the following in a new file named fairseq/models/simple_lstm.py
:
import torch.nn as nn
from fairseq import utils
from fairseq.models import FairseqEncoder
class SimpleLSTMEncoder(FairseqEncoder):
def __init__(
self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,
):
super().__init__(dictionary)
self.args = args
# Our encoder will embed the inputs before feeding them to the LSTM.
self.embed_tokens = nn.Embedding(
num_embeddings=len(dictionary),
embedding_dim=embed_dim,
padding_idx=dictionary.pad(),
)
self.dropout = nn.Dropout(p=dropout)
# We'll use a single-layer, unidirectional LSTM for simplicity.
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_dim,
num_layers=1,
bidirectional=False,
)
def forward(self, src_tokens, src_lengths):
# The inputs to the ``forward()`` function are determined by the
# Task, and in particular the ``'net_input'`` key in each
# mini-batch. We discuss Tasks in the next tutorial, but for now just
# know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*
# has shape `(batch)`.
# Note that the source is typically padded on the left. This can be
# configured by adding the `--left-pad-source "False"` command-line
# argument, but here we'll make the Encoder handle either kind of
# padding by converting everything to be right-padded.
if self.args.left_pad_source:
# Convert left-padding to right-padding.
src_tokens = utils.convert_padding_direction(
src_tokens,
padding_idx=self.dictionary.pad(),
left_to_right=True
)
# Embed the source.
x = self.embed_tokens(src_tokens)
# Apply dropout.
x = self.dropout(x)
# Pack the sequence into a PackedSequence object to feed to the LSTM.
x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
# Get the output from the LSTM.
_outputs, (final_hidden, _final_cell) = self.lstm(x)
# Return the Encoder's output. This can be any object and will be
# passed directly to the Decoder.
return {
# this will have shape `(bsz, hidden_dim)`
'final_hidden': final_hidden.squeeze(0),
}
# Encoders are required to implement this method so that we can rearrange
# the order of the batch elements during inference (e.g., beam search).
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to `new_order`.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
`encoder_out` rearranged according to `new_order`
"""
final_hidden = encoder_out['final_hidden']
return {
'final_hidden': final_hidden.index_select(0, new_order),
}
Decoder¶
Our Decoder will predict the next word, conditioned on the Encoder’s final
hidden state and an embedded representation of the previous target word – which
is sometimes called teacher forcing. More specifically, we’ll use a
torch.nn.LSTM
to produce a sequence of hidden states that we’ll project
to the size of the output vocabulary to predict each target word.
import torch
from fairseq.models import FairseqDecoder
class SimpleLSTMDecoder(FairseqDecoder):
def __init__(
self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
dropout=0.1,
):
super().__init__(dictionary)
# Our decoder will embed the inputs before feeding them to the LSTM.
self.embed_tokens = nn.Embedding(
num_embeddings=len(dictionary),
embedding_dim=embed_dim,
padding_idx=dictionary.pad(),
)
self.dropout = nn.Dropout(p=dropout)
# We'll use a single-layer, unidirectional LSTM for simplicity.
self.lstm = nn.LSTM(
# For the first layer we'll concatenate the Encoder's final hidden
# state with the embedded target tokens.
input_size=encoder_hidden_dim + embed_dim,
hidden_size=hidden_dim,
num_layers=1,
bidirectional=False,
)
# Define the output projection.
self.output_projection = nn.Linear(hidden_dim, len(dictionary))
# During training Decoders are expected to take the entire target sequence
# (shifted right by one position) and produce logits over the vocabulary.
# The *prev_output_tokens* tensor begins with the end-of-sentence symbol,
# ``dictionary.eos()``, followed by the target sequence.
def forward(self, prev_output_tokens, encoder_out):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
Returns:
tuple:
- the last decoder layer's output of shape
`(batch, tgt_len, vocab)`
- the last decoder layer's attention weights of shape
`(batch, tgt_len, src_len)`
"""
bsz, tgt_len = prev_output_tokens.size()
# Extract the final hidden state from the Encoder.
final_encoder_hidden = encoder_out['final_hidden']
# Embed the target sequence, which has been shifted right by one
# position and now starts with the end-of-sentence symbol.
x = self.embed_tokens(prev_output_tokens)
# Apply dropout.
x = self.dropout(x)
# Concatenate the Encoder's final hidden state to *every* embedded
# target token.
x = torch.cat(
[x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
dim=2,
)
# Using PackedSequence objects in the Decoder is harder than in the
# Encoder, since the targets are not sorted in descending length order,
# which is a requirement of ``pack_padded_sequence()``. Instead we'll
# feed nn.LSTM directly.
initial_state = (
final_encoder_hidden.unsqueeze(0), # hidden
torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
)
output, _ = self.lstm(
x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)`
initial_state,
)
x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)`
# Project the outputs to the size of the vocabulary.
x = self.output_projection(x)
# Return the logits and ``None`` for the attention weights
return x, None
2. Registering the Model¶
Now that we’ve defined our Encoder and Decoder we must register our model with
fairseq using the register_model()
function decorator.
Once the model is registered we’ll be able to use it with the existing
Command-line Tools.
All registered models must implement the
BaseFairseqModel
interface. For sequence-to-sequence
models (i.e., any model with a single Encoder and Decoder), we can instead
implement the FairseqEncoderDecoderModel
interface.
Create a small wrapper class in the same file and register it in fairseq with
the name 'simple_lstm'
:
from fairseq.models import FairseqEncoderDecoderModel, register_model
# Note: the register_model "decorator" should immediately precede the
# definition of the Model class.
@register_model('simple_lstm')
class SimpleLSTMModel(FairseqEncoderDecoderModel):
@staticmethod
def add_args(parser):
# Models can override this method to add new command-line arguments.
# Here we'll add some new command-line arguments to configure dropout
# and the dimensionality of the embeddings and hidden states.
parser.add_argument(
'--encoder-embed-dim', type=int, metavar='N',
help='dimensionality of the encoder embeddings',
)
parser.add_argument(
'--encoder-hidden-dim', type=int, metavar='N',
help='dimensionality of the encoder hidden state',
)
parser.add_argument(
'--encoder-dropout', type=float, default=0.1,
help='encoder dropout probability',
)
parser.add_argument(
'--decoder-embed-dim', type=int, metavar='N',
help='dimensionality of the decoder embeddings',
)
parser.add_argument(
'--decoder-hidden-dim', type=int, metavar='N',
help='dimensionality of the decoder hidden state',
)
parser.add_argument(
'--decoder-dropout', type=float, default=0.1,
help='decoder dropout probability',
)
@classmethod
def build_model(cls, args, task):
# Fairseq initializes models by calling the ``build_model()``
# function. This provides more flexibility, since the returned model
# instance can be of a different type than the one that was called.
# In this case we'll just return a SimpleLSTMModel instance.
# Initialize our Encoder and Decoder.
encoder = SimpleLSTMEncoder(
args=args,
dictionary=task.source_dictionary,
embed_dim=args.encoder_embed_dim,
hidden_dim=args.encoder_hidden_dim,
dropout=args.encoder_dropout,
)
decoder = SimpleLSTMDecoder(
dictionary=task.target_dictionary,
encoder_hidden_dim=args.encoder_hidden_dim,
embed_dim=args.decoder_embed_dim,
hidden_dim=args.decoder_hidden_dim,
dropout=args.decoder_dropout,
)
model = SimpleLSTMModel(encoder, decoder)
# Print the model architecture.
print(model)
return model
# We could override the ``forward()`` if we wanted more control over how
# the encoder and decoder interact, but it's not necessary for this
# tutorial since we can inherit the default implementation provided by
# the FairseqEncoderDecoderModel base class, which looks like:
#
# def forward(self, src_tokens, src_lengths, prev_output_tokens):
# encoder_out = self.encoder(src_tokens, src_lengths)
# decoder_out = self.decoder(prev_output_tokens, encoder_out)
# return decoder_out
Finally let’s define a named architecture with the configuration for our
model. This is done with the register_model_architecture()
function decorator. Thereafter this named architecture can be used with the
--arch
command-line argument, e.g., --arch tutorial_simple_lstm
:
from fairseq.models import register_model_architecture
# The first argument to ``register_model_architecture()`` should be the name
# of the model we registered above (i.e., 'simple_lstm'). The function we
# register here should take a single argument *args* and modify it in-place
# to match the desired architecture.
@register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
def tutorial_simple_lstm(args):
# We use ``getattr()`` to prioritize arguments that are explicitly given
# on the command-line, so that the defaults defined below are only used
# when no other value has been specified.
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)
3. Training the Model¶
Now we’re ready to train the model. We can use the existing fairseq-train
command-line tool for this, making sure to specify our new Model architecture
(--arch tutorial_simple_lstm
).
Note
Make sure you’ve already preprocessed the data from the IWSLT example in the
examples/translation/
directory.
> fairseq-train data-bin/iwslt14.tokenized.de-en \
--arch tutorial_simple_lstm \
--encoder-dropout 0.2 --decoder-dropout 0.2 \
--optimizer adam --lr 0.005 --lr-shrink 0.5 \
--max-tokens 12000
(...)
| epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396
| epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954
The model files should appear in the checkpoints/
directory. While this
model architecture is not very good, we can use the fairseq-generate script to
generate translations and compute our BLEU score over the test set:
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
--path checkpoints/checkpoint_best.pt \
--beam 5 \
--remove-bpe
(...)
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
4. Making generation faster¶
While autoregressive generation from sequence-to-sequence models is inherently
slow, our implementation above is especially slow because it recomputes the
entire sequence of Decoder hidden states for every output token (i.e., it is
O(n^2)
). We can make this significantly faster by instead caching the
previous hidden states.
In fairseq this is called Incremental decoding. Incremental decoding is a special mode at inference time where the Model only receives a single timestep of input corresponding to the immediately previous output token (for teacher forcing) and must produce the next output incrementally. Thus the model must cache any long-term state that is needed about the sequence, e.g., hidden states, convolutional states, etc.
To implement incremental decoding we will modify our model to implement the
FairseqIncrementalDecoder
interface. Compared to the
standard FairseqDecoder
interface, the incremental
decoder interface allows forward()
methods to take an extra keyword argument
(incremental_state) that can be used to cache state across time-steps.
Let’s replace our SimpleLSTMDecoder
with an incremental one:
import torch
from fairseq.models import FairseqIncrementalDecoder
class SimpleLSTMDecoder(FairseqIncrementalDecoder):
def __init__(
self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
dropout=0.1,
):
# This remains the same as before.
super().__init__(dictionary)
self.embed_tokens = nn.Embedding(
num_embeddings=len(dictionary),
embedding_dim=embed_dim,
padding_idx=dictionary.pad(),
)
self.dropout = nn.Dropout(p=dropout)
self.lstm = nn.LSTM(
input_size=encoder_hidden_dim + embed_dim,
hidden_size=hidden_dim,
num_layers=1,
bidirectional=False,
)
self.output_projection = nn.Linear(hidden_dim, len(dictionary))
# We now take an additional kwarg (*incremental_state*) for caching the
# previous hidden and cell states.
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
if incremental_state is not None:
# If the *incremental_state* argument is not ``None`` then we are
# in incremental inference mode. While *prev_output_tokens* will
# still contain the entire decoded prefix, we will only use the
# last step and assume that the rest of the state is cached.
prev_output_tokens = prev_output_tokens[:, -1:]
# This remains the same as before.
bsz, tgt_len = prev_output_tokens.size()
final_encoder_hidden = encoder_out['final_hidden']
x = self.embed_tokens(prev_output_tokens)
x = self.dropout(x)
x = torch.cat(
[x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
dim=2,
)
# We will now check the cache and load the cached previous hidden and
# cell states, if they exist, otherwise we will initialize them to
# zeros (as before). We will use the ``utils.get_incremental_state()``
# and ``utils.set_incremental_state()`` helpers.
initial_state = utils.get_incremental_state(
self, incremental_state, 'prev_state',
)
if initial_state is None:
# first time initialization, same as the original version
initial_state = (
final_encoder_hidden.unsqueeze(0), # hidden
torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
)
# Run one step of our LSTM.
output, latest_state = self.lstm(x.transpose(0, 1), initial_state)
# Update the cache with the latest hidden and cell states.
utils.set_incremental_state(
self, incremental_state, 'prev_state', latest_state,
)
# This remains the same as before
x = output.transpose(0, 1)
x = self.output_projection(x)
return x, None
# The ``FairseqIncrementalDecoder`` interface also requires implementing a
# ``reorder_incremental_state()`` method, which is used during beam search
# to select and reorder the incremental state.
def reorder_incremental_state(self, incremental_state, new_order):
# Load the cached state.
prev_state = utils.get_incremental_state(
self, incremental_state, 'prev_state',
)
# Reorder batches according to *new_order*.
reordered_state = (
prev_state[0].index_select(1, new_order), # hidden
prev_state[1].index_select(1, new_order), # cell
)
# Update the cached state.
utils.set_incremental_state(
self, incremental_state, 'prev_state', reordered_state,
)
Finally, we can rerun generation and observe the speedup:
# Before
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
--path checkpoints/checkpoint_best.pt \
--beam 5 \
--remove-bpe
(...)
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
# After
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
--path checkpoints/checkpoint_best.pt \
--beam 5 \
--remove-bpe
(...)
| Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s)
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
Tutorial: Classifying Names with a Character-Level RNN¶
In this tutorial we will extend fairseq to support classification tasks. In particular we will re-implement the PyTorch tutorial for Classifying Names with a Character-Level RNN in fairseq. It is recommended to quickly skim that tutorial before beginning this one.
This tutorial covers:
- Preprocessing the data to create dictionaries.
- Registering a new Model that encodes an input sentence with a simple RNN and predicts the output label.
- Registering a new Task that loads our dictionaries and dataset.
- Training the Model using the existing command-line tools.
- Writing an evaluation script that imports fairseq and allows us to interactively evaluate our model on new inputs.
1. Preprocessing the data¶
The original tutorial provides raw data, but we’ll work with a modified version of the data that is already tokenized into characters and split into separate train, valid and test sets.
Download and extract the data from here: tutorial_names.tar.gz
Once extracted, let’s preprocess the data using the fairseq-preprocess
command-line tool to create the dictionaries. While this tool is primarily
intended for sequence-to-sequence problems, we’re able to reuse it here by
treating the label as a “target” sequence of length 1. We’ll also output the
preprocessed files in “raw” format using the --dataset-impl
option to
enhance readability:
> fairseq-preprocess \
--trainpref names/train --validpref names/valid --testpref names/test \
--source-lang input --target-lang label \
--destdir names-bin --dataset-impl raw
After running the above command you should see a new directory,
names-bin/
, containing the dictionaries for inputs and labels.
2. Registering a new Model¶
Next we’ll register a new model in fairseq that will encode an input sentence with a simple RNN and predict the output label. Compared to the original PyTorch tutorial, our version will also work with batches of data and GPU Tensors.
First let’s copy the simple RNN module implemented in the PyTorch tutorial.
Create a new file named fairseq/models/rnn_classifier.py
with the
following contents:
import torch
import torch.nn as nn
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined)
output = self.i2o(combined)
output = self.softmax(output)
return output, hidden
def initHidden(self):
return torch.zeros(1, self.hidden_size)
We must also register this model with fairseq using the
register_model()
function decorator. Once the model is
registered we’ll be able to use it with the existing Command-line Tools.
All registered models must implement the BaseFairseqModel
interface, so we’ll create a small wrapper class in the same file and register
it in fairseq with the name 'rnn_classifier'
:
from fairseq.models import BaseFairseqModel, register_model
# Note: the register_model "decorator" should immediately precede the
# definition of the Model class.
@register_model('rnn_classifier')
class FairseqRNNClassifier(BaseFairseqModel):
@staticmethod
def add_args(parser):
# Models can override this method to add new command-line arguments.
# Here we'll add a new command-line argument to configure the
# dimensionality of the hidden state.
parser.add_argument(
'--hidden-dim', type=int, metavar='N',
help='dimensionality of the hidden state',
)
@classmethod
def build_model(cls, args, task):
# Fairseq initializes models by calling the ``build_model()``
# function. This provides more flexibility, since the returned model
# instance can be of a different type than the one that was called.
# In this case we'll just return a FairseqRNNClassifier instance.
# Initialize our RNN module
rnn = RNN(
# We'll define the Task in the next section, but for now just
# notice that the task holds the dictionaries for the "source"
# (i.e., the input sentence) and "target" (i.e., the label).
input_size=len(task.source_dictionary),
hidden_size=args.hidden_dim,
output_size=len(task.target_dictionary),
)
# Return the wrapped version of the module
return FairseqRNNClassifier(
rnn=rnn,
input_vocab=task.source_dictionary,
)
def __init__(self, rnn, input_vocab):
super(FairseqRNNClassifier, self).__init__()
self.rnn = rnn
self.input_vocab = input_vocab
# The RNN module in the tutorial expects one-hot inputs, so we can
# precompute the identity matrix to help convert from indices to
# one-hot vectors. We register it as a buffer so that it is moved to
# the GPU when ``cuda()`` is called.
self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab)))
def forward(self, src_tokens, src_lengths):
# The inputs to the ``forward()`` function are determined by the
# Task, and in particular the ``'net_input'`` key in each
# mini-batch. We'll define the Task in the next section, but for
# now just know that *src_tokens* has shape `(batch, src_len)` and
# *src_lengths* has shape `(batch)`.
bsz, max_src_len = src_tokens.size()
# Initialize the RNN hidden state. Compared to the original PyTorch
# tutorial we'll also handle batched inputs and work on the GPU.
hidden = self.rnn.initHidden()
hidden = hidden.repeat(bsz, 1) # expand for batched inputs
hidden = hidden.to(src_tokens.device) # move to GPU
for i in range(max_src_len):
# WARNING: The inputs have padding, so we should mask those
# elements here so that padding doesn't affect the results.
# This is left as an exercise for the reader. The padding symbol
# is given by ``self.input_vocab.pad()`` and the unpadded length
# of each input is given by *src_lengths*.
# One-hot encode a batch of input characters.
input = self.one_hot_inputs[src_tokens[:, i].long()]
# Feed the input to our RNN.
output, hidden = self.rnn(input, hidden)
# Return the final output state for making a prediction
return output
Finally let’s define a named architecture with the configuration for our
model. This is done with the register_model_architecture()
function decorator. Thereafter this named architecture can be used with the
--arch
command-line argument, e.g., --arch pytorch_tutorial_rnn
:
from fairseq.models import register_model_architecture
# The first argument to ``register_model_architecture()`` should be the name
# of the model we registered above (i.e., 'rnn_classifier'). The function we
# register here should take a single argument *args* and modify it in-place
# to match the desired architecture.
@register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn')
def pytorch_tutorial_rnn(args):
# We use ``getattr()`` to prioritize arguments that are explicitly given
# on the command-line, so that the defaults defined below are only used
# when no other value has been specified.
args.hidden_dim = getattr(args, 'hidden_dim', 128)
3. Registering a new Task¶
Now we’ll register a new FairseqTask
that will load our
dictionaries and dataset. Tasks can also control how the data is batched into
mini-batches, but in this tutorial we’ll reuse the batching provided by
fairseq.data.LanguagePairDataset
.
Create a new file named fairseq/tasks/simple_classification.py
with the
following contents:
import os
import torch
from fairseq.data import Dictionary, LanguagePairDataset
from fairseq.tasks import FairseqTask, register_task
@register_task('simple_classification')
class SimpleClassificationTask(FairseqTask):
@staticmethod
def add_args(parser):
# Add some command-line arguments for specifying where the data is
# located and the maximum supported input length.
parser.add_argument('data', metavar='FILE',
help='file prefix for data')
parser.add_argument('--max-positions', default=1024, type=int,
help='max input length')
@classmethod
def setup_task(cls, args, **kwargs):
# Here we can perform any setup required for the task. This may include
# loading Dictionaries, initializing shared Embedding layers, etc.
# In this case we'll just load the Dictionaries.
input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt'))
label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt'))
print('| [input] dictionary: {} types'.format(len(input_vocab)))
print('| [label] dictionary: {} types'.format(len(label_vocab)))
return SimpleClassificationTask(args, input_vocab, label_vocab)
def __init__(self, args, input_vocab, label_vocab):
super().__init__(args)
self.input_vocab = input_vocab
self.label_vocab = label_vocab
def load_dataset(self, split, **kwargs):
"""Load a given dataset split (e.g., train, valid, test)."""
prefix = os.path.join(self.args.data, '{}.input-label'.format(split))
# Read input sentences.
sentences, lengths = [], []
with open(prefix + '.input', encoding='utf-8') as file:
for line in file:
sentence = line.strip()
# Tokenize the sentence, splitting on spaces
tokens = self.input_vocab.encode_line(
sentence, add_if_not_exist=False,
)
sentences.append(tokens)
lengths.append(tokens.numel())
# Read labels.
labels = []
with open(prefix + '.label', encoding='utf-8') as file:
for line in file:
label = line.strip()
labels.append(
# Convert label to a numeric ID.
torch.LongTensor([self.label_vocab.add_symbol(label)])
)
assert len(sentences) == len(labels)
print('| {} {} {} examples'.format(self.args.data, split, len(sentences)))
# We reuse LanguagePairDataset since classification can be modeled as a
# sequence-to-sequence task where the target sequence has length 1.
self.datasets[split] = LanguagePairDataset(
src=sentences,
src_sizes=lengths,
src_dict=self.input_vocab,
tgt=labels,
tgt_sizes=torch.ones(len(labels)), # targets have length 1
tgt_dict=self.label_vocab,
left_pad_source=False,
max_source_positions=self.args.max_positions,
max_target_positions=1,
# Since our target is a single class label, there's no need for
# teacher forcing. If we set this to ``True`` then our Model's
# ``forward()`` method would receive an additional argument called
# *prev_output_tokens* that would contain a shifted version of the
# target sequence.
input_feeding=False,
)
def max_positions(self):
"""Return the max input length allowed by the task."""
# The source should be less than *args.max_positions* and the "target"
# has max length 1.
return (self.args.max_positions, 1)
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary`."""
return self.input_vocab
@property
def target_dictionary(self):
"""Return the target :class:`~fairseq.data.Dictionary`."""
return self.label_vocab
# We could override this method if we wanted more control over how batches
# are constructed, but it's not necessary for this tutorial since we can
# reuse the batching provided by LanguagePairDataset.
#
# def get_batch_iterator(
# self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
# ignore_invalid_inputs=False, required_batch_size_multiple=1,
# seed=1, num_shards=1, shard_id=0,
# ):
# (...)
4. Training the Model¶
Now we’re ready to train the model. We can use the existing fairseq-train
command-line tool for this, making sure to specify our new Task (--task
simple_classification
) and Model architecture (--arch
pytorch_tutorial_rnn
):
Note
You can also configure the dimensionality of the hidden state by passing the
--hidden-dim
argument to fairseq-train.
> fairseq-train names-bin \
--task simple_classification \
--arch pytorch_tutorial_rnn \
--optimizer adam --lr 0.001 --lr-shrink 0.5 \
--max-tokens 1000
(...)
| epoch 027 | loss 1.200 | ppl 2.30 | wps 15728 | ups 119.4 | wpb 116 | bsz 116 | num_updates 3726 | lr 1.5625e-05 | gnorm 1.290 | clip 0% | oom 0 | wall 32 | train_wall 21
| epoch 027 | valid on 'valid' subset | valid_loss 1.41304 | valid_ppl 2.66 | num_updates 3726 | best 1.41208
| done training in 31.6 seconds
The model files should appear in the checkpoints/
directory.
5. Writing an evaluation script¶
Finally we can write a short script to evaluate our model on new inputs. Create
a new file named eval_classifier.py
with the following contents:
from fairseq import checkpoint_utils, data, options, tasks
# Parse command-line arguments for generation
parser = options.get_generation_parser(default_task='simple_classification')
args = options.parse_args_and_arch(parser)
# Setup task
task = tasks.setup_task(args)
# Load model
print('| loading model from {}'.format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
model = models[0]
while True:
sentence = input('\nInput: ')
# Tokenize into characters
chars = ' '.join(list(sentence.strip()))
tokens = task.source_dictionary.encode_line(
chars, add_if_not_exist=False,
)
# Build mini-batch to feed to the model
batch = data.language_pair_dataset.collate(
samples=[{'id': -1, 'source': tokens}], # bsz = 1
pad_idx=task.source_dictionary.pad(),
eos_idx=task.source_dictionary.eos(),
left_pad_source=False,
input_feeding=False,
)
# Feed batch to the model and get predictions
preds = model(**batch['net_input'])
# Print top 3 predictions and their log-probabilities
top_scores, top_labels = preds[0].topk(k=3)
for score, label_idx in zip(top_scores, top_labels):
label_name = task.target_dictionary.string([label_idx])
print('({:.2f})\t{}'.format(score, label_name))
Now we can evaluate our model interactively. Note that we have included the
original data path (names-bin/
) so that the dictionaries can be loaded:
> python eval_classifier.py names-bin --path checkpoints/checkpoint_best.pt
| [input] dictionary: 64 types
| [label] dictionary: 24 types
| loading model from checkpoints/checkpoint_best.pt
Input: Satoshi
(-0.61) Japanese
(-1.20) Arabic
(-2.86) Italian
Input: Sinbad
(-0.30) Arabic
(-1.76) English
(-4.08) Russian
Tasks¶
Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss.
Tasks can be selected via the --task
command-line argument. Once selected, a
task may expose additional command-line arguments for further configuration.
Example usage:
# setup the task (e.g., load dictionaries)
task = fairseq.tasks.setup_task(args)
# build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
# load datasets
task.load_dataset('train')
task.load_dataset('valid')
# iterate over mini-batches of data
batch_itr = task.get_batch_iterator(
task.dataset('train'), max_tokens=4096,
)
for batch in batch_itr:
# compute the loss
loss, sample_size, logging_output = task.get_loss(
model, criterion, batch,
)
loss.backward()
Translation¶
-
class
fairseq.tasks.translation.
TranslationTask
(args, src_dict, tgt_dict)[source]¶ Translate from one (source) language to another (target) language.
Parameters: - src_dict (Dictionary) – dictionary for the source language
- tgt_dict (Dictionary) – dictionary for the target language
Note
The translation task is compatible with
fairseq-train
,fairseq-generate
andfairseq-interactive
.The translation task provides the following additional command-line arguments:
usage: [--task translation] [-s SRC] [-t TARGET] [--lazy-load] [--raw-text] [--left-pad-source BOOL] [--left-pad-target BOOL] [--max-source-positions N] [--max-target-positions N] [--upsample-primary UPSAMPLE_PRIMARY] data
Task name¶
--task Enable this task with: --task=translation
Additional command-line arguments¶
data colon separated path to data directories list, will be iterated upon during epochs in round-robin manner -s, --source-lang source language -t, --target-lang target language --lazy-load load the dataset lazily
Default: False
--raw-text load raw text dataset
Default: False
--left-pad-source pad the source on the left
Default: “True”
--left-pad-target pad the target on the left
Default: “False”
--max-source-positions max number of tokens in the source sequence
Default: 1024
--max-target-positions max number of tokens in the target sequence
Default: 1024
--upsample-primary amount to upsample primary dataset
Default: 1
Language Modeling¶
-
class
fairseq.tasks.language_modeling.
LanguageModelingTask
(args, dictionary, output_dictionary=None, targets=None)[source]¶ Train a language model.
Parameters: - dictionary (Dictionary) – the dictionary for the input of the language model
- output_dictionary (Dictionary) – the dictionary for the
output of the language model. In most cases it will be the same as
dictionary, but could possibly be a more limited version of the
dictionary (if
--output-dictionary-size
is used). - targets (List[str]) – list of the target types that the language model should predict. Can be one of “self”, “future”, and “past”. Defaults to “future”.
Note
The language modeling task is compatible with
fairseq-train
,fairseq-generate
,fairseq-interactive
andfairseq-eval-lm
.The language modeling task provides the following additional command-line arguments:
usage: [--task language_modeling] [--sample-break-mode {none,complete,complete_doc,eos}] [--tokens-per-sample TOKENS_PER_SAMPLE] [--lazy-load] [--raw-text] [--output-dictionary-size OUTPUT_DICTIONARY_SIZE] [--self-target] [--future-target] [--past-target] [--add-bos-token] [--max-target-positions N] data
Task name¶
--task Enable this task with: --task=language_modeling
Additional command-line arguments¶
data path to data directory --sample-break-mode Possible choices: none, complete, complete_doc, eos
If omitted or “none”, fills each sample with tokens-per-sample tokens. If set to “complete”, splits samples only at the end of sentence, but may include multiple sentences per sample. “complete_doc” is similar but respects doc boundaries. If set to “eos”, includes only one sentence per sample.
Default: “none”
--tokens-per-sample max number of tokens per sample for LM dataset
Default: 1024
--lazy-load load the dataset lazily
Default: False
--raw-text load raw text dataset
Default: False
--output-dictionary-size limit the size of output dictionary
Default: -1
--self-target include self target
Default: False
--future-target include future target
Default: False
--past-target include past target
Default: False
--add-bos-token prepend beginning of sentence token (<s>)
Default: False
--max-target-positions max number of tokens in the target sequence
Adding new tasks¶
-
fairseq.tasks.
register_task
(name)[source]¶ New tasks can be added to fairseq with the
register_task()
function decorator.For example:
@register_task('classification') class ClassificationTask(FairseqTask): (...)
Note
All Tasks must implement the
FairseqTask
interface.Please see the
Parameters: name (str) – the name of the task
-
class
fairseq.tasks.
FairseqTask
(args)[source]¶ Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss.
-
build_criterion
(args)[source]¶ Build the
FairseqCriterion
instance for this task.Parameters: args (argparse.Namespace) – parsed command-line arguments Returns: a FairseqCriterion
instance
-
classmethod
build_dictionary
(filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8)[source]¶ Build the dictionary
Parameters: - filenames (list) – list of filenames
- workers (int) – number of concurrent workers
- threshold (int) – defines the minimum word count
- nwords (int) – defines the total number of words in the final dictionary, including special symbols
- padding_factor (int) – can be used to pad the dictionary size to be a multiple of 8, which is important on some hardware (e.g., Nvidia Tensor Cores).
-
build_model
(args)[source]¶ Build the
BaseFairseqModel
instance for this task.Parameters: args (argparse.Namespace) – parsed command-line arguments Returns: a BaseFairseqModel
instance
-
dataset
(split)[source]¶ Return a loaded dataset split.
Parameters: split (str) – name of the split (e.g., train, valid, test) Returns: a FairseqDataset
corresponding to split
-
get_batch_iterator
(dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0)[source]¶ Get an iterator that yields batches of data from the given dataset.
Parameters: - dataset (FairseqDataset) – dataset to batch
- max_tokens (int, optional) – max number of tokens in each batch (default: None).
- max_sentences (int, optional) – max number of sentences in each batch (default: None).
- max_positions (optional) – max sentence length supported by the model (default: None).
- ignore_invalid_inputs (bool, optional) – don’t raise Exception for sentences that are too long (default: False).
- required_batch_size_multiple (int, optional) – require batch size to be a multiple of N (default: 1).
- seed (int, optional) – seed for random number generator for reproducibility (default: 1).
- num_shards (int, optional) – shard the data iterator into N shards (default: 1).
- shard_id (int, optional) – which shard of the data iterator to return (default: 0).
- num_workers (int, optional) – how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0).
- epoch (int, optional) – the epoch to start the iterator from (default: 0).
Returns: - a batched iterator over the
given dataset split
Return type: EpochBatchIterator
-
load_dataset
(split, combine=False, **kwargs)[source]¶ Load a given dataset split.
Parameters: split (str) – name of the split (e.g., train, valid, test)
-
classmethod
load_dictionary
(filename)[source]¶ Load the dictionary from the filename
Parameters: filename (str) – the filename
-
classmethod
setup_task
(args, **kwargs)[source]¶ Setup the task (e.g., load dictionaries).
Parameters: args (argparse.Namespace) – parsed command-line arguments
-
source_dictionary
¶ Return the source
Dictionary
(if applicable for this task).
-
target_dictionary
¶ Return the target
Dictionary
(if applicable for this task).
-
train_step
(sample, model, criterion, optimizer, ignore_grad=False)[source]¶ Do forward and backward, and return the loss as computed by criterion for the given model and sample.
Parameters: - sample (dict) – the mini-batch. The format is defined by the
FairseqDataset
. - model (BaseFairseqModel) – the model
- criterion (FairseqCriterion) – the criterion
- optimizer (FairseqOptimizer) – the optimizer
- ignore_grad (bool) – multiply loss by 0 if this is set to True
Returns: - the loss
- the sample size, which is used as the denominator for the gradient
- logging outputs to display while training
Return type: - sample (dict) – the mini-batch. The format is defined by the
-
Models¶
A Model defines the neural network’s forward()
method and encapsulates all
of the learnable parameters in the network. Each model also provides a set of
named architectures that define the precise network configuration (e.g.,
embedding dimension, number of layers, etc.).
Both the model type and architecture are selected via the --arch
command-line argument. Once selected, a model may expose additional command-line
arguments for further configuration.
Note
All fairseq Models extend BaseFairseqModel
, which in turn extends
torch.nn.Module
. Thus any fairseq Model can be used as a
stand-alone Module in other PyTorch code.
Convolutional Neural Networks (CNN)¶
-
class
fairseq.models.fconv.
FConvModel
(encoder, decoder)[source]¶ A fully convolutional model, i.e. a convolutional encoder and a convolutional decoder, as described in “Convolutional Sequence to Sequence Learning” (Gehring et al., 2017).
Parameters: - encoder (FConvEncoder) – the encoder
- decoder (FConvDecoder) – the decoder
The Convolutional model provides the following named architectures and command-line arguments:
usage: [--arch {fconv,fconv_iwslt_de_en,fconv_wmt_en_ro,fconv_wmt_en_de,fconv_wmt_en_fr}] [--dropout D] [--encoder-embed-dim N] [--encoder-embed-path STR] [--encoder-layers EXPR] [--decoder-embed-dim N] [--decoder-embed-path STR] [--decoder-layers EXPR] [--decoder-out-embed-dim N] [--decoder-attention EXPR] [--share-input-output-embed]
Named architectures¶
--arch Possible choices: fconv, fconv_iwslt_de_en, fconv_wmt_en_ro, fconv_wmt_en_de, fconv_wmt_en_fr Additional command-line arguments¶
--dropout dropout probability --encoder-embed-dim encoder embedding dimension --encoder-embed-path path to pre-trained encoder embedding --encoder-layers encoder layers [(dim, kernel_size), …] --decoder-embed-dim decoder embedding dimension --decoder-embed-path path to pre-trained decoder embedding --decoder-layers decoder layers [(dim, kernel_size), …] --decoder-out-embed-dim decoder output embedding dimension --decoder-attention decoder attention [True, …] --share-input-output-embed share input and output embeddings (requires –decoder-out-embed-dim and –decoder-embed-dim to be equal)
Default: False
-
class
fairseq.models.fconv.
FConvEncoder
(dictionary, embed_dim=512, embed_dict=None, max_positions=1024, convolutions=((512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3)), dropout=0.1)[source]¶ Convolutional encoder consisting of len(convolutions) layers.
Parameters: - dictionary (Dictionary) – encoding dictionary
- embed_dim (int, optional) – embedding dimension
- embed_dict (str, optional) – filename from which to load pre-trained embeddings
- max_positions (int, optional) – maximum supported input sequence length
- convolutions (list, optional) – the convolutional layer structure. Each
list item i corresponds to convolutional layer i. Layers are
given as
(out_channels, kernel_width, [residual])
. Residual connections are added between layers whenresidual=1
(which is the default behavior). - dropout (float, optional) – dropout to be applied before each conv layer
-
forward
(src_tokens, src_lengths)[source]¶ Parameters: - src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
- src_lengths (LongTensor) – lengths of each source sentence of shape (batch)
Returns: - encoder_out (tuple): a tuple with two elements, where the first element is the last encoder layer’s output and the second element is the same quantity summed with the input embedding (used for attention). The shape of both tensors is (batch, src_len, embed_dim).
- encoder_padding_mask (ByteTensor): the positions of padding elements of shape (batch, src_len)
Return type:
-
class
fairseq.models.fconv.
FConvDecoder
(dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256, max_positions=1024, convolutions=((512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3)), attention=True, dropout=0.1, share_embed=False, positional_embeddings=True, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0)[source]¶ Convolutional decoder
-
forward
(prev_output_tokens, encoder_out=None, incremental_state=None, **unused)[source]¶ Parameters: - prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for teacher forcing
- encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
- incremental_state (dict, optional) – dictionary used for storing state during Incremental decoding
Returns: - the decoder’s output of shape (batch, tgt_len, vocab)
- a dictionary with any model-specific outputs
Return type:
-
Long Short-Term Memory (LSTM) networks¶
-
class
fairseq.models.lstm.
LSTMEncoder
(dictionary, embed_dim=512, hidden_size=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, left_pad=True, pretrained_embed=None, padding_value=0.0)[source]¶ LSTM encoder.
-
class
fairseq.models.lstm.
LSTMDecoder
(dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True, encoder_output_units=512, pretrained_embed=None, share_input_output_embed=False, adaptive_softmax_cutoff=None)[source]¶ LSTM decoder.
-
forward
(prev_output_tokens, encoder_out, incremental_state=None)[source]¶ Parameters: - prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for teacher forcing
- encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
- incremental_state (dict, optional) – dictionary used for storing state during Incremental decoding
Returns: - the decoder’s output of shape (batch, tgt_len, vocab)
- a dictionary with any model-specific outputs
Return type:
-
Transformer (self-attention) networks¶
-
class
fairseq.models.transformer.
TransformerModel
(encoder, decoder)[source]¶ Transformer model from “Attention Is All You Need” (Vaswani, et al, 2017).
Parameters: - encoder (TransformerEncoder) – the encoder
- decoder (TransformerDecoder) – the decoder
The Transformer model provides the following named architectures and command-line arguments:
usage: [--arch {transformer,transformer_iwslt_de_en,transformer_wmt_en_de,transformer_vaswani_wmt_en_de_big,transformer_vaswani_wmt_en_fr_big,transformer_wmt_en_de_big,transformer_wmt_en_de_big_t2t}] [--activation-fn {relu,gelu,gelu_fast,gelu_accurate,tanh,linear}] [--dropout D] [--attention-dropout D] [--activation-dropout D] [--encoder-embed-path STR] [--encoder-embed-dim N] [--encoder-ffn-embed-dim N] [--encoder-layers N] [--encoder-attention-heads N] [--encoder-normalize-before] [--encoder-learned-pos] [--decoder-embed-path STR] [--decoder-embed-dim N] [--decoder-ffn-embed-dim N] [--decoder-layers N] [--decoder-attention-heads N] [--decoder-learned-pos] [--decoder-normalize-before] [--share-decoder-input-output-embed] [--share-all-embeddings] [--no-token-positional-embeddings] [--adaptive-softmax-cutoff EXPR] [--adaptive-softmax-dropout D]
Named architectures¶
--arch Possible choices: transformer, transformer_iwslt_de_en, transformer_wmt_en_de, transformer_vaswani_wmt_en_de_big, transformer_vaswani_wmt_en_fr_big, transformer_wmt_en_de_big, transformer_wmt_en_de_big_t2t Additional command-line arguments¶
--activation-fn Possible choices: relu, gelu, gelu_fast, gelu_accurate, tanh, linear
activation function to use
--dropout dropout probability --attention-dropout dropout probability for attention weights --activation-dropout, --relu-dropout dropout probability after activation in FFN. --encoder-embed-path path to pre-trained encoder embedding --encoder-embed-dim encoder embedding dimension --encoder-ffn-embed-dim encoder embedding dimension for FFN --encoder-layers num encoder layers --encoder-attention-heads num encoder attention heads --encoder-normalize-before apply layernorm before each encoder block
Default: False
--encoder-learned-pos use learned positional embeddings in the encoder
Default: False
--decoder-embed-path path to pre-trained decoder embedding --decoder-embed-dim decoder embedding dimension --decoder-ffn-embed-dim decoder embedding dimension for FFN --decoder-layers num decoder layers --decoder-attention-heads num decoder attention heads --decoder-learned-pos use learned positional embeddings in the decoder
Default: False
--decoder-normalize-before apply layernorm before each decoder block
Default: False
--share-decoder-input-output-embed share decoder input and output embeddings
Default: False
--share-all-embeddings share encoder, decoder and output embeddings (requires shared dictionary and embed dim)
Default: False
--no-token-positional-embeddings if set, disables positional embeddings (outside self attention)
Default: False
--adaptive-softmax-cutoff comma separated list of adaptive softmax cutoff points. Must be used with adaptive_loss criterion --adaptive-softmax-dropout sets adaptive softmax dropout for the tail projections
-
class
fairseq.models.transformer.
TransformerEncoder
(args, dictionary, embed_tokens)[source]¶ Transformer encoder consisting of args.encoder_layers layers. Each layer is a
TransformerEncoderLayer
.Parameters: - args (argparse.Namespace) – parsed command-line arguments
- dictionary (Dictionary) – encoding dictionary
- embed_tokens (torch.nn.Embedding) – input embedding
-
forward
(src_tokens, src_lengths)[source]¶ Parameters: - src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
- src_lengths (torch.LongTensor) – lengths of each source sentence of shape (batch)
Returns: - encoder_out (Tensor): the last encoder layer’s output of shape (src_len, batch, embed_dim)
- encoder_padding_mask (ByteTensor): the positions of padding elements of shape (batch, src_len)
Return type:
-
class
fairseq.models.transformer.
TransformerEncoderLayer
(args)[source]¶ Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is postprocessed with: dropout -> add residual -> layernorm. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: dropout -> add residual. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting args.encoder_normalize_before to
True
.Parameters: args (argparse.Namespace) – parsed command-line arguments -
forward
(x, encoder_padding_mask, attn_mask=None)[source]¶ Parameters: - x (Tensor) – input to the layer of shape (seq_len, batch, embed_dim)
- encoder_padding_mask (ByteTensor) – binary ByteTensor of shape
(batch, src_len) where padding elements are indicated by
1
. - attn_mask (ByteTensor) – binary tensor of shape (T_tgt, T_src), where
- is the length of query, while T_src is the length of key, (T_tgt) –
- here both query and key is x here, (though) –
- t_src] = 1 means when calculating embedding (attn_mask[t_tgt,) –
- t_tgt, t_src is excluded (for) –
- in attention (included) –
Returns: encoded output of shape (seq_len, batch, embed_dim)
-
-
class
fairseq.models.transformer.
TransformerDecoder
(args, dictionary, embed_tokens, no_encoder_attn=False)[source]¶ Transformer decoder consisting of args.decoder_layers layers. Each layer is a
TransformerDecoderLayer
.Parameters: - args (argparse.Namespace) – parsed command-line arguments
- dictionary (Dictionary) – decoding dictionary
- embed_tokens (torch.nn.Embedding) – output embedding
- no_encoder_attn (bool, optional) – whether to attend to encoder outputs (default: False).
-
extract_features
(prev_output_tokens, encoder_out=None, incremental_state=None, **unused)[source]¶ Similar to forward but only return features.
Returns: - the decoder’s features of shape (batch, tgt_len, embed_dim)
- a dictionary with any model-specific outputs
Return type: tuple
-
forward
(prev_output_tokens, encoder_out=None, incremental_state=None, **unused)[source]¶ Parameters: - prev_output_tokens (LongTensor) – previous decoder outputs of shape (batch, tgt_len), for teacher forcing
- encoder_out (Tensor, optional) – output from the encoder, used for encoder-side attention
- incremental_state (dict) – dictionary used for storing state during Incremental decoding
Returns: - the decoder’s output of shape (batch, tgt_len, vocab)
- a dictionary with any model-specific outputs
Return type:
-
class
fairseq.models.transformer.
TransformerDecoderLayer
(args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False)[source]¶ Decoder layer block.
In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: dropout -> add residual -> layernorm. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: dropout -> add residual. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting args.decoder_normalize_before to
True
.Parameters: - args (argparse.Namespace) – parsed command-line arguments
- no_encoder_attn (bool, optional) – whether to attend to encoder outputs (default: False).
-
forward
(x, encoder_out=None, encoder_padding_mask=None, incremental_state=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None)[source]¶ Parameters: - x (Tensor) – input to the layer of shape (seq_len, batch, embed_dim)
- encoder_padding_mask (ByteTensor) – binary ByteTensor of shape
(batch, src_len) where padding elements are indicated by
1
.
Returns: encoded output of shape (seq_len, batch, embed_dim)
Adding new models¶
-
fairseq.models.
register_model
(name)[source]¶ New model types can be added to fairseq with the
register_model()
function decorator.For example:
@register_model('lstm') class LSTM(FairseqEncoderDecoderModel): (...)
Note
All models must implement the
BaseFairseqModel
interface. Typically you will extendFairseqEncoderDecoderModel
for sequence-to-sequence tasks orFairseqLanguageModel
for language modeling tasks.Parameters: name (str) – the name of the model
-
fairseq.models.
register_model_architecture
(model_name, arch_name)[source]¶ New model architectures can be added to fairseq with the
register_model_architecture()
function decorator. After registration, model architectures can be selected with the--arch
command-line argument.For example:
@register_model_architecture('lstm', 'lstm_luong_wmt_en_de') def lstm_luong_wmt_en_de(args): args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000) (...)
The decorated function should take a single argument args, which is a
argparse.Namespace
of arguments parsed from the command-line. The decorated function should modify these arguments in-place to match the desired architecture.Parameters:
-
class
fairseq.models.
BaseFairseqModel
[source]¶ Base class for fairseq models.
-
classmethod
from_pretrained
(model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', **kwargs)[source]¶ Load a
FairseqModel
from a pre-trained model file. Downloads and caches the pre-trained model file if needed.The base implementation returns a
GeneratorHubInterface
, which can be used to generate translations or sample from language models. The underlyingFairseqModel
can be accessed via the generator.models attribute.Other models may override this to implement custom hub interfaces.
Parameters: - model_name_or_path (str) – either the name of a pre-trained model to load or a path/URL to a pre-trained model state dict
- checkpoint_file (str, optional) – colon-separated list of checkpoint files in the model archive to ensemble (default: ‘model.pt’)
- data_name_or_path (str, optional) – point args.data to the archive at the given path/URL. Can start with ‘.’ or ‘./’ to reuse the model archive path.
-
get_normalized_probs
(net_output, log_probs, sample=None)[source]¶ Get normalized probabilities (or log probs) from a net’s output.
-
classmethod
-
class
fairseq.models.
FairseqEncoderDecoderModel
(encoder, decoder)[source]¶ Base class for encoder-decoder models.
Parameters: - encoder (FairseqEncoder) – the encoder
- decoder (FairseqDecoder) – the decoder
-
extract_features
(src_tokens, src_lengths, prev_output_tokens, **kwargs)[source]¶ Similar to forward but only return features.
Returns: - the decoder’s features of shape (batch, tgt_len, embed_dim)
- a dictionary with any model-specific outputs
Return type: tuple
-
forward
(src_tokens, src_lengths, prev_output_tokens, **kwargs)[source]¶ Run the forward pass for an encoder-decoder model.
First feed a batch of source tokens through the encoder. Then, feed the encoder output and previous decoder outputs (i.e., teacher forcing) to the decoder to produce the next outputs:
encoder_out = self.encoder(src_tokens, src_lengths) return self.decoder(prev_output_tokens, encoder_out)
Parameters: - src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
- src_lengths (LongTensor) – source sentence lengths of shape (batch)
- prev_output_tokens (LongTensor) – previous decoder outputs of shape (batch, tgt_len), for teacher forcing
Returns: - the decoder’s output of shape (batch, tgt_len, vocab)
- a dictionary with any model-specific outputs
Return type:
-
class
fairseq.models.
FairseqEncoderModel
(encoder)[source]¶ Base class for encoder-only models.
Parameters: encoder (FairseqEncoder) – the encoder -
forward
(src_tokens, src_lengths, **kwargs)[source]¶ Run the forward pass for a encoder-only model.
Feeds a batch of tokens through the encoder to generate features.
Parameters: - src_tokens (LongTensor) – input tokens of shape (batch, src_len)
- src_lengths (LongTensor) – source sentence lengths of shape (batch)
Returns: the encoder’s output, typically of shape (batch, src_len, features)
-
-
class
fairseq.models.
FairseqLanguageModel
(decoder)[source]¶ Base class for decoder-only models.
Parameters: decoder (FairseqDecoder) – the decoder -
extract_features
(src_tokens, **kwargs)[source]¶ Similar to forward but only return features.
Returns: - the decoder’s features of shape (batch, seq_len, embed_dim)
- a dictionary with any model-specific outputs
Return type: tuple
-
forward
(src_tokens, **kwargs)[source]¶ Run the forward pass for a decoder-only model.
Feeds a batch of tokens through the decoder to predict the next tokens.
Parameters: - src_tokens (LongTensor) – tokens on which to condition the decoder, of shape (batch, tgt_len)
- src_lengths (LongTensor) – source sentence lengths of shape (batch)
Returns: - the decoder’s output of shape (batch, seq_len, vocab)
- a dictionary with any model-specific outputs
Return type:
-
output_layer
(features, **kwargs)[source]¶ Project features to the default output size (typically vocabulary size).
-
supported_targets
¶
-
-
class
fairseq.models.
FairseqMultiModel
(encoders, decoders)[source]¶ Base class for combining multiple encoder-decoder models.
Helper function to build shared embeddings for a set of languages after checking that all dicts corresponding to those languages are equivalent.
Parameters: - dicts – Dict of lang_id to its corresponding Dictionary
- langs – languages that we want to share embeddings for
- embed_dim – embedding dimension
- build_embedding – callable function to actually build the embedding
- pretrained_embed_path – Optional path to load pretrained embeddings
-
decoder
¶
-
encoder
¶
-
forward
(src_tokens, src_lengths, prev_output_tokens, **kwargs)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
class
fairseq.models.
FairseqEncoder
(dictionary)[source]¶ Base class for encoders.
-
forward
(src_tokens, src_lengths=None, **kwargs)[source]¶ Parameters: - src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
- src_lengths (LongTensor) – lengths of each source sentence of shape (batch)
-
-
class
fairseq.models.
CompositeEncoder
(encoders)[source]¶ A wrapper around a dictionary of
FairseqEncoder
objects.We run forward on each encoder and return a dictionary of outputs. The first encoder’s dictionary is used for initialization.
Parameters: encoders (dict) – a dictionary of FairseqEncoder
objects.
-
class
fairseq.models.
FairseqDecoder
(dictionary)[source]¶ Base class for decoders.
-
extract_features
(prev_output_tokens, encoder_out=None, **kwargs)[source]¶ Returns: - the decoder’s features of shape (batch, tgt_len, embed_dim)
- a dictionary with any model-specific outputs
Return type: tuple
-
forward
(prev_output_tokens, encoder_out=None, **kwargs)[source]¶ Parameters: - prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for teacher forcing
- encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
Returns: - the decoder’s output of shape (batch, tgt_len, vocab)
- a dictionary with any model-specific outputs
Return type:
-
get_normalized_probs
(net_output, log_probs, sample)[source]¶ Get normalized probabilities (or log probs) from a net’s output.
-
Incremental decoding¶
-
class
fairseq.models.
FairseqIncrementalDecoder
(dictionary)[source]¶ Base class for incremental decoders.
Incremental decoding is a special mode at inference time where the Model only receives a single timestep of input corresponding to the previous output token (for teacher forcing) and must produce the next output incrementally. Thus the model must cache any long-term state that is needed about the sequence, e.g., hidden states, convolutional states, etc.
Compared to the standard
FairseqDecoder
interface, the incremental decoder interface allowsforward()
functions to take an extra keyword argument (incremental_state) that can be used to cache state across time-steps.The
FairseqIncrementalDecoder
interface also defines thereorder_incremental_state()
method, which is used during beam search to select and reorder the incremental state based on the selection of beams.To learn more about how incremental decoding works, refer to this blog.
-
extract_features
(prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs)[source]¶ Returns: - the decoder’s features of shape (batch, tgt_len, embed_dim)
- a dictionary with any model-specific outputs
Return type: tuple
-
forward
(prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs)[source]¶ Parameters: - prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for teacher forcing
- encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
- incremental_state (dict, optional) – dictionary used for storing state during Incremental decoding
Returns: - the decoder’s output of shape (batch, tgt_len, vocab)
- a dictionary with any model-specific outputs
Return type:
-
Criterions¶
Criterions compute the loss function given the model and batch, roughly:
loss = criterion(model, batch)
-
class
fairseq.criterions.
FairseqCriterion
(args, task)[source]¶ -
-
static
aggregate_logging_outputs
(logging_outputs)[source]¶ Aggregate logging outputs from data parallel training.
-
static
-
class
fairseq.criterions.adaptive_loss.
AdaptiveLoss
(args, task)[source]¶ This is an implementation of the loss function accompanying the adaptive softmax approximation for graphical processing units (GPU), described in the paper “Efficient softmax approximation for GPUs” (http://arxiv.org/abs/1609.04309).
-
class
fairseq.criterions.composite_loss.
CompositeLoss
(args, task)[source]¶ This is a composite loss that, given a list of model outputs and a list of targets, computes an average of losses for each output-target pair
-
class
fairseq.criterions.cross_entropy.
CrossEntropyCriterion
(args, task)[source]¶
Optimizers¶
Optimizers update the Model parameters based on the gradients.
-
class
fairseq.optim.
FP16Optimizer
(args, params, fp32_optimizer, fp32_params)[source]¶ Wrap an optimizer to support FP16 (mixed precision) training.
-
backward
(loss)[source]¶ Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to
fairseq.optim.FairseqOptimizer.backward()
, this function additionally dynamically scales the loss to avoid gradient underflow.
-
classmethod
build_optimizer
(args, params)[source]¶ Parameters: - args (argparse.Namespace) – fairseq args
- params (iterable) – iterable of parameters to optimize
-
load_state_dict
(state_dict, optimizer_overrides=None)[source]¶ Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer instance (e.g., learning rate) over that found in the state_dict. This allows us to resume training from a checkpoint using a new set of optimizer args.
-
optimizer
¶ Return a torch.optim.optimizer.Optimizer instance.
-
optimizer_config
¶ Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.
-
-
class
fairseq.optim.
MemoryEfficientFP16Optimizer
(args, params, optimizer)[source]¶ Wrap an optimizer to support FP16 (mixed precision) training.
Compared to
fairseq.optim.FP16Optimizer
, this version does not maintain an FP32 copy of the model. We instead expect the optimizer to convert the gradients to FP32 internally and sync the results back to the FP16 model params. This significantly reduces memory usage but slightly increases the time spent in the optimizer.Since this wrapper depends on specific functionality in the wrapped optimizer (i.e., on-the-fly conversion of grads to FP32), only certain optimizers can be wrapped. This is determined by the supports_memory_efficient_fp16 property.
-
backward
(loss)[source]¶ Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to
fairseq.optim.FairseqOptimizer.backward()
, this function additionally dynamically scales the loss to avoid gradient underflow.
-
classmethod
build_optimizer
(args, params)[source]¶ Parameters: - args (argparse.Namespace) – fairseq args
- params (iterable) – iterable of parameters to optimize
-
load_state_dict
(state_dict, optimizer_overrides=None)[source]¶ Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer instance (e.g., learning rate) over that found in the state_dict. This allows us to resume training from a checkpoint using a new set of optimizer args.
-
optimizer
¶ Return a torch.optim.optimizer.Optimizer instance.
-
optimizer_config
¶ Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.
-
-
class
fairseq.optim.
FairseqOptimizer
(args, params)[source]¶ -
-
load_state_dict
(state_dict, optimizer_overrides=None)[source]¶ Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer instance (e.g., learning rate) over that found in the state_dict. This allows us to resume training from a checkpoint using a new set of optimizer args.
-
optimizer
¶ Return a torch.optim.optimizer.Optimizer instance.
-
optimizer_config
¶ Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.
-
supports_memory_efficient_fp16
¶
-
-
class
fairseq.optim.adadelta.
Adadelta
(args, params)[source]¶ -
-
optimizer_config
¶ Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.
-
-
class
fairseq.optim.adagrad.
Adagrad
(args, params)[source]¶ -
-
optimizer_config
¶ Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.
-
-
class
fairseq.optim.adafactor.
FairseqAdafactor
(args, params)[source]¶ -
-
optimizer_config
¶ Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate. Note : Convergence issues empirically observed with fp16 on.
Might require search for appropriate configuration.
-
-
class
fairseq.optim.adam.
FairseqAdam
(args, params)[source]¶ -
-
optimizer_config
¶ Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.
-
-
class
fairseq.optim.fp16_optimizer.
FP16Optimizer
(args, params, fp32_optimizer, fp32_params)[source]¶ Wrap an optimizer to support FP16 (mixed precision) training.
-
backward
(loss)[source]¶ Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to
fairseq.optim.FairseqOptimizer.backward()
, this function additionally dynamically scales the loss to avoid gradient underflow.
-
classmethod
build_optimizer
(args, params)[source]¶ Parameters: - args (argparse.Namespace) – fairseq args
- params (iterable) – iterable of parameters to optimize
-
load_state_dict
(state_dict, optimizer_overrides=None)[source]¶ Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer instance (e.g., learning rate) over that found in the state_dict. This allows us to resume training from a checkpoint using a new set of optimizer args.
-
optimizer
¶ Return a torch.optim.optimizer.Optimizer instance.
-
optimizer_config
¶ Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.
-
-
class
fairseq.optim.nag.
FairseqNAG
(args, params)[source]¶ -
-
optimizer_config
¶ Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate.
-
Learning Rate Schedulers¶
Learning Rate Schedulers update the learning rate over the course of training.
Learning rates can be updated after each update via step_update()
or at
epoch boundaries via step()
.
-
class
fairseq.optim.lr_scheduler.cosine_lr_scheduler.
CosineSchedule
(args, optimizer)[source]¶ Assign LR based on a cyclical schedule that follows the cosine function.
See https://arxiv.org/pdf/1608.03983.pdf for details.
We also support a warmup phase where we linearly increase the learning rate from some initial learning rate (
--warmup-init-lr
) until the configured max learning rate (--max-lr
).During warmup:
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) lr = lrs[update_num]
After warmup:
lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))
where
t_curr
is current percentage of updates within the current period range andt_i
is the current period range, which is scaled byt_mul
after every iteration.
-
class
fairseq.optim.lr_scheduler.fixed_schedule.
FixedSchedule
(args, optimizer)[source]¶ Decay the LR on a fixed schedule.
-
class
fairseq.optim.lr_scheduler.inverse_square_root_schedule.
InverseSquareRootSchedule
(args, optimizer)[source]¶ Decay the LR based on the inverse square root of the update number.
We also support a warmup phase where we linearly increase the learning rate from some initial learning rate (
--warmup-init-lr
) until the configured learning rate (--lr
). Thereafter we decay proportional to the number of updates, with a decay factor set to align with the configured learning rate.During warmup:
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) lr = lrs[update_num]
After warmup:
decay_factor = args.lr * sqrt(args.warmup_updates) lr = decay_factor / sqrt(update_num)
-
class
fairseq.optim.lr_scheduler.reduce_lr_on_plateau.
ReduceLROnPlateau
(args, optimizer)[source]¶ Decay the LR by a factor every time the validation loss plateaus.
-
class
fairseq.optim.lr_scheduler.triangular_lr_scheduler.
TriangularSchedule
(args, optimizer)[source]¶ Assign LR based on a triangular cyclical schedule.
See https://arxiv.org/pdf/1506.01186.pdf for details.
Data Loading and Utilities¶
Datasets¶
Datasets define the data format and provide helpers for creating mini-batches.
-
class
fairseq.data.
FairseqDataset
[source]¶ A dataset that provides helpers for batching.
-
collater
(samples)[source]¶ Merge a list of samples to form a mini-batch.
Parameters: samples (List[dict]) – samples to collate Returns: a mini-batch suitable for forwarding with a Model Return type: dict
-
num_tokens
(index)[source]¶ Return the number of tokens in a sample. This value is used to enforce
--max-tokens
during batching.
-
ordered_indices
()[source]¶ Return an ordered list of indices. Batches will be constructed based on this order.
-
size
(index)[source]¶ Return an example’s size as a float or tuple. This value is used when filtering a dataset with
--max-positions
.
-
supports_prefetch
¶ Whether this dataset supports prefetching.
-
-
class
fairseq.data.
LanguagePairDataset
(src, src_sizes, src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, left_pad_source=True, left_pad_target=False, max_source_positions=1024, max_target_positions=1024, shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False)[source]¶ A pair of torch.utils.data.Datasets.
Parameters: - src (torch.utils.data.Dataset) – source dataset to wrap
- src_sizes (List[int]) – source sentence lengths
- src_dict (Dictionary) – source vocabulary
- tgt (torch.utils.data.Dataset, optional) – target dataset to wrap
- tgt_sizes (List[int], optional) – target sentence lengths
- tgt_dict (Dictionary, optional) – target vocabulary
- left_pad_source (bool, optional) – pad source tensors on the left side (default: True).
- left_pad_target (bool, optional) – pad target tensors on the left side (default: False).
- max_source_positions (int, optional) – max number of tokens in the source sentence (default: 1024).
- max_target_positions (int, optional) – max number of tokens in the target sentence (default: 1024).
- shuffle (bool, optional) – shuffle dataset elements before batching (default: True).
- input_feeding (bool, optional) – create a shifted version of the targets to be passed into the model for teacher forcing (default: True).
- remove_eos_from_source (bool, optional) – if set, removes eos from end of source if it’s present (default: False).
- append_eos_to_target (bool, optional) – if set, appends eos to end of target if it’s absent (default: False).
-
collater
(samples)[source]¶ Merge a list of samples to form a mini-batch.
Parameters: samples (List[dict]) – samples to collate Returns: a mini-batch with the following keys: - id (LongTensor): example IDs in the original input order
- ntokens (int): total number of tokens in the batch
- net_input (dict): the input to the Model, containing keys:
- src_tokens (LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape (bsz, src_len). Padding will
appear on the left if left_pad_source is
True
. - src_lengths (LongTensor): 1D Tensor of the unpadded lengths of each source sentence of shape (bsz)
- prev_output_tokens (LongTensor): a padded 2D Tensor of
tokens in the target sentence, shifted right by one
position for teacher forcing, of shape (bsz, tgt_len).
This key will not be present if input_feeding is
False
. Padding will appear on the left if left_pad_target isTrue
.
- src_tokens (LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape (bsz, src_len). Padding will
appear on the left if left_pad_source is
- target (LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape (bsz, tgt_len). Padding will appear
on the left if left_pad_target is
True
.
Return type: dict
-
num_tokens
(index)[source]¶ Return the number of tokens in a sample. This value is used to enforce
--max-tokens
during batching.
-
ordered_indices
()[source]¶ Return an ordered list of indices. Batches will be constructed based on this order.
-
size
(index)[source]¶ Return an example’s size as a float or tuple. This value is used when filtering a dataset with
--max-positions
.
-
supports_prefetch
¶ Whether this dataset supports prefetching.
-
class
fairseq.data.
MonolingualDataset
(dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle, targets=None, add_bos_token=False)[source]¶ A wrapper around torch.utils.data.Dataset for monolingual data.
Parameters: - dataset (torch.utils.data.Dataset) – dataset to wrap
- sizes (List[int]) – sentence lengths
- vocab (Dictionary) – vocabulary
- shuffle (bool, optional) – shuffle the elements before batching (default: True).
-
collater
(samples)[source]¶ Merge a list of samples to form a mini-batch.
Parameters: samples (List[dict]) – samples to collate Returns: a mini-batch with the following keys: - id (LongTensor): example IDs in the original input order
- ntokens (int): total number of tokens in the batch
- net_input (dict): the input to the Model, containing keys:
- src_tokens (LongTensor): a padded 2D Tensor of tokens in the source sentence of shape (bsz, src_len). Padding will appear on the right.
- target (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape (bsz, tgt_len). Padding will appear on the right.
Return type: dict
-
num_tokens
(index)[source]¶ Return the number of tokens in a sample. This value is used to enforce
--max-tokens
during batching.
-
ordered_indices
()[source]¶ Return an ordered list of indices. Batches will be constructed based on this order.
-
size
(index)[source]¶ Return an example’s size as a float or tuple. This value is used when filtering a dataset with
--max-positions
.
-
supports_prefetch
¶ Whether this dataset supports prefetching.
Helper Datasets
These datasets wrap other fairseq.data.FairseqDataset
instances and
provide additional functionality:
-
class
fairseq.data.
BacktranslationDataset
(tgt_dataset, src_dict, tgt_dict=None, backtranslation_fn=None, output_collater=None, cuda=True, **kwargs)[source]¶ Sets up a backtranslation dataset which takes a tgt batch, generates a src using a tgt-src backtranslation function (backtranslation_fn), and returns the corresponding {generated src, input tgt} batch.
Parameters: - tgt_dataset (FairseqDataset) – the dataset to be backtranslated. Only the source side of this dataset will be used. After backtranslation, the source sentences in this dataset will be returned as the targets.
- src_dict (Dictionary) – the dictionary of backtranslated sentences.
- tgt_dict (Dictionary, optional) – the dictionary of sentences to be backtranslated.
- backtranslation_fn (callable, optional) – function to call to generate
backtranslations. This is typically the generate method of a
SequenceGenerator
object. Pass in None when it is not available at initialization time, and use set_backtranslation_fn function to set it when available. - output_collater (callable, optional) – function to call on the
backtranslated samples to create the final batch
(default:
tgt_dataset.collater
). - cuda – use GPU for generation
-
collater
(samples)[source]¶ Merge and backtranslate a list of samples to form a mini-batch.
Using the samples from tgt_dataset, load a collated target sample to feed to the backtranslation model. Then take the backtranslation with the best score as the source and the original input as the target.
Note: we expect tgt_dataset to provide a function collater() that will collate samples into the format expected by backtranslation_fn. After backtranslation, we will feed the new list of samples (i.e., the (backtranslated source, original source) pairs) to output_collater and return the result.
Parameters: samples (List[dict]) – samples to backtranslate and collate Returns: a mini-batch with keys coming from output_collater Return type: dict
-
size
(index)[source]¶ Return an example’s size as a float or tuple. This value is used when filtering a dataset with
--max-positions
.Note: we use tgt_dataset to approximate the length of the source sentence, since we do not know the actual length until after backtranslation.
-
supports_prefetch
¶ Whether this dataset supports prefetching.
-
class
fairseq.data.
ConcatDataset
(datasets, sample_ratios=1)[source]¶ -
collater
(samples)[source]¶ Merge a list of samples to form a mini-batch.
Parameters: samples (List[dict]) – samples to collate Returns: a mini-batch suitable for forwarding with a Model Return type: dict
-
num_tokens
(index: int)[source]¶ Return the number of tokens in a sample. This value is used to enforce
--max-tokens
during batching.
-
supports_prefetch
¶ Whether this dataset supports prefetching.
-
-
class
fairseq.data.
RoundRobinZipDatasets
(datasets, eval_key=None)[source]¶ Zip multiple
FairseqDataset
instances together.Shorter datasets are repeated in a round-robin fashion to match the length of the longest one.
Parameters: - datasets (Dict[FairseqDataset]) – a dictionary of
FairseqDataset
instances. - eval_key (str, optional) – a key used at evaluation time that causes this instance to pass-through batches from datasets[eval_key].
-
size
(index)[source]¶ Return an example’s size as a float or tuple. This value is used when filtering a dataset with
--max-positions
.
-
supports_prefetch
¶ Whether this dataset supports prefetching.
- datasets (Dict[FairseqDataset]) – a dictionary of
-
class
fairseq.data.
TransformEosDataset
(dataset, eos, append_eos_to_src=False, remove_eos_from_src=False, append_eos_to_tgt=False, remove_eos_from_tgt=False, has_target=True)[source]¶ A
FairseqDataset
wrapper that appends/prepends/strips EOS.Note that the transformation is applied in
collater()
.Parameters: - dataset (FairseqDataset) – dataset to wrap
- eos (int) – index of the end-of-sentence symbol
- append_eos_to_src (bool, optional) – append EOS to the end of src
- remove_eos_from_src (bool, optional) – remove EOS from the end of src
- append_eos_to_tgt (bool, optional) – append EOS to the end of tgt
- remove_eos_from_tgt (bool, optional) – remove EOS from the end of tgt
-
collater
(samples)[source]¶ Merge a list of samples to form a mini-batch.
Parameters: samples (List[dict]) – samples to collate Returns: a mini-batch suitable for forwarding with a Model Return type: dict
-
num_tokens
(index)[source]¶ Return the number of tokens in a sample. This value is used to enforce
--max-tokens
during batching.
-
ordered_indices
()[source]¶ Return an ordered list of indices. Batches will be constructed based on this order.
-
size
(index)[source]¶ Return an example’s size as a float or tuple. This value is used when filtering a dataset with
--max-positions
.
-
supports_prefetch
¶ Whether this dataset supports prefetching.
Dictionary¶
-
class
fairseq.data.
Dictionary
(pad='<pad>', eos='</s>', unk='<unk>', bos='<s>', extra_special_symbols=None)[source]¶ A mapping from symbols to consecutive integers
-
add_from_file
(f, ignore_utf_errors=False)[source]¶ Loads a pre-existing dictionary from a text file and adds its symbols to this instance.
-
finalize
(threshold=-1, nwords=-1, padding_factor=8)[source]¶ Sort symbols by frequency in descending order, ignoring special ones.
Parameters: - threshold defines the minimum word count (-) –
- nwords defines the total number of words in the final dictionary, (-) – including special symbols
- padding_factor can be used to pad the dictionary size to be a (-) – multiple of 8, which is important on some hardware (e.g., Nvidia Tensor Cores).
-
classmethod
load
(f, ignore_utf_errors=False)[source]¶ Loads the dictionary from a text file with the format:
` <symbol0> <count0> <symbol1> <count1> ... `
-
Iterators¶
-
class
fairseq.data.
CountingIterator
(iterable, start=0)[source]¶ Wrapper around an iterable that maintains the iteration count.
Parameters: iterable (iterable) – iterable to wrap
-
class
fairseq.data.
EpochBatchIterator
(dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0)[source]¶ A multi-epoch iterator over a
torch.utils.data.Dataset
.Compared to
torch.utils.data.DataLoader
, this iterator:- can be reused across multiple epochs with the
next_epoch_itr()
method (optionally shuffled between epochs) - can be serialized/deserialized with the
state_dict()
andload_state_dict()
methods - supports sharding with the num_shards and shard_id arguments
Parameters: - dataset (Dataset) – dataset from which to load the data
- collate_fn (callable) – merges a list of samples to form a mini-batch
- batch_sampler (Sampler) – an iterator over batches of indices
- seed (int, optional) – seed for random number generator for reproducibility (default: 1).
- num_shards (int, optional) – shard the data iterator into N shards (default: 1).
- shard_id (int, optional) – which shard of the data iterator to return (default: 0).
- num_workers (int, optional) – how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0).
- epoch (int, optional) – the epoch to start the iterator from (default: 0).
-
iterations_in_epoch
¶ The number of consumed batches in the current epoch.
-
next_epoch_itr
(shuffle=True, fix_batches_to_gpus=False)[source]¶ Return a new iterator over the dataset.
Parameters: - shuffle (bool, optional) – shuffle batches before returning the iterator (default: True).
- fix_batches_to_gpus – ensure that batches are always
allocated to the same shards across epochs. Requires
that
dataset
supports prefetching (default: False).
- can be reused across multiple epochs with the
-
class
fairseq.data.
GroupedIterator
(iterable, chunk_size)[source]¶ Wrapper around an iterable that returns groups (chunks) of items.
Parameters: - iterable (iterable) – iterable to wrap
- chunk_size (int) – size of each chunk
Modules¶
Fairseq provides several stand-alone torch.nn.Module
classes that may
be helpful when implementing a new BaseFairseqModel
.
-
class
fairseq.modules.
AdaptiveInput
(vocab_size: int, padding_idx: int, initial_dim: int, factor: float, output_dim: int, cutoff: List[int])[source]¶ -
forward
(input: torch.Tensor)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
fairseq.modules.
AdaptiveSoftmax
(vocab_size, input_dim, cutoff, dropout, factor=4.0, adaptive_inputs=None, tie_proj=False)[source]¶ This is an implementation of the efficient softmax approximation for graphical processing units (GPU), described in the paper “Efficient softmax approximation for GPUs” (http://arxiv.org/abs/1609.04309).
-
adapt_target
(target)[source]¶ In order to be efficient, the AdaptiveSoftMax does not compute the scores for all the word of the vocabulary for all the examples. It is thus necessary to call the method adapt_target of the AdaptiveSoftMax layer inside each forward pass.
-
forward
(input, target)[source]¶ Parameters: - input – (b x t x d)
- target – (b x t)
Returns: output for each cutoff section and new targets by cut off
Return type: 2 lists
-
-
class
fairseq.modules.
BeamableMM
(beam_size=None)[source]¶ This module provides an optimized MM for beam decoding with attention.
It leverage the fact that the source-side of the input is replicated beam times and the target-side of the input is of width one. This layer speeds up inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
-
forward
(input1, input2)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
fairseq.modules.
CharacterTokenEmbedder
(vocab: fairseq.data.dictionary.Dictionary, filters: List[Tuple[int, int]], char_embed_dim: int, word_embed_dim: int, highway_layers: int, max_char_len: int = 50, char_inputs: bool = False)[source]¶ -
forward
(input: torch.Tensor)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
padding_idx
¶
-
-
class
fairseq.modules.
ConvTBC
(in_channels, out_channels, kernel_size, padding=0)[source]¶ 1D convolution over an input of shape (time x batch x channel)
The implementation uses gemm to perform the convolution. This implementation is faster than cuDNN for small kernel sizes.
-
forward
(input)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
fairseq.modules.
DownsampledMultiHeadAttention
(out_channels, embed_dim, num_heads, dropout=0.0, bias=True, project_input=True, gated=False, downsample=False)[source]¶ Multi-headed attention with Gating and Downsampling
-
forward
(query, key, value, mask_future_timesteps=False, key_padding_mask=None, use_scalar_bias=False)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
fairseq.modules.
DynamicConv1dTBC
(input_size, kernel_size=1, padding_l=None, num_heads=1, weight_dropout=0.0, weight_softmax=False, renorm_padding=False, bias=False, conv_bias=False, query_size=None, in_proj=False)[source]¶ Dynamic lightweight convolution taking T x B x C inputs :param input_size: # of channels of the input :param kernel_size: convolution channels :param padding_l: padding to the left when using “same” padding :param num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) :param weight_dropout: the drop rate of the DropConnect to drop the weight :param weight_softmax: normalize the weight with softmax before the convolution :param renorm_padding: re-normalize the filters to ignore the padded part (only the non-padding parts sum up to 1) :param bias: use bias :param conv_bias: bias of the convolution :param query_size: specified when feeding a different input as the query :param in_proj: project the input and generate the filter together
- Shape:
- Input: TxBxC, i.e. (timesteps, batch_size, input_size) Output: TxBxC, i.e. (timesteps, batch_size, input_size)
-
weight
¶ the learnable weights of the module of shape (num_heads, 1, kernel_size)
-
bias
¶ the learnable bias of the module of shape (input_size)
-
extra_repr
()[source]¶ Set the extra representation of the module
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
-
forward
(x, incremental_state=None, query=None, unfold=None)[source]¶ Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C :param x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) :param incremental_state: A dict to keep the state :param unfold: unfold the input or not. If not, we use the matrix trick instead :param query: use the specified query to predict the conv filters
-
in_proj
¶
-
class
fairseq.modules.
GradMultiply
[source]¶ -
static
backward
(ctx, grad)[source]¶ Defines a formula for differentiating the operation.
This function is to be overridden by all subclasses.
It must accept a context
ctx
as the first argument, followed by as many outputs didforward()
return, and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computated w.r.t. the output.
-
static
forward
(ctx, x, scale)[source]¶ Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store tensors that can be then retrieved during the backward pass.
-
static
-
class
fairseq.modules.
Highway
(input_dim: int, num_layers: int = 1)[source]¶ A Highway layer. Adopted from the AllenNLP implementation.
-
forward
(x: torch.Tensor)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
fairseq.modules.
LayerNorm
(normalized_shape, eps=1e-05, elementwise_affine=True, export=False)[source]¶
-
class
fairseq.modules.
LearnedPositionalEmbedding
(num_embeddings: int, embedding_dim: int, padding_idx: int)[source]¶ This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to the forward function.
-
class
fairseq.modules.
LightweightConv1dTBC
(input_size, kernel_size=1, padding_l=None, num_heads=1, weight_dropout=0.0, weight_softmax=False, bias=False)[source]¶ Lightweight Convolution assuming the input is TxBxC :param input_size: # of channels of the input :param kernel_size: convolution channels :param padding_l: padding to the left when using “same” padding :param num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) :param weight_dropout: the drop rate of the DropConnect to drop the weight :param weight_softmax: normalize the weight with softmax before the convolution :param bias: use bias
- Shape:
- Input: TxBxC, i.e. (timesteps, batch_size, input_size) Output: TxBxC, i.e. (timesteps, batch_size, input_size)
-
weight
¶ the learnable weights of the module of shape (num_heads, 1, kernel_size)
-
bias
¶ the learnable bias of the module of shape (input_size)
-
extra_repr
()[source]¶ Set the extra representation of the module
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
-
forward
(x, incremental_state=None, unfold=False)[source]¶ Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C :param x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) :param incremental_state: A dict to keep the state :param unfold: unfold the input or not. If not, we use the matrix trick instead
-
class
fairseq.modules.
LinearizedConvolution
(in_channels, out_channels, kernel_size, **kwargs)[source]¶ An optimized version of nn.Conv1d.
At training time, this module uses ConvTBC, which is an optimized version of Conv1d. At inference time, it optimizes incremental generation (i.e., one time step at a time) by replacing the convolutions with linear layers. Note that the input order changes from training to inference.
-
forward
(input, incremental_state=None)[source]¶ Parameters: incremental_state – Used to buffer signal; if not None, then input is expected to contain a single frame. If the input order changes between time steps, call reorder_incremental_state. - Input:
- Time x Batch x Channel during training Batch x Time x Channel during inference
-
-
class
fairseq.modules.
LogSumExpMoE
[source]¶ Standard LogSumExp forward pass, but use posterior for the backward.
See “Mixture Models for Diverse Machine Translation: Tricks of the Trade” (Shen et al., 2019).
-
static
backward
(ctx, grad_output)[source]¶ Defines a formula for differentiating the operation.
This function is to be overridden by all subclasses.
It must accept a context
ctx
as the first argument, followed by as many outputs didforward()
return, and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computated w.r.t. the output.
-
static
forward
(ctx, logp, posterior, dim=-1)[source]¶ Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store tensors that can be then retrieved during the backward pass.
-
static
-
class
fairseq.modules.
MeanPoolGatingNetwork
(embed_dim, num_experts, dropout=None)[source]¶ A simple mean-pooling gating network for selecting experts.
This module applies mean pooling over an encoder’s output and returns reponsibilities for each expert. The encoder format is expected to match
fairseq.models.transformer.TransformerEncoder
.-
forward
(encoder_out)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
fairseq.modules.
MultiheadAttention
(embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False)[source]¶ Multi-headed attention.
See “Attention Is All You Need” for more details.
-
forward
(query, key, value, key_padding_mask=None, incremental_state=None, need_weights=True, static_kv=False, attn_mask=None)[source]¶ Input shape: Time x Batch x Channel
Timesteps can be masked by supplying a T x T mask in the attn_mask argument. Padding elements can be excluded from the key by passing a binary ByteTensor (key_padding_mask) with shape: batch x src_len, where padding elements are indicated by 1s.
-
-
fairseq.modules.
PositionalEmbedding
(num_embeddings: int, embedding_dim: int, padding_idx: int, learned: bool = False)[source]¶
-
class
fairseq.modules.
ScalarBias
[source]¶ Adds a vector of scalars, used in self-attention mechanism to allow the model to optionally attend to this vector instead of the past
-
static
backward
(ctx, grad)[source]¶ Defines a formula for differentiating the operation.
This function is to be overridden by all subclasses.
It must accept a context
ctx
as the first argument, followed by as many outputs didforward()
return, and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computated w.r.t. the output.
-
static
forward
(ctx, input, dim, bias_init)[source]¶ Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store tensors that can be then retrieved during the backward pass.
-
static
-
class
fairseq.modules.
SinusoidalPositionalEmbedding
(embedding_dim, padding_idx, init_size=1024)[source]¶ This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
-
forward
(input, incremental_state=None, timestep=None, **kwargs)[source]¶ Input is expected to be of size [bsz x seqlen].
-
-
class
fairseq.modules.
TransformerSentenceEncoderLayer
(embedding_dim: float = 768, ffn_embedding_dim: float = 3072, num_attention_heads: float = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, activation_fn: str = 'relu', add_bias_kv: bool = False, add_zero_attn: bool = False, export: bool = False)[source]¶ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained models.
-
class
fairseq.modules.
TransformerSentenceEncoder
(padding_idx: int, vocab_size: int, num_encoder_layers: int = 6, embedding_dim: int = 768, ffn_embedding_dim: int = 3072, num_attention_heads: int = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, max_seq_len: int = 256, num_segments: int = 2, use_position_embeddings: bool = True, offset_positions_by_padding: bool = True, encoder_normalize_before: bool = False, apply_bert_init: bool = False, activation_fn: str = 'relu', learned_pos_embedding: bool = True, add_bias_kv: bool = False, add_zero_attn: bool = False, embed_scale: float = None, freeze_embeddings: bool = False, n_trans_layers_to_freeze: int = 0, export: bool = False)[source]¶ Implementation for a Bi-directional Transformer based Sentence Encoder used in BERT/XLM style pre-trained models.
This first computes the token embedding using the token embedding matrix, position embeddings (if specified) and segment embeddings (if specified). After applying the specified number of TransformerEncoderLayers, it outputs all the internal states of the encoder as well as the final representation associated with the first token (usually CLS token).
- Input:
- tokens: B x T matrix representing sentences
- segment_labels: B x T matrix representing segment label for tokens
- Output:
- a tuple of the following:
- a list of internal model states used to compute the predictions where each tensor has shape B x T x C
- sentence representation associated with first input token in format B x C.
-
forward
(tokens: torch.Tensor, segment_labels: torch.Tensor = None, last_state_only: bool = False, positions: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor][source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
class
fairseq.modules.
TransformerDecoderLayer
(args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False)[source]¶ Decoder layer block.
In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: dropout -> add residual -> layernorm. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: dropout -> add residual. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting args.decoder_normalize_before to
True
.Parameters: - args (argparse.Namespace) – parsed command-line arguments
- no_encoder_attn (bool, optional) – whether to attend to encoder outputs (default: False).
-
forward
(x, encoder_out=None, encoder_padding_mask=None, incremental_state=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None)[source]¶ Parameters: - x (Tensor) – input to the layer of shape (seq_len, batch, embed_dim)
- encoder_padding_mask (ByteTensor) – binary ByteTensor of shape
(batch, src_len) where padding elements are indicated by
1
.
Returns: encoded output of shape (seq_len, batch, embed_dim)
-
class
fairseq.modules.
TransformerEncoderLayer
(args)[source]¶ Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is postprocessed with: dropout -> add residual -> layernorm. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: dropout -> add residual. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting args.encoder_normalize_before to
True
.Parameters: args (argparse.Namespace) – parsed command-line arguments -
forward
(x, encoder_padding_mask, attn_mask=None)[source]¶ Parameters: - x (Tensor) – input to the layer of shape (seq_len, batch, embed_dim)
- encoder_padding_mask (ByteTensor) – binary ByteTensor of shape
(batch, src_len) where padding elements are indicated by
1
. - attn_mask (ByteTensor) – binary tensor of shape (T_tgt, T_src), where
- is the length of query, while T_src is the length of key, (T_tgt) –
- here both query and key is x here, (though) –
- t_src] = 1 means when calculating embedding (attn_mask[t_tgt,) –
- t_tgt, t_src is excluded (for) –
- in attention (included) –
Returns: encoded output of shape (seq_len, batch, embed_dim)
-
-
class
fairseq.modules.
VGGBlock
(in_channels, out_channels, conv_kernel_size, pooling_kernel_size, num_conv_layers, input_dim, conv_stride=1, padding=None, layer_norm=False)[source]¶ VGG motibated cnn module https://arxiv.org/pdf/1409.1556.pdf
Parameters: - in_channels – (int) number of input channels (typically 1)
- out_channels – (int) number of output channels
- conv_kernel_size – convolution channels
- pooling_kernel_size – the size of the pooling window to take a max over
- num_conv_layers – (int) number of convolution layers
- input_dim – (int) input dimension
- conv_stride – the stride of the convolving kernel. Can be a single number or a tuple (sH, sW) Default: 1
- padding – implicit paddings on both sides of the input. Can be a single number or a tuple (padH, padW). Default: None
- layer_norm – (bool) if layer norm is going to be applied. Default: False
- Shape:
- Input: BxCxTxfeat, i.e. (batch_size, input_size, timesteps, features) Output: BxCxTxfeat, i.e. (batch_size, input_size, timesteps, features)
-
forward
(x)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.