Source code for fairseq.models.fairseq_decoder

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, List, Optional, Tuple

import torch.nn as nn
from fairseq import utils
from torch import Tensor


[docs]class FairseqDecoder(nn.Module): """Base class for decoders.""" def __init__(self, dictionary): super().__init__() self.dictionary = dictionary self.onnx_trace = False self.adaptive_softmax = None
[docs] def forward(self, prev_output_tokens, encoder_out=None, **kwargs): """ Args: prev_output_tokens (LongTensor): shifted output tokens of shape `(batch, tgt_len)`, for teacher forcing encoder_out (dict, optional): output from the encoder, used for encoder-side attention Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ x, extra = self.extract_features( prev_output_tokens, encoder_out=encoder_out, **kwargs ) x = self.output_layer(x) return x, extra
[docs] def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs): """ Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ raise NotImplementedError
[docs] def output_layer(self, features, **kwargs): """ Project features to the default output size, e.g., vocabulary size. Args: features (Tensor): features returned by *extract_features*. """ raise NotImplementedError
[docs] def get_normalized_probs( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
# TorchScript doesn't support super() method so that the scriptable Subclass # can't access the base class model in Torchscript. # Current workaround is to add a helper function with different name and # call the helper function from scriptable Subclass.
[docs] def get_normalized_probs_scriptable( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: if sample is not None: assert "target" in sample target = sample["target"] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0] if log_probs: return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) else: return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
[docs] def max_positions(self): """Maximum input length supported by the decoder.""" return 1e6 # an arbitrary large number
[docs] def upgrade_state_dict_named(self, state_dict, name): """Upgrade old state dicts to work with newer code.""" return state_dict
def prepare_for_onnx_export_(self): self.onnx_trace = True