#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Multi-Head Attention layer definition."""
import math
import torch
from torch import nn
from fairseq.modules.rotary_positional_embedding import (
RotaryPositionalEmbedding,
apply_rotary_pos_emb,
)
class ESPNETMultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head: The number of heads.
n_feat: The number of features.
dropout: Dropout rate.
"""
def __init__(self, n_feat, n_head, dropout):
"""Construct an MultiHeadedAttention object."""
super(ESPNETMultiHeadedAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
def forward_qkv(self, query, key, value, **kwargs):
"""Transform query, key and value.
Args:
query: Query tensor B X T1 X C
key: Key tensor B X T2 X C
value: Value tensor B X T2 X C
Returns:
torch.Tensor: Transformed query tensor B X n_head X T1 X d_k
torch.Tensor: Transformed key tensor B X n_head X T2 X d_k
torch.Tensor: Transformed value tensor B X n_head X T2 X d_k
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
return q, k, v
def forward_attention(self, value, scores, mask):
"""Compute attention context vector.
Args:
value: Transformed value B X n_head X T2 X d_k.
scores: Attention score B X n_head X T1 X T2
mask: Mask T2 X B
Returns:
torch.Tensor: Transformed value B X T1 X d_model
weighted by the attention score B X T1 X T2
"""
n_batch = value.size(0)
if mask is not None:
scores = scores.masked_fill(
mask.unsqueeze(1).unsqueeze(2).to(bool),
float("-inf"), # (batch, head, time1, time2)
)
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, query, key, value, key_padding_mask=None, **kwargs):
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor T X B X C
key (torch.Tensor): Key tensor T X B X C
value (torch.Tensor): Value tensor T X B X C
mask (torch.Tensor): Mask tensor T X B
Returns:
torch.Tensor: Output tensor T X B X D.
"""
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
scores = self.forward_attention(v, scores, key_padding_mask)
scores = scores.transpose(0, 1)
return scores, None
[docs]class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head: The number of heads.
n_feat: The number of features.
dropout: Dropout rate.
zero_triu: Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_feat, n_head, dropout, zero_triu=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_feat, n_head, dropout)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.zeros(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.zeros(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
[docs] def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x: Input tensor B X n_head X T X 2T-1
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)[
:, :, :, : x.size(-1) // 2 + 1
] # only keep the positions from 0 to time2
if self.zero_triu:
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
[docs] def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs):
"""Compute scaled dot product attention.
Args:
query: Query tensor T X B X C
key: Key tensor T X B X C
value: Value tensor T X B X C
pos_emb: Positional embedding tensor B X 2T-1 X C
key_padding_mask: Mask tensor T X B
Returns:
torch.Tensor: Output tensor T X B X C.
"""
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
pos_emb = pos_emb.transpose(0, 1)
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, 2*time1-1)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k
) # (batch, head, time1, time2)
scores = self.forward_attention(v, scores, key_padding_mask)
scores = scores.transpose(0, 1)
return scores, None
[docs]class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
def __init__(
self,
n_feat,
n_head,
dropout,
precision,
rotary_emd_base=10000,
):
"""Construct an RotaryPositionMultiHeadedAttention object."""
super().__init__(n_feat, n_head, dropout)
precision = torch.float
self.rotary_ndims = self.d_k # also try self.d_k//2
if precision == "fp16":
precision = torch.half
self.rotary_emb = RotaryPositionalEmbedding(
self.rotary_ndims, base=rotary_emd_base, precision=precision
)
[docs] def forward(self, query, key, value, key_padding_mask=None, **kwargs):
"""Compute rotary position attention.
Args:
query: Query tensor T X B X C
key: Key tensor T X B X C
value: Value tensor T X B X C
key_padding_mask: Mask tensor T X B
Returns:
torch.Tensor: Output tensor T X B X D.
Notes:
Assumes self attn
"""
T, B, C = value.size()
query = query.view(T, B, self.h, self.d_k)
key = key.view(T, B, self.h, self.d_k)
value = value.view(T, B, self.h, self.d_k)
cos, sin = self.rotary_emb(value, seq_len=T)
query, key = apply_rotary_pos_emb(
query, key, cos, sin, offset=0
) # offset is based on layer_past
query = query.view(T, B, self.h * self.d_k)
key = key.view(T, B, self.h * self.d_k)
value = value.view(T, B, self.h * self.d_k)
# TBD to BTD
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
scores = self.forward_attention(v, scores, key_padding_mask)
scores = scores.transpose(0, 1)
return scores, None