Source code for fairseq.models.fairseq_encoder

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch.nn as nn


[docs]class FairseqEncoder(nn.Module): """Base class for encoders.""" def __init__(self, dictionary): super().__init__() self.dictionary = dictionary
[docs] def forward(self, src_tokens, src_lengths=None, **kwargs): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (LongTensor): lengths of each source sentence of shape `(batch)` """ raise NotImplementedError
[docs] def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to `new_order`. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: `encoder_out` rearranged according to `new_order` """ raise NotImplementedError
[docs] def max_positions(self): """Maximum input length supported by the encoder.""" return 1e6 # an arbitrary large number
[docs] def upgrade_state_dict(self, state_dict): """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict