Source code for fairseq.modules.sinusoidal_positional_embedding

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
from typing import Any, Optional

import torch
import torch.onnx.operators
from fairseq import utils
from torch import Tensor, nn


[docs]class SinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length. Padding symbols are ignored. """ def __init__(self, embedding_dim, padding_idx, init_size=1024): super().__init__() self.embedding_dim = embedding_dim self.padding_idx = padding_idx if padding_idx is not None else 0 self.weights = SinusoidalPositionalEmbedding.get_embedding( init_size, embedding_dim, padding_idx ) self.onnx_trace = False self.register_buffer("_float_tensor", torch.FloatTensor(1)) self.max_positions = int(1e5)
[docs] def prepare_for_onnx_export_(self): self.onnx_trace = True
[docs] @staticmethod def get_embedding( num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None ): """Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze( 1 ) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view( num_embeddings, -1 ) if embedding_dim % 2 == 1: # zero pad emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) if padding_idx is not None: emb[padding_idx, :] = 0 return emb
[docs] def forward( self, input, incremental_state: Optional[Any] = None, timestep: Optional[Tensor] = None, positions: Optional[Any] = None, ): """Input is expected to be of size [bsz x seqlen].""" if torch.jit.is_scripting(): bspair = torch.onnx.operators.shape_as_tensor(input) elif torch.onnx.is_in_onnx_export(): bspair = torch.onnx.operators.shape_as_tensor(input) else: bspair = input.size() bsz, seq_len = bspair[0], bspair[1] max_pos = self.padding_idx + 1 + seq_len if self.weights is None or max_pos > self.weights.size(0): # recompute/expand embeddings if needed self.weights = SinusoidalPositionalEmbedding.get_embedding( max_pos, self.embedding_dim, self.padding_idx ) self.weights = self.weights.to(self._float_tensor) if incremental_state is not None: # positions is the same for every token when decoding a single step pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len if self.onnx_trace: return ( self.weights.index_select(index=self.padding_idx + pos, dim=0) .unsqueeze(1) .repeat(bsz, 1, 1) ) return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) positions = utils.make_positions( input, self.padding_idx, onnx_trace=self.onnx_trace ) if self.onnx_trace: flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) embedding_shape = torch.cat( (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long)) ) embeddings = torch.onnx.operators.reshape_from_tensor_shape( flat_embeddings, embedding_shape ) return embeddings return ( self.weights.index_select(0, positions.view(-1)) .view(bsz, seq_len, -1) .detach() )