Source code for fairseq.models.fairseq_incremental_decoder

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Dict, Optional

from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.models import FairseqDecoder
from torch import Tensor


logger = logging.getLogger(__name__)


[docs]@with_incremental_state class FairseqIncrementalDecoder(FairseqDecoder): """Base class for incremental decoders. Incremental decoding is a special mode at inference time where the Model only receives a single timestep of input corresponding to the previous output token (for teacher forcing) and must produce the next output *incrementally*. Thus the model must cache any long-term state that is needed about the sequence, e.g., hidden states, convolutional states, etc. Compared to the standard :class:`FairseqDecoder` interface, the incremental decoder interface allows :func:`forward` functions to take an extra keyword argument (*incremental_state*) that can be used to cache state across time-steps. The :class:`FairseqIncrementalDecoder` interface also defines the :func:`reorder_incremental_state` method, which is used during beam search to select and reorder the incremental state based on the selection of beams. To learn more about how incremental decoding works, refer to `this blog <http://www.telesens.co/2019/04/21/understanding-incremental-decoding-in-fairseq/>`_. """ def __init__(self, dictionary): super().__init__(dictionary)
[docs] def forward( self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs ): """ Args: prev_output_tokens (LongTensor): shifted output tokens of shape `(batch, tgt_len)`, for teacher forcing encoder_out (dict, optional): output from the encoder, used for encoder-side attention incremental_state (dict, optional): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ raise NotImplementedError
[docs] def extract_features( self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs ): """ Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ raise NotImplementedError
[docs] def reorder_incremental_state( self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor, ): """Reorder incremental state. This will be called when the order of the input has changed from the previous time step. A typical use case is beam search, where the input order changes between time steps based on the selection of beams. """ pass
[docs] def reorder_incremental_state_scripting( self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor, ): """Main entry point for reordering the incremental state. Due to limitations in TorchScript, we call this function in :class:`fairseq.sequence_generator.SequenceGenerator` instead of calling :func:`reorder_incremental_state` directly. """ for module in self.modules(): if hasattr(module, "reorder_incremental_state"): result = module.reorder_incremental_state(incremental_state, new_order) if result is not None: incremental_state = result
[docs] def set_beam_size(self, beam_size): """Sets the beam size in the decoder and all children.""" if getattr(self, "_beam_size", -1) != beam_size: seen = set() def apply_set_beam_size(module): if ( module != self and hasattr(module, "set_beam_size") and module not in seen ): seen.add(module) module.set_beam_size(beam_size) self.apply(apply_set_beam_size) self._beam_size = beam_size