# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.nn as nn
import torch.onnx.operators
from fairseq import utils
[docs]class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
padding_idx,
)
self.onnx_trace = False
self.register_buffer('_float_tensor', torch.FloatTensor(1))
[docs] def prepare_for_onnx_export_(self):
self.onnx_trace = True
[docs] @staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
[docs] def forward(self, input, incremental_state=None, timestep=None, **kwargs):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = torch.onnx.operators.shape_as_tensor(input)
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
self.padding_idx,
)
self.weights = self.weights.to(self._float_tensor)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = (timestep.int() + 1).long() if timestep is not None else seq_len
if self.onnx_trace:
return self.weights[self.padding_idx + pos, :].unsqueeze(1).repeat(bsz, 1, 1)
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = utils.make_positions(input, self.padding_idx, onnx_trace=self.onnx_trace)
if self.onnx_trace:
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
embedding_shape = torch.cat((bsz.view(1), seq_len.view(1), torch.LongTensor([-1])))
embeddings = torch.onnx.operators.reshape_from_tensor_shape(flat_embeddings, embedding_shape)
return embeddings
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
[docs] def max_positions(self):
"""Maximum number of supported positions."""
return int(1e5) # an arbitrary large number