Models¶
A Model defines the neural network’s forward()
method and encapsulates all
of the learnable parameters in the network. Each model also provides a set of
named architectures that define the precise network configuration (e.g.,
embedding dimension, number of layers, etc.).
Both the model type and architecture are selected via the --arch
command-line argument. Once selected, a model may expose additional command-line
arguments for further configuration.
Note
All fairseq Models extend BaseFairseqModel
, which in turn extends
torch.nn.Module
. Thus any fairseq Model can be used as a
stand-alone Module in other PyTorch code.
Convolutional Neural Networks (CNN)¶
-
class
fairseq.models.fconv.
FConvModel
(encoder, decoder)[source]¶ A fully convolutional model, i.e. a convolutional encoder and a convolutional decoder, as described in “Convolutional Sequence to Sequence Learning” (Gehring et al., 2017).
Parameters: - encoder (FConvEncoder) – the encoder
- decoder (FConvDecoder) – the decoder
The Convolutional model provides the following named architectures and command-line arguments:
usage: [--arch {fconv,fconv_iwslt_de_en,fconv_wmt_en_ro,fconv_wmt_en_de,fconv_wmt_en_fr}] [--dropout D] [--encoder-embed-dim N] [--encoder-embed-path STR] [--encoder-layers EXPR] [--decoder-embed-dim N] [--decoder-embed-path STR] [--decoder-layers EXPR] [--decoder-out-embed-dim N] [--decoder-attention EXPR] [--share-input-output-embed]
Named architectures¶
--arch Possible choices: fconv, fconv_iwslt_de_en, fconv_wmt_en_ro, fconv_wmt_en_de, fconv_wmt_en_fr Additional command-line arguments¶
--dropout dropout probability --encoder-embed-dim encoder embedding dimension --encoder-embed-path path to pre-trained encoder embedding --encoder-layers encoder layers [(dim, kernel_size), …] --decoder-embed-dim decoder embedding dimension --decoder-embed-path path to pre-trained decoder embedding --decoder-layers decoder layers [(dim, kernel_size), …] --decoder-out-embed-dim decoder output embedding dimension --decoder-attention decoder attention [True, …] --share-input-output-embed share input and output embeddings (requires –decoder-out-embed-dim and –decoder-embed-dim to be equal)
Default: False
-
class
fairseq.models.fconv.
FConvEncoder
(dictionary, embed_dim=512, embed_dict=None, max_positions=1024, convolutions=((512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3)), dropout=0.1)[source]¶ Convolutional encoder consisting of len(convolutions) layers.
Parameters: - dictionary (Dictionary) – encoding dictionary
- embed_dim (int, optional) – embedding dimension
- embed_dict (str, optional) – filename from which to load pre-trained embeddings
- max_positions (int, optional) – maximum supported input sequence length
- convolutions (list, optional) – the convolutional layer structure. Each
list item i corresponds to convolutional layer i. Layers are
given as
(out_channels, kernel_width, [residual])
. Residual connections are added between layers whenresidual=1
(which is the default behavior). - dropout (float, optional) – dropout to be applied before each conv layer
-
forward
(src_tokens, src_lengths)[source]¶ Parameters: - src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
- src_lengths (LongTensor) – lengths of each source sentence of shape (batch)
Returns: - encoder_out (tuple): a tuple with two elements, where the first element is the last encoder layer’s output and the second element is the same quantity summed with the input embedding (used for attention). The shape of both tensors is (batch, src_len, embed_dim).
- encoder_padding_mask (ByteTensor): the positions of padding elements of shape (batch, src_len)
Return type:
-
class
fairseq.models.fconv.
FConvDecoder
(dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256, max_positions=1024, convolutions=((512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3), (512, 3)), attention=True, dropout=0.1, share_embed=False, positional_embeddings=True, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0.0)[source]¶ Convolutional decoder
-
forward
(prev_output_tokens, encoder_out=None, incremental_state=None, **unused)[source]¶ Parameters: - prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for teacher forcing
- encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
- incremental_state (dict, optional) – dictionary used for storing state during Incremental decoding
Returns: - the decoder’s output of shape (batch, tgt_len, vocab)
- a dictionary with any model-specific outputs
Return type:
-
Long Short-Term Memory (LSTM) networks¶
-
class
fairseq.models.lstm.
LSTMModel
(encoder, decoder)[source]¶ -
-
forward
(src_tokens, src_lengths, prev_output_tokens, incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] = None)[source]¶ Run the forward pass for an encoder-decoder model.
First feed a batch of source tokens through the encoder. Then, feed the encoder output and previous decoder outputs (i.e., teacher forcing) to the decoder to produce the next outputs:
encoder_out = self.encoder(src_tokens, src_lengths) return self.decoder(prev_output_tokens, encoder_out)
Parameters: - src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
- src_lengths (LongTensor) – source sentence lengths of shape (batch)
- prev_output_tokens (LongTensor) – previous decoder outputs of shape (batch, tgt_len), for teacher forcing
Returns: - the decoder’s output of shape (batch, tgt_len, vocab)
- a dictionary with any model-specific outputs
Return type:
-
-
class
fairseq.models.lstm.
LSTMEncoder
(dictionary, embed_dim=512, hidden_size=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, left_pad=True, pretrained_embed=None, padding_idx=None, max_source_positions=100000.0)[source]¶ LSTM encoder.
-
forward
(src_tokens: torch.Tensor, src_lengths: torch.Tensor, enforce_sorted: bool = True)[source]¶ Parameters: - src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
- src_lengths (LongTensor) – lengths of each source sentence of shape (batch)
- enforce_sorted (bool, optional) – if True, src_tokens is expected to contain sequences sorted by length in a decreasing order. If False, this condition is not required. Default: True.
-
-
class
fairseq.models.lstm.
LSTMDecoder
(dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True, encoder_output_units=512, pretrained_embed=None, share_input_output_embed=False, adaptive_softmax_cutoff=None, max_target_positions=100000.0, residuals=False)[source]¶ LSTM decoder.
-
extract_features
(prev_output_tokens, encoder_out: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] = None)[source]¶ Similar to forward but only return features.
-
forward
(prev_output_tokens, encoder_out: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] = None, src_lengths: Optional[torch.Tensor] = None)[source]¶ Parameters: - prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for teacher forcing
- encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
- incremental_state (dict, optional) – dictionary used for storing state during Incremental decoding
Returns: - the decoder’s output of shape (batch, tgt_len, vocab)
- a dictionary with any model-specific outputs
Return type:
-
reorder_incremental_state
(incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]], new_order: torch.Tensor)[source]¶ Reorder incremental state.
This will be called when the order of the input has changed from the previous time step. A typical use case is beam search, where the input order changes between time steps based on the selection of beams.
-
Transformer (self-attention) networks¶
-
class
fairseq.models.transformer.
TransformerModel
(args, encoder, decoder)[source]¶ Transformer model from “Attention Is All You Need” (Vaswani, et al, 2017).
Parameters: - encoder (TransformerEncoder) – the encoder
- decoder (TransformerDecoder) – the decoder
The Transformer model provides the following named architectures and command-line arguments:
usage: [--arch {transformer,transformer_iwslt_de_en,transformer_wmt_en_de,transformer_vaswani_wmt_en_de_big,transformer_vaswani_wmt_en_fr_big,transformer_wmt_en_de_big,transformer_wmt_en_de_big_t2t}] [--activation-fn {relu,gelu,gelu_fast,gelu_accurate,tanh,linear}] [--dropout D] [--attention-dropout D] [--activation-dropout D] [--encoder-embed-path STR] [--encoder-embed-dim N] [--encoder-ffn-embed-dim N] [--encoder-layers N] [--encoder-attention-heads N] [--encoder-normalize-before] [--encoder-learned-pos] [--decoder-embed-path STR] [--decoder-embed-dim N] [--decoder-ffn-embed-dim N] [--decoder-layers N] [--decoder-attention-heads N] [--decoder-learned-pos] [--decoder-normalize-before] [--decoder-output-dim N] [--share-decoder-input-output-embed] [--share-all-embeddings] [--no-token-positional-embeddings] [--adaptive-softmax-cutoff EXPR] [--adaptive-softmax-dropout D] [--layernorm-embedding] [--no-scale-embedding] [--no-cross-attention] [--cross-self-attention] [--encoder-layerdrop D] [--decoder-layerdrop D] [--encoder-layers-to-keep ENCODER_LAYERS_TO_KEEP] [--decoder-layers-to-keep DECODER_LAYERS_TO_KEEP] [--quant-noise-pq D] [--quant-noise-pq-block-size D] [--quant-noise-scalar D]
Named architectures¶
--arch Possible choices: transformer, transformer_iwslt_de_en, transformer_wmt_en_de, transformer_vaswani_wmt_en_de_big, transformer_vaswani_wmt_en_fr_big, transformer_wmt_en_de_big, transformer_wmt_en_de_big_t2t Additional command-line arguments¶
--activation-fn Possible choices: relu, gelu, gelu_fast, gelu_accurate, tanh, linear
activation function to use
--dropout dropout probability --attention-dropout dropout probability for attention weights --activation-dropout, --relu-dropout dropout probability after activation in FFN. --encoder-embed-path path to pre-trained encoder embedding --encoder-embed-dim encoder embedding dimension --encoder-ffn-embed-dim encoder embedding dimension for FFN --encoder-layers num encoder layers --encoder-attention-heads num encoder attention heads --encoder-normalize-before apply layernorm before each encoder block
Default: False
--encoder-learned-pos use learned positional embeddings in the encoder
Default: False
--decoder-embed-path path to pre-trained decoder embedding --decoder-embed-dim decoder embedding dimension --decoder-ffn-embed-dim decoder embedding dimension for FFN --decoder-layers num decoder layers --decoder-attention-heads num decoder attention heads --decoder-learned-pos use learned positional embeddings in the decoder
Default: False
--decoder-normalize-before apply layernorm before each decoder block
Default: False
--decoder-output-dim decoder output dimension (extra linear layer if different from decoder embed dim --share-decoder-input-output-embed share decoder input and output embeddings
Default: False
--share-all-embeddings share encoder, decoder and output embeddings (requires shared dictionary and embed dim)
Default: False
--no-token-positional-embeddings if set, disables positional embeddings (outside self attention)
Default: False
--adaptive-softmax-cutoff comma separated list of adaptive softmax cutoff points. Must be used with adaptive_loss criterion --adaptive-softmax-dropout sets adaptive softmax dropout for the tail projections --layernorm-embedding add layernorm to embedding
Default: False
--no-scale-embedding if True, dont scale embeddings
Default: False
--no-cross-attention do not perform cross-attention
Default: False
--cross-self-attention perform cross+self-attention
Default: False
--encoder-layerdrop LayerDrop probability for encoder
Default: 0
--decoder-layerdrop LayerDrop probability for decoder
Default: 0
--encoder-layers-to-keep which layers to keep when pruning as a comma-separated list --decoder-layers-to-keep which layers to keep when pruning as a comma-separated list --quant-noise-pq iterative PQ quantization noise at training time
Default: 0
--quant-noise-pq-block-size block size of quantization noise at training time
Default: 8
--quant-noise-scalar scalar quantization noise and scalar quantization at training time
Default: 0
-
forward
(src_tokens, src_lengths, prev_output_tokens, return_all_hiddens: bool = True, features_only: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None)[source]¶ Run the forward pass for an encoder-decoder model.
Copied from the base class, but without
**kwargs
, which are not supported by TorchScript.
-
class
fairseq.models.transformer.
TransformerEncoder
(args, dictionary, embed_tokens)[source]¶ Transformer encoder consisting of args.encoder_layers layers. Each layer is a
TransformerEncoderLayer
.Parameters: - args (argparse.Namespace) – parsed command-line arguments
- dictionary (Dictionary) – encoding dictionary
- embed_tokens (torch.nn.Embedding) – input embedding
-
forward
(src_tokens, src_lengths, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None)[source]¶ Parameters: - src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
- src_lengths (torch.LongTensor) – lengths of each source sentence of shape (batch)
- return_all_hiddens (bool, optional) – also return all of the intermediate hidden states (default: False).
- token_embeddings (torch.Tensor, optional) – precomputed embeddings default None will recompute embeddings
Returns: - encoder_out (Tensor): the last encoder layer’s output of shape (src_len, batch, embed_dim)
- encoder_padding_mask (ByteTensor): the positions of padding elements of shape (batch, src_len)
- encoder_embedding (Tensor): the (scaled) embedding lookup of shape (batch, src_len, embed_dim)
- encoder_states (List[Tensor]): all intermediate hidden states of shape (src_len, batch, embed_dim). Only populated if return_all_hiddens is True.
Return type: namedtuple
-
class
fairseq.models.transformer.
TransformerEncoderLayer
(args)[source]¶ Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is postprocessed with: dropout -> add residual -> layernorm. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: dropout -> add residual. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting args.encoder_normalize_before to
True
.Parameters: args (argparse.Namespace) – parsed command-line arguments -
forward
(x, encoder_padding_mask, attn_mask: Optional[torch.Tensor] = None)[source]¶ Parameters: - x (Tensor) – input to the layer of shape (seq_len, batch, embed_dim)
- encoder_padding_mask (ByteTensor) – binary ByteTensor of shape
(batch, seq_len) where padding elements are indicated by
1
. - attn_mask (ByteTensor) – binary tensor of shape (tgt_len, src_len), where tgt_len is the length of output and src_len is the length of input, though here both are equal to seq_len. attn_mask[tgt_i, src_j] = 1 means that when calculating the embedding for tgt_i, we exclude (mask out) src_j. This is useful for strided self-attention.
Returns: encoded output of shape (seq_len, batch, embed_dim)
-
-
class
fairseq.models.transformer.
TransformerDecoder
(args, dictionary, embed_tokens, no_encoder_attn=False)[source]¶ Transformer decoder consisting of args.decoder_layers layers. Each layer is a
TransformerDecoderLayer
.Parameters: - args (argparse.Namespace) – parsed command-line arguments
- dictionary (Dictionary) – decoding dictionary
- embed_tokens (torch.nn.Embedding) – output embedding
- no_encoder_attn (bool, optional) – whether to attend to encoder outputs (default: False).
-
extract_features
(prev_output_tokens, encoder_out: Optional[fairseq.models.fairseq_encoder.EncoderOut] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None)[source]¶ Returns: tuple:
- the decoder’s features of shape (batch, tgt_len, embed_dim)
- a dictionary with any model-specific outputs
-
extract_features_scriptable
(prev_output_tokens, encoder_out: Optional[fairseq.models.fairseq_encoder.EncoderOut] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None)[source]¶ Similar to forward but only return features.
Includes several features from “Jointly Learning to Align and Translate with Transformer Models” (Garg et al., EMNLP 2019).
Parameters: - full_context_alignment (bool, optional) – don’t apply auto-regressive mask to self-attention (default: False).
- alignment_layer (int, optional) – return mean alignment over heads at this layer (default: last layer).
- alignment_heads (int, optional) – only average alignment over this many heads (default: all heads).
Returns: - the decoder’s features of shape (batch, tgt_len, embed_dim)
- a dictionary with any model-specific outputs
Return type:
-
forward
(prev_output_tokens, encoder_out: Optional[fairseq.models.fairseq_encoder.EncoderOut] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] = None, features_only: bool = False, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, src_lengths: Optional[Any] = None, return_all_hiddens: bool = False)[source]¶ Parameters: - prev_output_tokens (LongTensor) – previous decoder outputs of shape (batch, tgt_len), for teacher forcing
- encoder_out (optional) – output from the encoder, used for encoder-side attention
- incremental_state (dict) – dictionary used for storing state during Incremental decoding
- features_only (bool, optional) – only return features without applying output layer (default: False).
- full_context_alignment (bool, optional) – don’t apply auto-regressive mask to self-attention (default: False).
Returns: - the decoder’s output of shape (batch, tgt_len, vocab)
- a dictionary with any model-specific outputs
Return type:
-
class
fairseq.models.transformer.
TransformerDecoderLayer
(args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False)[source]¶ Decoder layer block.
In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: dropout -> add residual -> layernorm. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: dropout -> add residual. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting args.decoder_normalize_before to
True
.Parameters: - args (argparse.Namespace) – parsed command-line arguments
- no_encoder_attn (bool, optional) – whether to attend to encoder outputs (default: False).
-
forward
(x, encoder_out: Optional[torch.Tensor] = None, encoder_padding_mask: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] = None, prev_self_attn_state: Optional[List[torch.Tensor]] = None, prev_attn_state: Optional[List[torch.Tensor]] = None, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, need_attn: bool = False, need_head_weights: bool = False)[source]¶ Parameters: - x (Tensor) – input to the layer of shape (seq_len, batch, embed_dim)
- encoder_padding_mask (ByteTensor, optional) – binary
ByteTensor of shape (batch, src_len) where padding
elements are indicated by
1
. - need_attn (bool, optional) – return attention weights
- need_head_weights (bool, optional) – return attention weights for each head (default: return average over heads).
Returns: encoded output of shape (seq_len, batch, embed_dim)
Adding new models¶
-
fairseq.models.
register_model
(name, dataclass=None)[source]¶ New model types can be added to fairseq with the
register_model()
function decorator.For example:
@register_model('lstm') class LSTM(FairseqEncoderDecoderModel): (...)
Note
All models must implement the
BaseFairseqModel
interface. Typically you will extendFairseqEncoderDecoderModel
for sequence-to-sequence tasks orFairseqLanguageModel
for language modeling tasks.Parameters: name (str) – the name of the model
-
fairseq.models.
register_model_architecture
(model_name, arch_name)[source]¶ New model architectures can be added to fairseq with the
register_model_architecture()
function decorator. After registration, model architectures can be selected with the--arch
command-line argument.For example:
@register_model_architecture('lstm', 'lstm_luong_wmt_en_de') def lstm_luong_wmt_en_de(args): args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000) (...)
The decorated function should take a single argument args, which is a
argparse.Namespace
of arguments parsed from the command-line. The decorated function should modify these arguments in-place to match the desired architecture.Parameters:
-
class
fairseq.models.
BaseFairseqModel
[source]¶ Base class for fairseq models.
-
classmethod
from_pretrained
(model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', **kwargs)[source]¶ Load a
FairseqModel
from a pre-trained model file. Downloads and caches the pre-trained model file if needed.The base implementation returns a
GeneratorHubInterface
, which can be used to generate translations or sample from language models. The underlyingFairseqModel
can be accessed via the generator.models attribute.Other models may override this to implement custom hub interfaces.
Parameters: - model_name_or_path (str) – either the name of a pre-trained model to load or a path/URL to a pre-trained model state dict
- checkpoint_file (str, optional) – colon-separated list of checkpoint files in the model archive to ensemble (default: ‘model.pt’)
- data_name_or_path (str, optional) – point args.data to the archive at the given path/URL. Can start with ‘.’ or ‘./’ to reuse the model archive path.
-
get_normalized_probs
(net_output: Tuple[torch.Tensor, Optional[Dict[str, List[Optional[torch.Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, torch.Tensor]] = None)[source]¶ Get normalized probabilities (or log probs) from a net’s output.
-
get_normalized_probs_scriptable
(net_output: Tuple[torch.Tensor, Optional[Dict[str, List[Optional[torch.Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, torch.Tensor]] = None)[source]¶ Scriptable helper function for get_normalized_probs in ~BaseFairseqModel
-
load_state_dict
(state_dict, strict=True, args=None)[source]¶ Copies parameters and buffers from state_dict into this module and its descendants.
Overrides the method in
nn.Module
. Compared with that method this additionally “upgrades” state_dicts from old checkpoints.
-
make_generation_fast_
(**kwargs)[source]¶ Legacy entry point to optimize model for faster generation. Prefer prepare_for_inference_.
-
classmethod
-
class
fairseq.models.
FairseqEncoderDecoderModel
(encoder, decoder)[source]¶ Base class for encoder-decoder models.
Parameters: - encoder (FairseqEncoder) – the encoder
- decoder (FairseqDecoder) – the decoder
-
extract_features
(src_tokens, src_lengths, prev_output_tokens, **kwargs)[source]¶ Similar to forward but only return features.
Returns: - the decoder’s features of shape (batch, tgt_len, embed_dim)
- a dictionary with any model-specific outputs
Return type: tuple
-
forward
(src_tokens, src_lengths, prev_output_tokens, **kwargs)[source]¶ Run the forward pass for an encoder-decoder model.
First feed a batch of source tokens through the encoder. Then, feed the encoder output and previous decoder outputs (i.e., teacher forcing) to the decoder to produce the next outputs:
encoder_out = self.encoder(src_tokens, src_lengths) return self.decoder(prev_output_tokens, encoder_out)
Parameters: - src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
- src_lengths (LongTensor) – source sentence lengths of shape (batch)
- prev_output_tokens (LongTensor) – previous decoder outputs of shape (batch, tgt_len), for teacher forcing
Returns: - the decoder’s output of shape (batch, tgt_len, vocab)
- a dictionary with any model-specific outputs
Return type:
-
class
fairseq.models.
FairseqEncoderModel
(encoder)[source]¶ Base class for encoder-only models.
Parameters: encoder (FairseqEncoder) – the encoder -
forward
(src_tokens, src_lengths, **kwargs)[source]¶ Run the forward pass for a encoder-only model.
Feeds a batch of tokens through the encoder to generate features.
Parameters: - src_tokens (LongTensor) – input tokens of shape (batch, src_len)
- src_lengths (LongTensor) – source sentence lengths of shape (batch)
Returns: the encoder’s output, typically of shape (batch, src_len, features)
-
-
class
fairseq.models.
FairseqLanguageModel
(decoder)[source]¶ Base class for decoder-only models.
Parameters: decoder (FairseqDecoder) – the decoder -
extract_features
(src_tokens, **kwargs)[source]¶ Similar to forward but only return features.
Returns: - the decoder’s features of shape (batch, seq_len, embed_dim)
- a dictionary with any model-specific outputs
Return type: tuple
-
forward
(src_tokens, **kwargs)[source]¶ Run the forward pass for a decoder-only model.
Feeds a batch of tokens through the decoder to predict the next tokens.
Parameters: - src_tokens (LongTensor) – tokens on which to condition the decoder, of shape (batch, tgt_len)
- src_lengths (LongTensor) – source sentence lengths of shape (batch)
Returns: - the decoder’s output of shape (batch, seq_len, vocab)
- a dictionary with any model-specific outputs
Return type:
-
output_layer
(features, **kwargs)[source]¶ Project features to the default output size (typically vocabulary size).
-
supported_targets
¶
-
-
class
fairseq.models.
FairseqMultiModel
(encoders, decoders)[source]¶ Base class for combining multiple encoder-decoder models.
Helper function to build shared embeddings for a set of languages after checking that all dicts corresponding to those languages are equivalent.
Parameters: - dicts – Dict of lang_id to its corresponding Dictionary
- langs – languages that we want to share embeddings for
- embed_dim – embedding dimension
- build_embedding – callable function to actually build the embedding
- pretrained_embed_path – Optional path to load pretrained embeddings
-
decoder
¶
-
encoder
¶
-
forward
(src_tokens, src_lengths, prev_output_tokens, **kwargs)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
class
fairseq.models.
FairseqEncoder
(dictionary)[source]¶ Base class for encoders.
-
forward
(src_tokens, src_lengths=None, **kwargs)[source]¶ Parameters: - src_tokens (LongTensor) – tokens in the source language of shape (batch, src_len)
- src_lengths (LongTensor) – lengths of each source sentence of shape (batch)
-
forward_torchscript
(net_input: Dict[str, torch.Tensor])[source]¶ A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override this method for TorchScript compatibility.
-
-
class
fairseq.models.
CompositeEncoder
(encoders)[source]¶ A wrapper around a dictionary of
FairseqEncoder
objects.We run forward on each encoder and return a dictionary of outputs. The first encoder’s dictionary is used for initialization.
Parameters: encoders (dict) – a dictionary of FairseqEncoder
objects.
-
class
fairseq.models.
FairseqDecoder
(dictionary)[source]¶ Base class for decoders.
-
extract_features
(prev_output_tokens, encoder_out=None, **kwargs)[source]¶ Returns: - the decoder’s features of shape (batch, tgt_len, embed_dim)
- a dictionary with any model-specific outputs
Return type: tuple
-
forward
(prev_output_tokens, encoder_out=None, **kwargs)[source]¶ Parameters: - prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for teacher forcing
- encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
Returns: - the decoder’s output of shape (batch, tgt_len, vocab)
- a dictionary with any model-specific outputs
Return type:
-
get_normalized_probs
(net_output: Tuple[torch.Tensor, Optional[Dict[str, List[Optional[torch.Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, torch.Tensor]] = None)[source]¶ Get normalized probabilities (or log probs) from a net’s output.
-
Incremental decoding¶
-
class
fairseq.models.
FairseqIncrementalDecoder
(dictionary)[source]¶ Base class for incremental decoders.
Incremental decoding is a special mode at inference time where the Model only receives a single timestep of input corresponding to the previous output token (for teacher forcing) and must produce the next output incrementally. Thus the model must cache any long-term state that is needed about the sequence, e.g., hidden states, convolutional states, etc.
Compared to the standard
FairseqDecoder
interface, the incremental decoder interface allowsforward()
functions to take an extra keyword argument (incremental_state) that can be used to cache state across time-steps.The
FairseqIncrementalDecoder
interface also defines thereorder_incremental_state()
method, which is used during beam search to select and reorder the incremental state based on the selection of beams.To learn more about how incremental decoding works, refer to this blog.
-
extract_features
(prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs)[source]¶ Returns: - the decoder’s features of shape (batch, tgt_len, embed_dim)
- a dictionary with any model-specific outputs
Return type: tuple
-
forward
(prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs)[source]¶ Parameters: - prev_output_tokens (LongTensor) – shifted output tokens of shape (batch, tgt_len), for teacher forcing
- encoder_out (dict, optional) – output from the encoder, used for encoder-side attention
- incremental_state (dict, optional) – dictionary used for storing state during Incremental decoding
Returns: - the decoder’s output of shape (batch, tgt_len, vocab)
- a dictionary with any model-specific outputs
Return type:
-
reorder_incremental_state
(incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]], new_order: torch.Tensor)[source]¶ Reorder incremental state.
This will be called when the order of the input has changed from the previous time step. A typical use case is beam search, where the input order changes between time steps based on the selection of beams.
-
reorder_incremental_state_scripting
(incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]], new_order: torch.Tensor)[source]¶ Main entry point for reordering the incremental state.
Due to limitations in TorchScript, we call this function in
fairseq.sequence_generator.SequenceGenerator
instead of callingreorder_incremental_state()
directly.
-