Source code for fairseq.modules.location_attention

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch.nn as nn
import torch
import torch.nn.functional as F

[docs]class LocationAttention(nn.Module): """ Attention-Based Models for Speech Recognition :param int encoder_dim: # projection-units of encoder :param int decoder_dim: # units of decoder :param int attn_dim: attention dimension :param int conv_dim: # channels of attention convolution :param int conv_kernel_size: filter size of attention convolution """ def __init__( self, attn_dim, encoder_dim, decoder_dim, attn_state_kernel_size, conv_dim, conv_kernel_size, scaling=2.0, ): super(LocationAttention, self).__init__() self.attn_dim = attn_dim self.decoder_dim = decoder_dim self.scaling = scaling self.proj_enc = nn.Linear(encoder_dim, attn_dim) self.proj_dec = nn.Linear(decoder_dim, attn_dim, bias=False) self.proj_attn = nn.Linear(conv_dim, attn_dim, bias=False) self.conv = nn.Conv1d( attn_state_kernel_size, conv_dim, 2 * conv_kernel_size + 1, padding=conv_kernel_size, bias=False, ) self.proj_out = nn.Sequential(nn.Tanh(), nn.Linear(attn_dim, 1)) self.proj_enc_out = None # cache
[docs] def clear_cache(self): self.proj_enc_out = None
[docs] def forward(self, encoder_out, encoder_padding_mask, decoder_h, attn_state): """ :param torch.Tensor encoder_out: padded encoder hidden state B x T x D :param torch.Tensor encoder_padding_mask: encoder padding mask :param torch.Tensor decoder_h: decoder hidden state B x D :param torch.Tensor attn_prev: previous attention weight B x K x T :return: attention weighted encoder state (B, D) :rtype: torch.Tensor :return: previous attention weights (B x T) :rtype: torch.Tensor """ bsz, seq_len, _ = encoder_out.size() if self.proj_enc_out is None: self.proj_enc_out = self.proj_enc(encoder_out) # B x K x T -> B x C x T attn = self.conv(attn_state) # B x C x T -> B x T x C -> B x T x D attn = self.proj_attn(attn.transpose(1, 2)) if decoder_h is None: decoder_h = encoder_out.new_zeros(bsz, self.decoder_dim) dec_h = self.proj_dec(decoder_h).view(bsz, 1, self.attn_dim) out = self.proj_out(attn + self.proj_enc_out + dec_h).squeeze(2) out.masked_fill_(encoder_padding_mask, -float("inf")) w = F.softmax(self.scaling * out, dim=1) c = torch.sum(encoder_out * w.view(bsz, seq_len, 1), dim=1) return c, w