```#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import math
import torch
from torch import nn
from fairseq.modules.rotary_positional_embedding import (
RotaryPositionalEmbedding,
apply_rotary_pos_emb,
)

Args:
n_feat: The number of features.
dropout: Dropout rate.
"""

assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout)

def forward_qkv(self, query, key, value, **kwargs):
"""Transform query, key and value.
Args:
query: Query tensor  B X T1 X C
key: Key tensor B X T2 X C
value: Value tensor  B X T2 X C
Returns:
torch.Tensor: Transformed query tensor  B X n_head X T1 X d_k
torch.Tensor: Transformed key tensor B X n_head X T2 X d_k
torch.Tensor: Transformed value tensor  B X n_head X T2 X d_k
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2)  # (batch, head, time1, d_k)
k = k.transpose(1, 2)  # (batch, head, time2, d_k)
v = v.transpose(1, 2)  # (batch, head, time2, d_k)
return q, k, v

"""Compute attention context vector.
Args:
value: Transformed value B X n_head X T2 X d_k.
scores: Attention score  B X n_head X T1 X T2
Returns:
torch.Tensor: Transformed value  B X T1 X d_model
weighted by the attention score  B X T1 X T2
"""
n_batch = value.size(0)
float("-inf"),  # (batch, head, time1, time2)
)
self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)

else:
self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
)  # (batch, time1, d_model)

return self.linear_out(x)  # (batch, time1, d_model)

"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor T X B X C
key (torch.Tensor): Key tensor T X B X C
value (torch.Tensor): Value tensor T X B X C
Returns:
torch.Tensor: Output tensor T X B X D.
"""
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)

q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
scores = scores.transpose(0, 1)
return scores, None

"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_feat: The number of features.
dropout: Dropout rate.
zero_triu: Whether to zero the upper triangular part of attention matrix.
"""

def __init__(self, n_feat, n_head, dropout, zero_triu=False):
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.zeros(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.zeros(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)

[docs]    def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x: Input tensor B X n_head X T X 2T-1
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)

:, :, :, : x.size(-1) // 2 + 1
]  # only keep the positions from 0 to time2

if self.zero_triu:
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]

return x

"""Compute scaled dot product attention.
Args:
query: Query tensor T X B X C
key: Key tensor T X B X C
value: Value tensor T X B X C
pos_emb: Positional embedding tensor B X 2T-1 X C
Returns:
torch.Tensor: Output tensor T X B X C.
"""
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
pos_emb = pos_emb.transpose(0, 1)
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2)  # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2)  # (batch, head, 2*time1-1, d_k)

q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)

# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))

# compute matrix b and matrix d
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)

scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k
)  # (batch, head, time1, time2)

scores = scores.transpose(0, 1)
return scores, None

def __init__(
self,
n_feat,
dropout,
precision,
rotary_emd_base=10000,
):
precision = torch.float
self.rotary_ndims = self.d_k  # also try self.d_k//2
if precision == "fp16":
precision = torch.half

self.rotary_emb = RotaryPositionalEmbedding(
self.rotary_ndims, base=rotary_emd_base, precision=precision
)

"""Compute rotary position attention.
Args:
query: Query tensor T X B X C
key: Key tensor T X B X C
value: Value tensor T X B X C
Returns:
torch.Tensor: Output tensor T X B X D.
Notes:
Assumes self attn
"""

T, B, C = value.size()
query = query.view(T, B, self.h, self.d_k)
key = key.view(T, B, self.h, self.d_k)
value = value.view(T, B, self.h, self.d_k)
cos, sin = self.rotary_emb(value, seq_len=T)
query, key = apply_rotary_pos_emb(
query, key, cos, sin, offset=0
)  # offset is based on layer_past

query = query.view(T, B, self.h * self.d_k)
key = key.view(T, B, self.h * self.d_k)
value = value.view(T, B, self.h * self.d_k)

# TBD to BTD
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)

q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)