Source code for fairseq.modules.transformer_layer

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.modules import LayerNorm, MultiheadAttention


[docs]class TransformerEncoderLayer(nn.Module): """Encoder layer block. In the original paper each operation (multi-head attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.encoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments """ def __init__(self, args): super().__init__() self.embed_dim = args.encoder_embed_dim self.self_attn = MultiheadAttention( self.embed_dim, args.encoder_attention_heads, dropout=args.attention_dropout, self_attention=True ) self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, 'activation_fn', 'relu') ) self.activation_dropout = getattr(args, 'activation_dropout', 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, 'relu_dropout', 0) self.normalize_before = args.encoder_normalize_before self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim)
[docs] def upgrade_state_dict_named(self, state_dict, name): """ Rename layer norm states from `...layer_norms.0.weight` to `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to `...final_layer_norm.weight` """ layer_norm_map = { '0': 'self_attn_layer_norm', '1': 'final_layer_norm' } for old, new in layer_norm_map.items(): for m in ('weight', 'bias'): k = '{}.layer_norms.{}.{}'.format(name, old, m) if k in state_dict: state_dict[ '{}.{}.{}'.format(name, new, m) ] = state_dict[k] del state_dict[k]
[docs] def forward(self, x, encoder_padding_mask, attn_mask=None): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. attn_mask (ByteTensor): binary tensor of shape (T_tgt, T_src), where T_tgt is the length of query, while T_src is the length of key, though here both query and key is x here, attn_mask[t_tgt, t_src] = 1 means when calculating embedding for t_tgt, t_src is excluded (or masked out), =0 means it is included in attention Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if attn_mask is not None: attn_mask = attn_mask.masked_fill(attn_mask.bool(), -1e8) # anything in original attn_mask = 1, becomes -1e8 # anything in original attn_mask = 0, becomes 0 # Note that we cannot use -inf here, because at some edge cases, # the attention weight (before softmax) for some padded element in query # will become -inf, which results in NaN in model parameters # TODO: to formally solve this problem, we need to change fairseq's # MultiheadAttention. We will do this later on. x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) return x
[docs] def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x
[docs]class TransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False): super().__init__() self.embed_dim = args.decoder_embed_dim self.cross_self_attention = getattr(args, 'cross_self_attention', False) self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not self.cross_self_attention, ) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, 'activation_fn', 'relu') ) self.activation_dropout = getattr(args, 'activation_dropout', 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, 'relu_dropout', 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, 'char_inputs', False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, 'encoder_embed_dim', None), vdim=getattr(args, 'encoder_embed_dim', None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False
[docs] def prepare_for_onnx_export_(self): self.onnx_trace = True
[docs] def forward( self, x, encoder_out=None, encoder_padding_mask=None, incremental_state=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, need_attn=False, need_head_weights=False, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ if need_head_weights: need_attn = True residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state[:2] saved_state = {"prev_key": prev_key, "prev_value": prev_value} if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] self.self_attn._set_input_buffer(incremental_state, saved_state) if self.cross_self_attention and not (incremental_state is not None and "prev_key" in self.self_attn._get_input_buffer(incremental_state)): if self_attn_mask is not None: self_attn_mask = torch.cat((x.new(x.size(0), encoder_out.size(0)).zero_(), self_attn_mask), dim=1) if self_attn_padding_mask is not None: if encoder_padding_mask is None: encoder_padding_mask = self_attn_padding_mask.new(encoder_out.size(1), encoder_out.size(0)).zero_() self_attn_padding_mask = torch.cat((encoder_padding_mask, self_attn_padding_mask), dim=1) y = torch.cat((encoder_out, x), dim=0) else: y = x x, attn = self.self_attn( query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state[:2] saved_state = {"prev_key": prev_key, "prev_value": prev_value} if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) if self_attn_padding_mask is not None: self_attn_state = saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"] else: self_attn_state = saved_state["prev_key"], saved_state["prev_value"] return x, attn, self_attn_state return x, attn
[docs] def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x
[docs] def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn
def Linear(in_features, out_features, bias=True): m = nn.Linear(in_features, out_features, bias) nn.init.xavier_uniform_(m.weight) if bias: nn.init.constant_(m.bias, 0.) return m