Source code for fairseq.models.fairseq_encoder

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, List, NamedTuple, Optional

import torch
import torch.nn as nn
from torch import Tensor


EncoderOut = NamedTuple(
    "EncoderOut",
    [
        ("encoder_out", Tensor),  # T x B x C
        ("encoder_padding_mask", Optional[Tensor]),  # B x T
        ("encoder_embedding", Optional[Tensor]),  # B x T x C
        ("encoder_states", Optional[List[Tensor]]),  # List[T x B x C]
        ("src_tokens", Optional[Tensor]),  # B x T
        ("src_lengths", Optional[Tensor]),  # B x 1
    ],
)


[docs]class FairseqEncoder(nn.Module): """Base class for encoders.""" def __init__(self, dictionary): super().__init__() self.dictionary = dictionary
[docs] def forward(self, src_tokens, src_lengths=None, **kwargs): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (LongTensor): lengths of each source sentence of shape `(batch)` """ raise NotImplementedError
[docs] def forward_torchscript(self, net_input: Dict[str, Tensor]): """A TorchScript-compatible version of forward. Encoders which use additional arguments may want to override this method for TorchScript compatibility. """ if torch.jit.is_scripting(): return self.forward( src_tokens=net_input["src_tokens"], src_lengths=net_input["src_lengths"], ) else: return self.forward_non_torchscript(net_input)
@torch.jit.unused def forward_non_torchscript(self, net_input: Dict[str, Tensor]): encoder_input = { k: v for k, v in net_input.items() if k != "prev_output_tokens" } return self.forward(**encoder_input)
[docs] def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to `new_order`. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: `encoder_out` rearranged according to `new_order` """ raise NotImplementedError
[docs] def max_positions(self): """Maximum input length supported by the encoder.""" return 1e6 # an arbitrary large number
[docs] def upgrade_state_dict(self, state_dict): """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict
[docs] def set_num_updates(self, num_updates): """State from trainer to pass along to model at every update.""" def _apply(m): if hasattr(m, "set_num_updates") and m != self: m.set_num_updates(num_updates) self.apply(_apply)