Source code for fairseq.models.fairseq_incremental_decoder

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

from fairseq.models import FairseqDecoder


[docs]class FairseqIncrementalDecoder(FairseqDecoder): """Base class for incremental decoders. Incremental decoding is a special mode at inference time where the Model only receives a single timestep of input corresponding to the previous output token (for input feeding) and must produce the next output *incrementally*. Thus the model must cache any long-term state that is needed about the sequence, e.g., hidden states, convolutional states, etc. Compared to the standard :class:`FairseqDecoder` interface, the incremental decoder interface allows :func:`forward` functions to take an extra keyword argument (*incremental_state*) that can be used to cache state across time-steps. The :class:`FairseqIncrementalDecoder` interface also defines the :func:`reorder_incremental_state` method, which is used during beam search to select and reorder the incremental state based on the selection of beams. To learn more about how incremental decoding works, refer to `this blog <http://www.telesens.co/2019/04/21/understanding-incremental-decoding-in-fairseq/>`_. """ def __init__(self, dictionary): super().__init__(dictionary)
[docs] def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs): """ Args: prev_output_tokens (LongTensor): shifted output tokens of shape `(batch, tgt_len)`, for input feeding/teacher forcing encoder_out (dict, optional): output from the encoder, used for encoder-side attention incremental_state (dict, optional): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ raise NotImplementedError
[docs] def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs): """ Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ raise NotImplementedError
[docs] def reorder_incremental_state(self, incremental_state, new_order): """Reorder incremental state. This should be called when the order of the input has changed from the previous time step. A typical use case is beam search, where the input order changes between time steps based on the selection of beams. """ seen = set() def apply_reorder_incremental_state(module): if module != self and hasattr(module, 'reorder_incremental_state') \ and module not in seen: seen.add(module) module.reorder_incremental_state(incremental_state, new_order) self.apply(apply_reorder_incremental_state)
[docs] def set_beam_size(self, beam_size): """Sets the beam size in the decoder and all children.""" if getattr(self, '_beam_size', -1) != beam_size: seen = set() def apply_set_beam_size(module): if module != self and hasattr(module, 'set_beam_size') \ and module not in seen: seen.add(module) module.set_beam_size(beam_size) self.apply(apply_set_beam_size) self._beam_size = beam_size