# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.nn as nn
from fairseq import utils
[docs]class LearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
Padding ids are ignored by either offsetting based on padding_idx
or by setting padding_idx to None and ensuring that the appropriate
position ids are passed to the forward function.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.onnx_trace = False
[docs] def forward(self, input, incremental_state=None, positions=None):
"""Input is expected to be of size [bsz x seqlen]."""
assert (
(positions is None) or (self.padding_idx is None)
), "If positions is pre-computed then padding_idx should not be set."
if positions is None:
if incremental_state is not None:
# positions is the same for every token when decoding a single step
positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
else:
positions = utils.make_positions(
input.data, self.padding_idx, onnx_trace=self.onnx_trace,
)
return super().forward(positions)
[docs] def max_positions(self):
"""Maximum number of supported positions."""
if self.padding_idx is not None:
return self.num_embeddings - self.padding_idx - 1
else:
return self.num_embeddings