Modules

Fairseq provides several stand-alone torch.nn.Module classes that may be helpful when implementing a new BaseFairseqModel.

class fairseq.modules.AdaptiveInput(vocab_size: int, padding_idx: int, initial_dim: int, factor: float, output_dim: int, cutoff: List[int])[source]
forward(input: torch.Tensor)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

weights_for_band(band: int)[source]
class fairseq.modules.AdaptiveSoftmax(vocab_size, input_dim, cutoff, dropout, factor=4.0, adaptive_inputs=None, tie_proj=False)[source]

This is an implementation of the efficient softmax approximation for graphical processing units (GPU), described in the paper “Efficient softmax approximation for GPUs” (http://arxiv.org/abs/1609.04309).

adapt_target(target)[source]

In order to be efficient, the AdaptiveSoftMax does not compute the scores for all the word of the vocabulary for all the examples. It is thus necessary to call the method adapt_target of the AdaptiveSoftMax layer inside each forward pass.

forward(input, target)[source]
Parameters:
  • input – (b x t x d)
  • target – (b x t)
Returns:

output for each cutoff section and new targets by cut off

Return type:

2 lists

get_log_prob(input, target)[source]

Computes the log probabilities for all the words of the vocabulary, given a 2D tensor of hidden vectors.

upgrade_state_dict_named(state_dict, name)[source]
class fairseq.modules.BeamableMM(beam_size=None)[source]

This module provides an optimized MM for beam decoding with attention.

It leverage the fact that the source-side of the input is replicated beam times and the target-side of the input is of width one. This layer speeds up inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.

forward(input1, input2)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

set_beam_size(beam_size)[source]
class fairseq.modules.CharacterTokenEmbedder(vocab: fairseq.data.dictionary.Dictionary, filters: List[Tuple[int, int]], char_embed_dim: int, word_embed_dim: int, highway_layers: int, max_char_len: int = 50, char_inputs: bool = False)[source]
forward(input: torch.Tensor)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

padding_idx
prepare_for_onnx_export_()[source]
reset_parameters()[source]
set_vocab(vocab, max_char_len)[source]
class fairseq.modules.ConvTBC(in_channels, out_channels, kernel_size, padding=0)[source]

1D convolution over an input of shape (time x batch x channel)

The implementation uses gemm to perform the convolution. This implementation is faster than cuDNN for small kernel sizes.

forward(input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fairseq.modules.DownsampledMultiHeadAttention(out_channels, embed_dim, num_heads, dropout=0.0, bias=True, project_input=True, gated=False, downsample=False)[source]

Multi-headed attention with Gating and Downsampling

forward(query, key, value, mask_future_timesteps=False, key_padding_mask=None, use_scalar_bias=False)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fairseq.modules.DynamicConv1dTBC(input_size, kernel_size=1, padding_l=None, num_heads=1, weight_dropout=0.0, weight_softmax=False, renorm_padding=False, bias=False, conv_bias=False, query_size=None, in_proj=False)[source]

Dynamic lightweight convolution taking T x B x C inputs :param input_size: # of channels of the input :param kernel_size: convolution channels :param padding_l: padding to the left when using “same” padding :param num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) :param weight_dropout: the drop rate of the DropConnect to drop the weight :param weight_softmax: normalize the weight with softmax before the convolution :param renorm_padding: re-normalize the filters to ignore the padded part (only the non-padding parts sum up to 1) :param bias: use bias :param conv_bias: bias of the convolution :param query_size: specified when feeding a different input as the query :param in_proj: project the input and generate the filter together

Shape:
Input: TxBxC, i.e. (timesteps, batch_size, input_size) Output: TxBxC, i.e. (timesteps, batch_size, input_size)
weight

the learnable weights of the module of shape (num_heads, 1, kernel_size)

bias

the learnable bias of the module of shape (input_size)

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x, incremental_state=None, query=None, unfold=None)[source]

Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C :param x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) :param incremental_state: A dict to keep the state :param unfold: unfold the input or not. If not, we use the matrix trick instead :param query: use the specified query to predict the conv filters

in_proj
reorder_incremental_state(incremental_state, new_order)[source]
reset_parameters()[source]
fairseq.modules.gelu(x: torch.Tensor) → torch.Tensor[source]
fairseq.modules.gelu_accurate(x)[source]
class fairseq.modules.GradMultiply[source]
static backward(ctx, grad)[source]

Defines a formula for differentiating the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs did forward() return, and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computated w.r.t. the output.

static forward(ctx, x, scale)[source]

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store tensors that can be then retrieved during the backward pass.

class fairseq.modules.Highway(input_dim: int, num_layers: int = 1)[source]

A Highway layer. Adopted from the AllenNLP implementation.

forward(x: torch.Tensor)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters()[source]
fairseq.modules.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, export=False)[source]
class fairseq.modules.LearnedPositionalEmbedding(num_embeddings: int, embedding_dim: int, padding_idx: int)[source]

This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to the forward function.

forward(input, incremental_state=None, positions=None)[source]

Input is expected to be of size [bsz x seqlen].

max_positions()[source]

Maximum number of supported positions.

class fairseq.modules.LightweightConv1dTBC(input_size, kernel_size=1, padding_l=None, num_heads=1, weight_dropout=0.0, weight_softmax=False, bias=False)[source]

Lightweight Convolution assuming the input is TxBxC :param input_size: # of channels of the input :param kernel_size: convolution channels :param padding_l: padding to the left when using “same” padding :param num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) :param weight_dropout: the drop rate of the DropConnect to drop the weight :param weight_softmax: normalize the weight with softmax before the convolution :param bias: use bias

Shape:
Input: TxBxC, i.e. (timesteps, batch_size, input_size) Output: TxBxC, i.e. (timesteps, batch_size, input_size)
weight

the learnable weights of the module of shape (num_heads, 1, kernel_size)

bias

the learnable bias of the module of shape (input_size)

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x, incremental_state=None, unfold=False)[source]

Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C :param x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) :param incremental_state: A dict to keep the state :param unfold: unfold the input or not. If not, we use the matrix trick instead

prepare_for_onnx_export_()[source]
reorder_incremental_state(incremental_state, new_order)[source]
reset_parameters()[source]
class fairseq.modules.LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)[source]

An optimized version of nn.Conv1d.

At training time, this module uses ConvTBC, which is an optimized version of Conv1d. At inference time, it optimizes incremental generation (i.e., one time step at a time) by replacing the convolutions with linear layers. Note that the input order changes from training to inference.

forward(input, incremental_state=None)[source]
Parameters:incremental_state – Used to buffer signal; if not None, then input is expected to contain a single frame. If the input order changes between time steps, call reorder_incremental_state.
Input:
Time x Batch x Channel during training Batch x Time x Channel during inference
reorder_incremental_state(incremental_state, new_order)[source]
class fairseq.modules.LogSumExpMoE[source]

Standard LogSumExp forward pass, but use posterior for the backward.

See “Mixture Models for Diverse Machine Translation: Tricks of the Trade” (Shen et al., 2019).

static backward(ctx, grad_output)[source]

Defines a formula for differentiating the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs did forward() return, and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computated w.r.t. the output.

static forward(ctx, logp, posterior, dim=-1)[source]

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store tensors that can be then retrieved during the backward pass.

class fairseq.modules.MeanPoolGatingNetwork(embed_dim, num_experts, dropout=None)[source]

A simple mean-pooling gating network for selecting experts.

This module applies mean pooling over an encoder’s output and returns reponsibilities for each expert. The encoder format is expected to match fairseq.models.transformer.TransformerEncoder.

forward(encoder_out)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fairseq.modules.MultiheadAttention(embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False)[source]

Multi-headed attention.

See “Attention Is All You Need” for more details.

forward(query, key, value, key_padding_mask=None, incremental_state=None, need_weights=True, static_kv=False, attn_mask=None)[source]

Input shape: Time x Batch x Channel

Timesteps can be masked by supplying a T x T mask in the attn_mask argument. Padding elements can be excluded from the key by passing a binary ByteTensor (key_padding_mask) with shape: batch x src_len, where padding elements are indicated by 1s.

in_proj_k(key)[source]
in_proj_q(query)[source]
in_proj_qkv(query)[source]
in_proj_v(value)[source]
prepare_for_onnx_export_()[source]
reorder_incremental_state(incremental_state, new_order)[source]

Reorder buffered internal state (for incremental generation).

reset_parameters()[source]
fairseq.modules.PositionalEmbedding(num_embeddings: int, embedding_dim: int, padding_idx: int, learned: bool = False)[source]
class fairseq.modules.ScalarBias[source]

Adds a vector of scalars, used in self-attention mechanism to allow the model to optionally attend to this vector instead of the past

static backward(ctx, grad)[source]

Defines a formula for differentiating the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs did forward() return, and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computated w.r.t. the output.

static forward(ctx, input, dim, bias_init)[source]

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store tensors that can be then retrieved during the backward pass.

class fairseq.modules.SinusoidalPositionalEmbedding(embedding_dim, padding_idx, init_size=1024)[source]

This module produces sinusoidal positional embeddings of any length.

Padding symbols are ignored.

forward(input, incremental_state=None, timestep=None, **kwargs)[source]

Input is expected to be of size [bsz x seqlen].

static get_embedding(num_embeddings, embedding_dim, padding_idx=None)[source]

Build sinusoidal embeddings.

This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of “Attention Is All You Need”.

max_positions()[source]

Maximum number of supported positions.

prepare_for_onnx_export_()[source]
class fairseq.modules.TransformerSentenceEncoderLayer(embedding_dim: float = 768, ffn_embedding_dim: float = 3072, num_attention_heads: float = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, activation_fn: str = 'relu', add_bias_kv: bool = False, add_zero_attn: bool = False, export: bool = False)[source]

Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained models.

forward(x: torch.Tensor, self_attn_mask: torch.Tensor = None, self_attn_padding_mask: torch.Tensor = None)[source]

LayerNorm is applied either before or after the self-attention/ffn modules similar to the original Transformer imlementation.

class fairseq.modules.TransformerSentenceEncoder(padding_idx: int, vocab_size: int, num_encoder_layers: int = 6, embedding_dim: int = 768, ffn_embedding_dim: int = 3072, num_attention_heads: int = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, max_seq_len: int = 256, num_segments: int = 2, use_position_embeddings: bool = True, offset_positions_by_padding: bool = True, encoder_normalize_before: bool = False, apply_bert_init: bool = False, activation_fn: str = 'relu', learned_pos_embedding: bool = True, add_bias_kv: bool = False, add_zero_attn: bool = False, embed_scale: float = None, freeze_embeddings: bool = False, n_trans_layers_to_freeze: int = 0, export: bool = False)[source]

Implementation for a Bi-directional Transformer based Sentence Encoder used in BERT/XLM style pre-trained models.

This first computes the token embedding using the token embedding matrix, position embeddings (if specified) and segment embeddings (if specified). After applying the specified number of TransformerEncoderLayers, it outputs all the internal states of the encoder as well as the final representation associated with the first token (usually CLS token).

Input:
  • tokens: B x T matrix representing sentences
  • segment_labels: B x T matrix representing segment label for tokens
Output:
  • a tuple of the following:
    • a list of internal model states used to compute the predictions where each tensor has shape B x T x C
    • sentence representation associated with first input token in format B x C.
forward(tokens: torch.Tensor, segment_labels: torch.Tensor, last_state_only: bool = False, positions: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

fairseq.modules.unfold1d(x, kernel_size, padding_l, pad_value=0)[source]

unfold T x B x C to T x B x C x K