Models

A Model defines the neural network’s forward() method and encapsulates all of the learnable parameters in the network. Each model also provides a set of named architectures that define the precise network configuration (e.g., embedding dimension, number of layers, etc.).

Both the model type and architecture are selected via the --arch command-line argument. Once selected, a model may expose additional command-line arguments for further configuration.

Note

All fairseq Models extend BaseFairseqModel, which in turn extends torch.nn.Module. Thus any fairseq Model can be used as a stand-alone Module in other PyTorch code.

Convolutional Neural Networks (CNN)

class fairseq.models.fconv.FConvModel(encoder, decoder)[source]

A fully convolutional model, i.e. a convolutional encoder and a convolutional decoder, as described in “Convolutional Sequence to Sequence Learning” (Gehring et al., 2017).

Parameters:

The Convolutional model provides the following named architectures and command-line arguments:

usage: 
        [--arch {fconv,fconv_iwslt_de_en,fconv_wmt_en_ro,fconv_wmt_en_de,fconv_wmt_en_fr}]
        [--dropout D] [--encoder-embed-dim N] [--encoder-embed-path STR]
        [--encoder-layers EXPR] [--decoder-embed-dim N]
        [--decoder-embed-path STR] [--decoder-layers EXPR]
        [--decoder-out-embed-dim N] [--decoder-attention EXPR]
        [--share-input-output-embed]

Named architectures

--arch Possible choices: fconv, fconv_iwslt_de_en, fconv_wmt_en_ro, fconv_wmt_en_de, fconv_wmt_en_fr

Additional command-line arguments

--dropout dropout probability
--encoder-embed-dim encoder embedding dimension
--encoder-embed-path path to pre-trained encoder embedding
--encoder-layers encoder layers [(dim, kernel_size), …]
--decoder-embed-dim decoder embedding dimension
--decoder-embed-path path to pre-trained decoder embedding
--decoder-layers decoder layers [(dim, kernel_size), …]
--decoder-out-embed-dim decoder output embedding dimension
--decoder-attention decoder attention [True, …]
--share-input-output-embed

share input and output embeddings (requires –decoder-out-embed-dim and –decoder-embed-dim to be equal)

Default: False

static add_args(parser)[source]

Add model-specific arguments to the parser.

classmethod build_model(args, task)[source]

Build a new model instance.

class fairseq.models.fconv.FConvEncoder(dictionary, embed_dim=512, embed_dict=None, max_positions=1024, convolutions=((512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3)), dropout=0.1)[source]

Convolutional encoder consisting of len(convolutions) layers.

Parameters:
  • dictionary (Dictionary) – encoding dictionary
  • embed_dim (int, optional) – embedding dimension
  • embed_dict (str, optional) – filename from which to load pre-trained embeddings
  • max_positions (int, optional) – maximum supported input sequence length
  • convolutions (list, optional) – the convolutional layer structure. Each list item i corresponds to convolutional layer i. Layers are given as (out_channels, kernel_width, [residual]). Residual connections are added between layers when residual=1 (which is the default behavior).
  • dropout (float, optional) – dropout to be applied before each conv layer
forward(src_tokens, src_lengths)[source]
Parameters:
  • src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
  • src_lengths (LongTensor) – lengths of each source sentence of shape (batch)
Returns:

  • encoder_out (tuple): a tuple with two elements, where the first element is the last encoder layer’s output and the second element is the same quantity summed with the input embedding (used for attention). The shape of both tensors is (batch, src_len, embed_dim).
  • encoder_padding_mask (ByteTensor): the positions of padding elements of shape (batch, src_len)

Return type:

dict

max_positions()[source]

Maximum input length supported by the encoder.

reorder_encoder_out(encoder_out, new_order)[source]

Reorder encoder output according to new_order.

Parameters:
  • encoder_out – output from the forward() method
  • new_order (LongTensor) – desired order
Returns:

encoder_out rearranged according to new_order

class fairseq.models.fconv.FConvDecoder(dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256, max_positions=1024, convolutions=((512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3)), attention=True, dropout=0.1, share_embed=False, positional_embeddings=True, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0.0)[source]

Convolutional decoder

forward(prev_output_tokens, encoder_out=None, incremental_state=None, **unused)[source]
Parameters:
  • prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for teacher forcing
  • encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
  • incremental_state (dict, optional) – dictionary used for storing state during Incremental decoding
Returns:

  • the decoder’s output of shape (batch, tgt_len, vocab)
  • a dictionary with any model-specific outputs

Return type:

tuple

max_positions()[source]

Maximum output length supported by the decoder.

reorder_incremental_state(incremental_state, new_order)[source]

Reorder incremental state.

This will be called when the order of the input has changed from the previous time step. A typical use case is beam search, where the input order changes between time steps based on the selection of beams.

Long Short-Term Memory (LSTM) networks

class fairseq.models.lstm.LSTMModel(encoder, decoder)[source]
static add_args(parser)[source]

Add model-specific arguments to the parser.

classmethod build_model(args, task)[source]

Build a new model instance.

forward(src_tokens, src_lengths, prev_output_tokens, incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] = None)[source]

Run the forward pass for an encoder-decoder model.

First feed a batch of source tokens through the encoder. Then, feed the encoder output and previous decoder outputs (i.e., teacher forcing) to the decoder to produce the next outputs:

encoder_out = self.encoder(src_tokens, src_lengths)
return self.decoder(prev_output_tokens, encoder_out)
Parameters:
  • src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
  • src_lengths (LongTensor) – source sentence lengths of shape (batch)
  • prev_output_tokens (LongTensor) – previous decoder outputs of shape (batch, tgt_len), for teacher forcing
Returns:

  • the decoder’s output of shape (batch, tgt_len, vocab)
  • a dictionary with any model-specific outputs

Return type:

tuple

class fairseq.models.lstm.LSTMEncoder(dictionary, embed_dim=512, hidden_size=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, left_pad=True, pretrained_embed=None, padding_idx=None, max_source_positions=100000.0)[source]

LSTM encoder.

forward(src_tokens: torch.Tensor, src_lengths: torch.Tensor, enforce_sorted: bool = True)[source]
Parameters:
  • src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
  • src_lengths (LongTensor) – lengths of each source sentence of shape (batch)
  • enforce_sorted (bool, optional) – if True, src_tokens is expected to contain sequences sorted by length in a decreasing order. If False, this condition is not required. Default: True.
max_positions()[source]

Maximum input length supported by the encoder.

reorder_encoder_out(encoder_out: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], new_order)[source]

Reorder encoder output according to new_order.

Parameters:
  • encoder_out – output from the forward() method
  • new_order (LongTensor) – desired order
Returns:

encoder_out rearranged according to new_order

class fairseq.models.lstm.LSTMDecoder(dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True, encoder_output_units=512, pretrained_embed=None, share_input_output_embed=False, adaptive_softmax_cutoff=None, max_target_positions=100000.0, residuals=False)[source]

LSTM decoder.

extract_features(prev_output_tokens, encoder_out: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] = None)[source]

Similar to forward but only return features.

forward(prev_output_tokens, encoder_out: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] = None, src_lengths: Optional[torch.Tensor] = None)[source]
Parameters:
  • prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for teacher forcing
  • encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
  • incremental_state (dict, optional) – dictionary used for storing state during Incremental decoding
Returns:

  • the decoder’s output of shape (batch, tgt_len, vocab)
  • a dictionary with any model-specific outputs

Return type:

tuple

max_positions()[source]

Maximum output length supported by the decoder.

output_layer(x)[source]

Project features to the vocabulary size.

reorder_incremental_state(incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]], new_order: torch.Tensor)[source]

Reorder incremental state.

This will be called when the order of the input has changed from the previous time step. A typical use case is beam search, where the input order changes between time steps based on the selection of beams.

Transformer (self-attention) networks

class fairseq.models.transformer.TransformerModel(args, encoder, decoder)[source]

This is the legacy implementation of the transformer model that uses argparse for configuration.

classmethod add_args(parser)[source]

Add model-specific arguments to the parser.

classmethod build_model(args, task)[source]

Build a new model instance.

class fairseq.models.transformer.TransformerEncoder(args, dictionary, embed_tokens, return_fc=False)[source]
class fairseq.models.transformer.TransformerDecoder(args, dictionary, embed_tokens, no_encoder_attn=False, output_projection=None)[source]

Adding new models

fairseq.models.register_model(name, dataclass=None)[source]

New model types can be added to fairseq with the register_model() function decorator.

For example:

@register_model('lstm')
class LSTM(FairseqEncoderDecoderModel):
    (...)

Note

All models must implement the BaseFairseqModel interface. Typically you will extend FairseqEncoderDecoderModel for sequence-to-sequence tasks or FairseqLanguageModel for language modeling tasks.

Parameters:name (str) – the name of the model
fairseq.models.register_model_architecture(model_name, arch_name)[source]

New model architectures can be added to fairseq with the register_model_architecture() function decorator. After registration, model architectures can be selected with the --arch command-line argument.

For example:

@register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
def lstm_luong_wmt_en_de(cfg):
    args.encoder_embed_dim = getattr(cfg.model, 'encoder_embed_dim', 1000)
    (...)

The decorated function should take a single argument cfg, which is a omegaconf.DictConfig. The decorated function should modify these arguments in-place to match the desired architecture.

Parameters:
  • model_name (str) – the name of the Model (Model must already be registered)
  • arch_name (str) – the name of the model architecture (--arch)
class fairseq.models.BaseFairseqModel[source]

Base class for fairseq models.

classmethod add_args(parser)[source]

Add model-specific arguments to the parser.

classmethod build_model(args, task)[source]

Build a new model instance.

extract_features(*args, **kwargs)[source]

Similar to forward but only return features.

classmethod from_pretrained(model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', **kwargs)[source]

Load a FairseqModel from a pre-trained model file. Downloads and caches the pre-trained model file if needed.

The base implementation returns a GeneratorHubInterface, which can be used to generate translations or sample from language models. The underlying FairseqModel can be accessed via the generator.models attribute.

Other models may override this to implement custom hub interfaces.

Parameters:
  • model_name_or_path (str) – either the name of a pre-trained model to load or a path/URL to a pre-trained model state dict
  • checkpoint_file (str, optional) – colon-separated list of checkpoint files in the model archive to ensemble (default: ‘model.pt’)
  • data_name_or_path (str, optional) – point args.data to the archive at the given path/URL. Can start with ‘.’ or ‘./’ to reuse the model archive path.
get_normalized_probs(net_output: Tuple[torch.Tensor, Optional[Dict[str, List[Optional[torch.Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, torch.Tensor]] = None)[source]

Get normalized probabilities (or log probs) from a net’s output.

get_normalized_probs_scriptable(net_output: Tuple[torch.Tensor, Optional[Dict[str, List[Optional[torch.Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, torch.Tensor]] = None)[source]

Scriptable helper function for get_normalized_probs in ~BaseFairseqModel

get_targets(sample, net_output)[source]

Get targets from either the sample or the net’s output.

classmethod hub_models()[source]
load_state_dict(state_dict, strict=True, model_cfg: Optional[omegaconf.dictconfig.DictConfig] = None, args: Optional[argparse.Namespace] = None)[source]

Copies parameters and buffers from state_dict into this module and its descendants.

Overrides the method in nn.Module. Compared with that method this additionally “upgrades” state_dicts from old checkpoints.

make_generation_fast_(**kwargs)[source]

Legacy entry point to optimize model for faster generation. Prefer prepare_for_inference_.

max_positions()[source]

Maximum length supported by the model.

prepare_for_inference_(cfg: omegaconf.dictconfig.DictConfig)[source]

Prepare model for inference.

prepare_for_onnx_export_(**kwargs)[source]

Make model exportable via ONNX trace.

set_num_updates(num_updates)[source]

State from trainer to pass along to model at every update.

upgrade_state_dict(state_dict)[source]

Upgrade old state dicts to work with newer code.

upgrade_state_dict_named(state_dict, name)[source]

Upgrade old state dicts to work with newer code.

Parameters:
  • state_dict (dict) – state dictionary to upgrade, in place
  • name (str) – the state dict key corresponding to the current module
class fairseq.models.FairseqEncoderDecoderModel(encoder, decoder)[source]

Base class for encoder-decoder models.

Parameters:
extract_features(src_tokens, src_lengths, prev_output_tokens, **kwargs)[source]

Similar to forward but only return features.

Returns:
  • the decoder’s features of shape (batch, tgt_len, embed_dim)
  • a dictionary with any model-specific outputs
Return type:tuple
forward(src_tokens, src_lengths, prev_output_tokens, **kwargs)[source]

Run the forward pass for an encoder-decoder model.

First feed a batch of source tokens through the encoder. Then, feed the encoder output and previous decoder outputs (i.e., teacher forcing) to the decoder to produce the next outputs:

encoder_out = self.encoder(src_tokens, src_lengths)
return self.decoder(prev_output_tokens, encoder_out)
Parameters:
  • src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
  • src_lengths (LongTensor) – source sentence lengths of shape (batch)
  • prev_output_tokens (LongTensor) – previous decoder outputs of shape (batch, tgt_len), for teacher forcing
Returns:

  • the decoder’s output of shape (batch, tgt_len, vocab)
  • a dictionary with any model-specific outputs

Return type:

tuple

forward_decoder(prev_output_tokens, **kwargs)[source]
max_decoder_positions()[source]

Maximum length supported by the decoder.

max_positions()[source]

Maximum length supported by the model.

output_layer(features, **kwargs)[source]

Project features to the default output size (typically vocabulary size).

class fairseq.models.FairseqEncoderModel(encoder)[source]

Base class for encoder-only models.

Parameters:encoder (FairseqEncoder) – the encoder
forward(src_tokens, src_lengths, **kwargs)[source]

Run the forward pass for a encoder-only model.

Feeds a batch of tokens through the encoder to generate features.

Parameters:
  • src_tokens (LongTensor) – input tokens of shape (batch, src_len)
  • src_lengths (LongTensor) – source sentence lengths of shape (batch)
Returns:

the encoder’s output, typically of shape (batch, src_len, features)

get_normalized_probs(net_output, log_probs, sample=None)[source]

Get normalized probabilities (or log probs) from a net’s output.

max_positions()[source]

Maximum length supported by the model.

class fairseq.models.FairseqLanguageModel(decoder)[source]

Base class for decoder-only models.

Parameters:decoder (FairseqDecoder) – the decoder
extract_features(src_tokens, **kwargs)[source]

Similar to forward but only return features.

Returns:
  • the decoder’s features of shape (batch, seq_len, embed_dim)
  • a dictionary with any model-specific outputs
Return type:tuple
forward(src_tokens, **kwargs)[source]

Run the forward pass for a decoder-only model.

Feeds a batch of tokens through the decoder to predict the next tokens.

Parameters:
  • src_tokens (LongTensor) – tokens on which to condition the decoder, of shape (batch, tgt_len)
  • src_lengths (LongTensor) – source sentence lengths of shape (batch)
Returns:

  • the decoder’s output of shape (batch, seq_len, vocab)
  • a dictionary with any model-specific outputs

Return type:

tuple

forward_decoder(prev_output_tokens, **kwargs)[source]
max_decoder_positions()[source]

Maximum length supported by the decoder.

max_positions()[source]

Maximum length supported by the model.

output_layer(features, **kwargs)[source]

Project features to the default output size (typically vocabulary size).

supported_targets
class fairseq.models.FairseqMultiModel(encoders, decoders)[source]

Base class for combining multiple encoder-decoder models.

static build_shared_embeddings(dicts: Dict[str, fairseq.data.dictionary.Dictionary], langs: List[str], embed_dim: int, build_embedding: callable, pretrained_embed_path: Optional[str] = None)[source]

Helper function to build shared embeddings for a set of languages after checking that all dicts corresponding to those languages are equivalent.

Parameters:
  • dicts – Dict of lang_id to its corresponding Dictionary
  • langs – languages that we want to share embeddings for
  • embed_dim – embedding dimension
  • build_embedding – callable function to actually build the embedding
  • pretrained_embed_path – Optional path to load pretrained embeddings
decoder
encoder
forward(src_tokens, src_lengths, prev_output_tokens, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_decoder(prev_output_tokens, **kwargs)[source]
load_state_dict(state_dict, strict=True, model_cfg=None, args: Optional[argparse.Namespace] = None)[source]

Copies parameters and buffers from state_dict into this module and its descendants.

Overrides the method in nn.Module. Compared with that method this additionally “upgrades” state_dicts from old checkpoints.

max_decoder_positions()[source]

Maximum length supported by the decoder.

max_positions()[source]

Maximum length supported by the model.

class fairseq.models.FairseqEncoder(dictionary)[source]

Base class for encoders.

forward(src_tokens, src_lengths=None, **kwargs)[source]
Parameters:
  • src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
  • src_lengths (LongTensor) – lengths of each source sentence of shape (batch)
forward_torchscript(net_input: Dict[str, torch.Tensor])[source]

A TorchScript-compatible version of forward.

Encoders which use additional arguments may want to override this method for TorchScript compatibility.

max_positions()[source]

Maximum input length supported by the encoder.

reorder_encoder_out(encoder_out, new_order)[source]

Reorder encoder output according to new_order.

Parameters:
  • encoder_out – output from the forward() method
  • new_order (LongTensor) – desired order
Returns:

encoder_out rearranged according to new_order

set_num_updates(num_updates)[source]

State from trainer to pass along to model at every update.

upgrade_state_dict_named(state_dict, name)[source]

Upgrade old state dicts to work with newer code.

class fairseq.models.CompositeEncoder(encoders)[source]

A wrapper around a dictionary of FairseqEncoder objects.

We run forward on each encoder and return a dictionary of outputs. The first encoder’s dictionary is used for initialization.

Parameters:encoders (dict) – a dictionary of FairseqEncoder objects.
forward(src_tokens, src_lengths)[source]
Parameters:
  • src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
  • src_lengths (LongTensor) – lengths of each source sentence of shape (batch)
Returns:

the outputs from each Encoder

Return type:

dict

max_positions()[source]

Maximum input length supported by the encoder.

reorder_encoder_out(encoder_out, new_order)[source]

Reorder encoder output according to new_order.

class fairseq.models.FairseqDecoder(dictionary)[source]

Base class for decoders.

extract_features(prev_output_tokens, encoder_out=None, **kwargs)[source]
Returns:
  • the decoder’s features of shape (batch, tgt_len, embed_dim)
  • a dictionary with any model-specific outputs
Return type:tuple
forward(prev_output_tokens, encoder_out=None, **kwargs)[source]
Parameters:
  • prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for teacher forcing
  • encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
Returns:

  • the decoder’s output of shape (batch, tgt_len, vocab)
  • a dictionary with any model-specific outputs

Return type:

tuple

get_normalized_probs(net_output: Tuple[torch.Tensor, Optional[Dict[str, List[Optional[torch.Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, torch.Tensor]] = None)[source]

Get normalized probabilities (or log probs) from a net’s output.

get_normalized_probs_scriptable(net_output: Tuple[torch.Tensor, Optional[Dict[str, List[Optional[torch.Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, torch.Tensor]] = None)[source]

Get normalized probabilities (or log probs) from a net’s output.

max_positions()[source]

Maximum input length supported by the decoder.

output_layer(features, **kwargs)[source]

Project features to the default output size, e.g., vocabulary size.

Parameters:features (Tensor) – features returned by extract_features.
upgrade_state_dict_named(state_dict, name)[source]

Upgrade old state dicts to work with newer code.

Incremental decoding

class fairseq.models.FairseqIncrementalDecoder(dictionary)[source]

Base class for incremental decoders.

Incremental decoding is a special mode at inference time where the Model only receives a single timestep of input corresponding to the previous output token (for teacher forcing) and must produce the next output incrementally. Thus the model must cache any long-term state that is needed about the sequence, e.g., hidden states, convolutional states, etc.

Compared to the standard FairseqDecoder interface, the incremental decoder interface allows forward() functions to take an extra keyword argument (incremental_state) that can be used to cache state across time-steps.

The FairseqIncrementalDecoder interface also defines the reorder_incremental_state() method, which is used during beam search to select and reorder the incremental state based on the selection of beams.

To learn more about how incremental decoding works, refer to this blog.

extract_features(prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs)[source]
Returns:
  • the decoder’s features of shape (batch, tgt_len, embed_dim)
  • a dictionary with any model-specific outputs
Return type:tuple
forward(prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs)[source]
Parameters:
  • prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for teacher forcing
  • encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
  • incremental_state (dict, optional) – dictionary used for storing state during Incremental decoding
Returns:

  • the decoder’s output of shape (batch, tgt_len, vocab)
  • a dictionary with any model-specific outputs

Return type:

tuple

reorder_incremental_state(incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]], new_order: torch.Tensor)[source]

Reorder incremental state.

This will be called when the order of the input has changed from the previous time step. A typical use case is beam search, where the input order changes between time steps based on the selection of beams.

reorder_incremental_state_scripting(incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]], new_order: torch.Tensor)[source]

Main entry point for reordering the incremental state.

Due to limitations in TorchScript, we call this function in fairseq.sequence_generator.SequenceGenerator instead of calling reorder_incremental_state() directly.

set_beam_size(beam_size)[source]

Sets the beam size in the decoder and all children.