Source code for fairseq.modules.sinusoidal_positional_embedding

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import math

import torch
import torch.nn as nn
import torch.onnx.operators

from fairseq import utils


[docs]class SinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length. Padding symbols are ignored. """ def __init__(self, embedding_dim, padding_idx, init_size=1024): super().__init__() self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.weights = SinusoidalPositionalEmbedding.get_embedding( init_size, embedding_dim, padding_idx, ) self.onnx_trace = False self.register_buffer('_float_tensor', torch.FloatTensor(1))
[docs] def prepare_for_onnx_export_(self): self.onnx_trace = True
[docs] @staticmethod def get_embedding(num_embeddings, embedding_dim, padding_idx=None): """Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) if embedding_dim % 2 == 1: # zero pad emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) if padding_idx is not None: emb[padding_idx, :] = 0 return emb
[docs] def forward(self, input, incremental_state=None, timestep=None, **kwargs): """Input is expected to be of size [bsz x seqlen].""" bsz, seq_len = torch.onnx.operators.shape_as_tensor(input) max_pos = self.padding_idx + 1 + seq_len if self.weights is None or max_pos > self.weights.size(0): # recompute/expand embeddings if needed self.weights = SinusoidalPositionalEmbedding.get_embedding( max_pos, self.embedding_dim, self.padding_idx, ) self.weights = self.weights.to(self._float_tensor) if incremental_state is not None: # positions is the same for every token when decoding a single step pos = (timestep.int() + 1).long() if timestep is not None else seq_len if self.onnx_trace: return self.weights[self.padding_idx + pos, :].unsqueeze(1).repeat(bsz, 1, 1) return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) positions = utils.make_positions(input, self.padding_idx, onnx_trace=self.onnx_trace) if self.onnx_trace: flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) embedding_shape = torch.cat((bsz.view(1), seq_len.view(1), torch.LongTensor([-1]))) embeddings = torch.onnx.operators.reshape_from_tensor_shape(flat_embeddings, embedding_shape) return embeddings return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
[docs] def max_positions(self): """Maximum number of supported positions.""" return int(1e5) # an arbitrary large number