Source code for fairseq.modules.learned_positional_embedding

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch.nn as nn

from fairseq import utils


[docs]class LearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to the forward function. """ def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: int, ): super().__init__(num_embeddings, embedding_dim, padding_idx) self.onnx_trace = False
[docs] def forward(self, input, incremental_state=None, positions=None): """Input is expected to be of size [bsz x seqlen].""" assert ( (positions is None) or (self.padding_idx is None) ), "If positions is pre-computed then padding_idx should not be set." if positions is None: if incremental_state is not None: # positions is the same for every token when decoding a single step # Without the int() cast, it doesn't work in some cases when exporting to ONNX positions = input.data.new(1, 1).fill_(int(self.padding_idx + input.size(1))) else: positions = utils.make_positions( input.data, self.padding_idx, onnx_trace=self.onnx_trace, ) return super().forward(positions)
[docs] def max_positions(self): """Maximum number of supported positions.""" if self.padding_idx is not None: return self.num_embeddings - self.padding_idx - 1 else: return self.num_embeddings