Source code for fairseq.modules.adaptive_input

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


from typing import List

import torch
from torch import nn

from fairseq.modules.quant_noise import quant_noise


[docs]class AdaptiveInput(nn.Module): def __init__( self, vocab_size: int, padding_idx: int, initial_dim: int, factor: float, output_dim: int, cutoff: List[int], q_noise: float = 0, qn_block_size: int = 8, ): super().__init__() if vocab_size > cutoff[-1]: cutoff = cutoff + [vocab_size] else: assert ( vocab_size == cutoff[-1] ), "cannot specify cutoff larger than vocab size" self.cutoff = cutoff self.embedding_dim = output_dim self.padding_idx = padding_idx self.embeddings = nn.ModuleList() for i in range(len(self.cutoff)): prev = self.cutoff[i - 1] if i > 0 else 0 size = self.cutoff[i] - prev dim = int(initial_dim // (factor**i)) seq = nn.Sequential( nn.Embedding(size, dim, self.padding_idx), quant_noise( nn.Linear(dim, output_dim, bias=False), q_noise, qn_block_size ), ) self.embeddings.append(seq) self.padding_idx = None self.padding_idx = padding_idx def init_weights(m): if isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0, std=m.weight.shape[1] ** -0.5) nn.init.constant_(m.weight[padding_idx], 0) elif hasattr(m, "weight"): nn.init.xavier_uniform_(m.weight) self.apply(init_weights) self.register_buffer("_float_tensor", torch.FloatTensor(1))
[docs] def weights_for_band(self, band: int): return self.embeddings[band][0].weight, self.embeddings[band][1].weight
[docs] def forward(self, input: torch.Tensor): result = self._float_tensor.new(input.shape + (self.embedding_dim,)) for i in range(len(self.cutoff)): mask = input.lt(self.cutoff[i]) if i > 0: mask.mul_(input.ge(self.cutoff[i - 1])) chunk_input = input[mask] - self.cutoff[i - 1] else: chunk_input = input[mask] if mask.any(): result[mask] = self.embeddings[i](chunk_input) return result