Source code for fairseq.models.composite_encoder

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .fairseq_encoder import FairseqEncoder


[docs]class CompositeEncoder(FairseqEncoder): """ A wrapper around a dictionary of :class:`FairseqEncoder` objects. We run forward on each encoder and return a dictionary of outputs. The first encoder's dictionary is used for initialization. Args: encoders (dict): a dictionary of :class:`FairseqEncoder` objects. """ def __init__(self, encoders): super().__init__(next(iter(encoders.values())).dictionary) self.encoders = encoders for key in self.encoders: self.add_module(key, self.encoders[key])
[docs] def forward(self, src_tokens, src_lengths): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (LongTensor): lengths of each source sentence of shape `(batch)` Returns: dict: the outputs from each Encoder """ encoder_out = {} for key in self.encoders: encoder_out[key] = self.encoders[key](src_tokens, src_lengths) return encoder_out
[docs] def reorder_encoder_out(self, encoder_out, new_order): """Reorder encoder output according to new_order.""" for key in self.encoders: encoder_out[key] = self.encoders[key].reorder_encoder_out( encoder_out[key], new_order ) return encoder_out
[docs] def max_positions(self): return min(self.encoders[key].max_positions() for key in self.encoders)
def upgrade_state_dict(self, state_dict): for key in self.encoders: self.encoders[key].upgrade_state_dict(state_dict) return state_dict