Modules

Fairseq provides several stand-alone torch.nn.Module classes that may be helpful when implementing a new BaseFairseqModel.

isort:skip_file

class fairseq.modules.AdaptiveInput(vocab_size: int, padding_idx: int, initial_dim: int, factor: float, output_dim: int, cutoff: List[int], q_noise: float = 0, qn_block_size: int = 8)[source]
forward(input: torch.Tensor)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

weights_for_band(band: int)[source]
class fairseq.modules.AdaptiveSoftmax(vocab_size, input_dim, cutoff, dropout, factor=4.0, adaptive_inputs=None, tie_proj=False, q_noise=0, qn_block_size=8)[source]

This is an implementation of the efficient softmax approximation for graphical processing units (GPU), described in the paper “Efficient softmax approximation for GPUs” (http://arxiv.org/abs/1609.04309).

adapt_target(target)[source]

In order to be efficient, the AdaptiveSoftMax does not compute the scores for all the word of the vocabulary for all the examples. It is thus necessary to call the method adapt_target of the AdaptiveSoftMax layer inside each forward pass.

forward(input, target)[source]
Parameters:
  • input – (b x t x d)
  • target – (b x t)
Returns:

output for each cutoff section and new targets by cut off

Return type:

2 lists

get_log_prob(input, target)[source]

Computes the log probabilities for all the words of the vocabulary, given a 2D tensor of hidden vectors.

upgrade_state_dict_named(state_dict, name)[source]
class fairseq.modules.BeamableMM(beam_size=None)[source]

This module provides an optimized MM for beam decoding with attention.

It leverage the fact that the source-side of the input is replicated beam times and the target-side of the input is of width one. This layer speeds up inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.

forward(input1, input2)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

set_beam_size(beam_size)[source]
class fairseq.modules.CharacterTokenEmbedder(vocab: fairseq.data.dictionary.Dictionary, filters: List[Tuple[int, int]], char_embed_dim: int, word_embed_dim: int, highway_layers: int, max_char_len: int = 50, char_inputs: bool = False)[source]
forward(input: torch.Tensor)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

padding_idx
prepare_for_onnx_export_()[source]
reset_parameters()[source]
set_vocab(vocab, max_char_len)[source]
class fairseq.modules.ConvTBC(in_channels, out_channels, kernel_size, padding=0)[source]

1D convolution over an input of shape (time x batch x channel)

The implementation uses gemm to perform the convolution. This implementation is faster than cuDNN for small kernel sizes.

forward(input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

fairseq.modules.cross_entropy(logits, target, ignore_index=-100, reduction='mean')[source]
class fairseq.modules.DownsampledMultiHeadAttention(out_channels, embed_dim, num_heads, dropout=0.0, bias=True, project_input=True, gated=False, downsample=False)[source]

Multi-headed attention with Gating and Downsampling

forward(query, key, value, mask_future_timesteps=False, key_padding_mask=None, use_scalar_bias=False)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fairseq.modules.DynamicConv1dTBC(input_size, kernel_size=1, padding_l=None, num_heads=1, weight_dropout=0.0, weight_softmax=False, renorm_padding=False, bias=False, conv_bias=False, query_size=None, in_proj=False)[source]

Dynamic lightweight convolution taking T x B x C inputs :param input_size: # of channels of the input :param kernel_size: convolution channels :param padding_l: padding to the left when using “same” padding :param num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) :param weight_dropout: the drop rate of the DropConnect to drop the weight :param weight_softmax: normalize the weight with softmax before the convolution :param renorm_padding: re-normalize the filters to ignore the padded part (only the non-padding parts sum up to 1) :param bias: use bias :param conv_bias: bias of the convolution :param query_size: specified when feeding a different input as the query :param in_proj: project the input and generate the filter together

Shape:
Input: TxBxC, i.e. (timesteps, batch_size, input_size) Output: TxBxC, i.e. (timesteps, batch_size, input_size)
weight

the learnable weights of the module of shape (num_heads, 1, kernel_size)

bias

the learnable bias of the module of shape (input_size)

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x, incremental_state=None, query=None, unfold=None)[source]

Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C :param x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) :param incremental_state: A dict to keep the state :param unfold: unfold the input or not. If not, we use the matrix trick instead :param query: use the specified query to predict the conv filters

in_proj
reorder_incremental_state(incremental_state, new_order)[source]
reset_parameters()[source]
fairseq.modules.DynamicConv(input_size, kernel_size=1, padding_l=None, num_heads=1, weight_dropout=0.0, weight_softmax=False, renorm_padding=False, bias=False, conv_bias=False, query_size=None, in_proj=False)[source]
class fairseq.modules.DynamicCRF(num_embedding, low_rank=32, beam_size=64)[source]

Dynamic CRF layer is used to approximate the traditional Conditional Random Fields (CRF) $P(y | x) = 1/Z(x) exp(sum_i s(y_i, x) + sum_i t(y_{i-1}, y_i, x))$

where in this function, we assume the emition scores (s) are given, and the transition score is a |V| x |V| matrix $M$

in the following two aspects:
  1. it used a low-rank approximation for the transition matrix: $M = E_1 E_2^T$
  2. it used a beam to estimate the normalizing factor Z(x)
extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(emissions, targets, masks, beam=None)[source]

Compute the conditional log-likelihood of a sequence of target tokens given emission scores

Parameters:
  • emissions (~torch.Tensor) – Emission score are usually the unnormalized decoder output (batch_size, seq_len, vocab_size). We assume batch-first
  • targets (~torch.LongTensor) – Sequence of target token indices ``(batch_size, seq_len)
  • masks (~torch.ByteTensor) – Mask tensor with the same size as targets
Returns:

approximated log-likelihood

Return type:

~torch.Tensor

forward_decoder(emissions, masks=None, beam=None)[source]

Find the most likely output sequence using Viterbi algorithm.

Parameters:
  • emissions (~torch.Tensor) – Emission score are usually the unnormalized decoder output (batch_size, seq_len, vocab_size). We assume batch-first
  • masks (~torch.ByteTensor) – Mask tensor with the same size as targets
Returns:

decoded sequence from the CRF model

Return type:

~torch.LongTensor

class fairseq.modules.FairseqDropout(p, module_name=None)[source]
forward(x, inplace: bool = False)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

make_generation_fast_(name: str, retain_dropout: bool = False, retain_dropout_modules: Optional[List[str]] = None, **kwargs)[source]
class fairseq.modules.Fp32GroupNorm(*args, **kwargs)[source]
forward(input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fairseq.modules.Fp32LayerNorm(*args, **kwargs)[source]
forward(input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

fairseq.modules.gelu(x: torch.Tensor) → torch.Tensor[source]
fairseq.modules.gelu_accurate(x)[source]
class fairseq.modules.GradMultiply[source]
static backward(ctx, grad)[source]

Defines a formula for differentiating the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs did forward() return, and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computated w.r.t. the output.

static forward(ctx, x, scale)[source]

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store tensors that can be then retrieved during the backward pass.

class fairseq.modules.GumbelVectorQuantizer(dim, num_vars, temp, groups, combine_groups, vq_dim, time_first, activation=GELU(), weight_proj_depth=1, weight_proj_factor=1)[source]
codebook()[source]
forward(x, produce_targets=False)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_idx(x)[source]
get_codebook_indices()[source]
sample_from_codebook(b, n)[source]
set_num_updates(num_updates)[source]
to_codebook_index(indices)[source]
class fairseq.modules.KmeansVectorQuantizer(dim, num_vars, groups, combine_groups, vq_dim, time_first, gamma=0.25)[source]
expand_embedding
forward(x, produce_targets=False)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_idx(x)[source]
class fairseq.modules.LayerDropModuleList(p, modules=None)[source]

A LayerDrop implementation based on torch.nn.ModuleList.

We refresh the choice of which layers to drop every time we iterate over the LayerDropModuleList instance. During evaluation we always iterate over all layers.

Usage:

layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3])
for layer in layers:  # this might iterate over layers 1 and 3
    x = layer(x)
for layer in layers:  # this might iterate over all layers
    x = layer(x)
for layer in layers:  # this might not iterate over any layers
    x = layer(x)
Parameters:
  • p (float) – probability of dropping out each layer
  • modules (iterable, optional) – an iterable of modules to add
fairseq.modules.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, export=False)[source]
class fairseq.modules.LearnedPositionalEmbedding(num_embeddings: int, embedding_dim: int, padding_idx: int)[source]

This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to the forward function.

forward(input: torch.Tensor, incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] = None, positions: Optional[torch.Tensor] = None)[source]

Input is expected to be of size [bsz x seqlen].

class fairseq.modules.LightweightConv1dTBC(input_size, kernel_size=1, padding_l=None, num_heads=1, weight_dropout=0.0, weight_softmax=False, bias=False)[source]

Lightweight Convolution assuming the input is TxBxC :param input_size: # of channels of the input :param kernel_size: convolution channels :param padding_l: padding to the left when using “same” padding :param num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) :param weight_dropout: the drop rate of the DropConnect to drop the weight :param weight_softmax: normalize the weight with softmax before the convolution :param bias: use bias

Shape:
Input: TxBxC, i.e. (timesteps, batch_size, input_size) Output: TxBxC, i.e. (timesteps, batch_size, input_size)
weight

the learnable weights of the module of shape (num_heads, 1, kernel_size)

bias

the learnable bias of the module of shape (input_size)

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x, incremental_state=None, unfold=False)[source]

Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C :param x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) :param incremental_state: A dict to keep the state :param unfold: unfold the input or not. If not, we use the matrix trick instead

prepare_for_onnx_export_()[source]
reorder_incremental_state(incremental_state, new_order)[source]
reset_parameters()[source]
fairseq.modules.LightweightConv(input_size, kernel_size=1, padding_l=None, num_heads=1, weight_dropout=0.0, weight_softmax=False, bias=False)[source]
class fairseq.modules.LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)[source]

An optimized version of nn.Conv1d.

At training time, this module uses ConvTBC, which is an optimized version of Conv1d. At inference time, it optimizes incremental generation (i.e., one time step at a time) by replacing the convolutions with linear layers. Note that the input order changes from training to inference.

forward(input, incremental_state=None)[source]
Parameters:incremental_state – Used to buffer signal; if not None, then input is expected to contain a single frame. If the input order changes between time steps, call reorder_incremental_state.
Input:
Time x Batch x Channel during training Batch x Time x Channel during inference
reorder_incremental_state(incremental_state, new_order)[source]
state_dict(destination=None, prefix='', keep_vars=False)[source]

Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names.

Returns:a dictionary containing a whole state of the module
Return type:dict

Example:

>>> module.state_dict().keys()
['bias', 'weight']
upgrade_state_dict_named(state_dict, name)[source]
class fairseq.modules.MultiheadAttention(embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, q_noise=0.0, qn_block_size=8)[source]

Multi-headed attention.

See “Attention Is All You Need” for more details.

apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int)[source]
forward(query, key: Optional[torch.Tensor], value: Optional[torch.Tensor], key_padding_mask: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] = None, need_weights: bool = True, static_kv: bool = False, attn_mask: Optional[torch.Tensor] = None, before_softmax: bool = False, need_head_weights: bool = False) → Tuple[torch.Tensor, Optional[torch.Tensor]][source]

Input shape: Time x Batch x Channel

Parameters:
  • key_padding_mask (ByteTensor, optional) – mask to exclude keys that are pads, of shape (batch, src_len), where padding elements are indicated by 1s.
  • need_weights (bool, optional) – return the attention weights, averaged over heads (default: False).
  • attn_mask (ByteTensor, optional) – typically used to implement causal attention, where the mask prevents the attention from looking forward in time (default: None).
  • before_softmax (bool, optional) – return the raw attention weights and values before the attention softmax.
  • need_head_weights (bool, optional) – return the attention weights for each head. Implies need_weights. Default: return the average attention weights over all heads.
prepare_for_onnx_export_()[source]
prepare_for_tpu_(**kwargs)[source]
reorder_incremental_state(incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]], new_order: torch.Tensor)[source]

Reorder buffered internal state (for incremental generation).

reset_parameters()[source]
upgrade_state_dict_named(state_dict, name)[source]
fairseq.modules.PositionalEmbedding(num_embeddings: int, embedding_dim: int, padding_idx: int, learned: bool = False)[source]
class fairseq.modules.SamePad(kernel_size)[source]
forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fairseq.modules.ScalarBias[source]

Adds a vector of scalars, used in self-attention mechanism to allow the model to optionally attend to this vector instead of the past

static backward(ctx, grad)[source]

Defines a formula for differentiating the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs did forward() return, and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computated w.r.t. the output.

static forward(ctx, input, dim, bias_init)[source]

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store tensors that can be then retrieved during the backward pass.

class fairseq.modules.SinusoidalPositionalEmbedding(embedding_dim, padding_idx, init_size=1024)[source]

This module produces sinusoidal positional embeddings of any length.

Padding symbols are ignored.

forward(input, incremental_state: Optional[Any] = None, timestep: Optional[torch.Tensor] = None, positions: Optional[Any] = None)[source]

Input is expected to be of size [bsz x seqlen].

static get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None)[source]

Build sinusoidal embeddings.

This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of “Attention Is All You Need”.

prepare_for_onnx_export_()[source]
class fairseq.modules.TransformerSentenceEncoderLayer(embedding_dim: int = 768, ffn_embedding_dim: int = 3072, num_attention_heads: int = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, activation_fn: str = 'relu', export: bool = False, q_noise: float = 0.0, qn_block_size: int = 8, init_fn: Callable = None)[source]

Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained models.

build_fc1(input_dim, output_dim, q_noise, qn_block_size)[source]
build_fc2(input_dim, output_dim, q_noise, qn_block_size)[source]
build_self_attention(embed_dim, num_attention_heads, dropout, self_attention, q_noise, qn_block_size)[source]
forward(x: torch.Tensor, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None)[source]

LayerNorm is applied either before or after the self-attention/ffn modules similar to the original Transformer implementation.

class fairseq.modules.TransformerSentenceEncoder(padding_idx: int, vocab_size: int, num_encoder_layers: int = 6, embedding_dim: int = 768, ffn_embedding_dim: int = 3072, num_attention_heads: int = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, layerdrop: float = 0.0, max_seq_len: int = 256, num_segments: int = 2, use_position_embeddings: bool = True, offset_positions_by_padding: bool = True, encoder_normalize_before: bool = False, apply_bert_init: bool = False, activation_fn: str = 'relu', learned_pos_embedding: bool = True, embed_scale: float = None, freeze_embeddings: bool = False, n_trans_layers_to_freeze: int = 0, export: bool = False, traceable: bool = False, q_noise: float = 0.0, qn_block_size: int = 8)[source]

Implementation for a Bi-directional Transformer based Sentence Encoder used in BERT/XLM style pre-trained models.

This first computes the token embedding using the token embedding matrix, position embeddings (if specified) and segment embeddings (if specified). After applying the specified number of TransformerEncoderLayers, it outputs all the internal states of the encoder as well as the final representation associated with the first token (usually CLS token).

Input:
  • tokens: B x T matrix representing sentences
  • segment_labels: B x T matrix representing segment label for tokens
Output:
  • a tuple of the following:
    • a list of internal model states used to compute the predictions where each tensor has shape T x B x C
    • sentence representation associated with first input token in format B x C.
build_embedding(vocab_size, embedding_dim, padding_idx)[source]
build_transformer_sentence_encoder_layer(embedding_dim, ffn_embedding_dim, num_attention_heads, dropout, attention_dropout, activation_dropout, activation_fn, export, q_noise, qn_block_size)[source]
forward(tokens: torch.Tensor, segment_labels: torch.Tensor = None, last_state_only: bool = False, positions: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

prepare_for_tpu_(**kwargs)[source]
class fairseq.modules.TransformerDecoderLayer(args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False)[source]

Decoder layer block.

In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: dropout -> add residual -> layernorm. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: dropout -> add residual. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting args.decoder_normalize_before to True.

Parameters:
  • args (argparse.Namespace) – parsed command-line arguments
  • no_encoder_attn (bool, optional) – whether to attend to encoder outputs (default: False).
build_encoder_attention(embed_dim, args)[source]
build_fc1(input_dim, output_dim, q_noise, qn_block_size)[source]
build_fc2(input_dim, output_dim, q_noise, qn_block_size)[source]
build_self_attention(embed_dim, args, add_bias_kv=False, add_zero_attn=False)[source]
forward(x, encoder_out: Optional[torch.Tensor] = None, encoder_padding_mask: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] = None, prev_self_attn_state: Optional[List[torch.Tensor]] = None, prev_attn_state: Optional[List[torch.Tensor]] = None, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, need_attn: bool = False, need_head_weights: bool = False)[source]
Parameters:
  • x (Tensor) – input to the layer of shape (seq_len, batch, embed_dim)
  • encoder_padding_mask (ByteTensor, optional) – binary ByteTensor of shape (batch, src_len) where padding elements are indicated by 1.
  • need_attn (bool, optional) – return attention weights
  • need_head_weights (bool, optional) – return attention weights for each head (default: return average over heads).
Returns:

encoded output of shape (seq_len, batch, embed_dim)

make_generation_fast_(need_attn: bool = False, **kwargs)[source]
prepare_for_onnx_export_()[source]
residual_connection(x, residual)[source]
class fairseq.modules.TransformerEncoderLayer(args)[source]

Encoder layer block.

In the original paper each operation (multi-head attention or FFN) is postprocessed with: dropout -> add residual -> layernorm. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: dropout -> add residual. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting args.encoder_normalize_before to True.

Parameters:args (argparse.Namespace) – parsed command-line arguments
build_fc1(input_dim, output_dim, q_noise, qn_block_size)[source]
build_fc2(input_dim, output_dim, q_noise, qn_block_size)[source]
build_self_attention(embed_dim, args)[source]
forward(x, encoder_padding_mask, attn_mask: Optional[torch.Tensor] = None)[source]
Parameters:
  • x (Tensor) – input to the layer of shape (seq_len, batch, embed_dim)
  • encoder_padding_mask (ByteTensor) – binary ByteTensor of shape (batch, seq_len) where padding elements are indicated by 1.
  • attn_mask (ByteTensor) – binary tensor of shape (tgt_len, src_len), where tgt_len is the length of output and src_len is the length of input, though here both are equal to seq_len. attn_mask[tgt_i, src_j] = 1 means that when calculating the embedding for tgt_i, we exclude (mask out) src_j. This is useful for strided self-attention.
Returns:

encoded output of shape (seq_len, batch, embed_dim)

residual_connection(x, residual)[source]
upgrade_state_dict_named(state_dict, name)[source]

Rename layer norm states from …layer_norms.0.weight to …self_attn_layer_norm.weight and …layer_norms.1.weight to …final_layer_norm.weight

class fairseq.modules.TransposeLast(deconstruct_idx=None)[source]
forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fairseq.modules.VGGBlock(in_channels, out_channels, conv_kernel_size, pooling_kernel_size, num_conv_layers, input_dim, conv_stride=1, padding=None, layer_norm=False)[source]

VGG motibated cnn module https://arxiv.org/pdf/1409.1556.pdf

Parameters:
  • in_channels – (int) number of input channels (typically 1)
  • out_channels – (int) number of output channels
  • conv_kernel_size – convolution channels
  • pooling_kernel_size – the size of the pooling window to take a max over
  • num_conv_layers – (int) number of convolution layers
  • input_dim – (int) input dimension
  • conv_stride – the stride of the convolving kernel. Can be a single number or a tuple (sH, sW) Default: 1
  • padding – implicit paddings on both sides of the input. Can be a single number or a tuple (padH, padW). Default: None
  • layer_norm – (bool) if layer norm is going to be applied. Default: False
Shape:
Input: BxCxTxfeat, i.e. (batch_size, input_size, timesteps, features) Output: BxCxTxfeat, i.e. (batch_size, input_size, timesteps, features)
forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

fairseq.modules.unfold1d(x, kernel_size, padding_l, pad_value=0)[source]

unfold T x B x C to T x B x C x K