Source code for fairseq.models.fairseq_model

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Base classes for various fairseq models.
"""

import logging
from argparse import Namespace
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import Dictionary
from fairseq.dataclass.utils import (
    convert_namespace_to_omegaconf,
    gen_parser_from_dataclass,
)
from fairseq.models import FairseqDecoder, FairseqEncoder
from omegaconf import DictConfig
from torch import Tensor


logger = logging.getLogger(__name__)


def check_type(module, expected_type):
    if hasattr(module, "unwrapped_module"):
        assert isinstance(
            module.unwrapped_module, expected_type
        ), f"{type(module.unwrapped_module)} != {expected_type}"
    else:
        assert isinstance(module, expected_type), f"{type(module)} != {expected_type}"


[docs]class BaseFairseqModel(nn.Module): """Base class for fairseq models.""" def __init__(self): super().__init__() self._is_generation_fast = False
[docs] @classmethod def add_args(cls, parser): """Add model-specific arguments to the parser.""" dc = getattr(cls, "__dataclass", None) if dc is not None: # do not set defaults so that settings defaults from various architectures still works gen_parser_from_dataclass(parser, dc(), delete_default=True)
[docs] @classmethod def build_model(cls, args, task): """Build a new model instance.""" raise NotImplementedError("Model must implement the build_model method")
[docs] def get_targets(self, sample, net_output): """Get targets from either the sample or the net's output.""" return sample["target"]
[docs] def get_normalized_probs( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
# TorchScript doesn't support super() method so that the scriptable Subclass # can't access the base class model in Torchscript. # Current workaround is to add a helper function with different name and # call the helper function from scriptable Subclass.
[docs] def get_normalized_probs_scriptable( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Scriptable helper function for get_normalized_probs in ~BaseFairseqModel""" if hasattr(self, "decoder"): return self.decoder.get_normalized_probs(net_output, log_probs, sample) elif torch.is_tensor(net_output): # syntactic sugar for simple models which don't have a decoder # (e.g., the classification tutorial) logits = net_output.float() if log_probs: return F.log_softmax(logits, dim=-1) else: return F.softmax(logits, dim=-1) raise NotImplementedError
[docs] def extract_features(self, *args, **kwargs): """Similar to *forward* but only return features.""" return self(*args, **kwargs)
[docs] def max_positions(self): """Maximum length supported by the model.""" return None
[docs] def load_state_dict( self, state_dict, strict=True, model_cfg: Optional[DictConfig] = None, args: Optional[Namespace] = None, ): """Copies parameters and buffers from *state_dict* into this module and its descendants. Overrides the method in :class:`nn.Module`. Compared with that method this additionally "upgrades" *state_dicts* from old checkpoints. """ if model_cfg is None and args is not None: logger.warn( "using 'args' is deprecated, please update your code to use dataclass config" ) model_cfg = convert_namespace_to_omegaconf(args).model self.upgrade_state_dict(state_dict) from fairseq.checkpoint_utils import prune_state_dict new_state_dict = prune_state_dict(state_dict, model_cfg) return super().load_state_dict(new_state_dict, strict)
[docs] def upgrade_state_dict(self, state_dict): """Upgrade old state dicts to work with newer code.""" self.upgrade_state_dict_named(state_dict, "")
[docs] def upgrade_state_dict_named(self, state_dict, name): """Upgrade old state dicts to work with newer code. Args: state_dict (dict): state dictionary to upgrade, in place name (str): the state dict key corresponding to the current module """ assert state_dict is not None def do_upgrade(m, prefix): if len(prefix) > 0: prefix += "." for n, c in m.named_children(): name = prefix + n if hasattr(c, "upgrade_state_dict_named"): c.upgrade_state_dict_named(state_dict, name) elif hasattr(c, "upgrade_state_dict"): c.upgrade_state_dict(state_dict) do_upgrade(c, name) do_upgrade(self, name)
[docs] def set_num_updates(self, num_updates): """State from trainer to pass along to model at every update.""" for m in self.modules(): if hasattr(m, "set_num_updates") and m != self: m.set_num_updates(num_updates)
[docs] def prepare_for_inference_(self, cfg: DictConfig): """Prepare model for inference.""" kwargs = {} kwargs["beamable_mm_beam_size"] = ( None if getattr(cfg.generation, "no_beamable_mm", False) else getattr(cfg.generation, "beam", 5) ) kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False) if getattr(cfg.generation, "retain_dropout", False): kwargs["retain_dropout"] = cfg.generation.retain_dropout kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules self.make_generation_fast_(**kwargs)
[docs] def make_generation_fast_(self, **kwargs): """ Legacy entry point to optimize model for faster generation. Prefer prepare_for_inference_. """ if self._is_generation_fast: return # only apply once self._is_generation_fast = True # remove weight norm from all modules in the network def apply_remove_weight_norm(module): try: nn.utils.remove_weight_norm(module) except (AttributeError, ValueError): # this module didn't have weight norm return self.apply(apply_remove_weight_norm) def apply_make_generation_fast_(module, prefix): if len(prefix) > 0: prefix += "." base_func = BaseFairseqModel.make_generation_fast_ for n, m in module.named_modules(): if ( m != self and hasattr(m, "make_generation_fast_") # don't call this implementation again, e.g., if # children modules also inherit from BaseFairseqModel and m.make_generation_fast_.__func__ is not base_func ): name = prefix + n m.make_generation_fast_(name=name, **kwargs) apply_make_generation_fast_(self, "") def train(mode=True): if mode: raise RuntimeError("cannot train after make_generation_fast") # this model should no longer be used for training self.eval() self.train = train
[docs] def prepare_for_onnx_export_(self, **kwargs): """Make model exportable via ONNX trace.""" seen = set() def apply_prepare_for_onnx_export_(module): if ( module != self and hasattr(module, "prepare_for_onnx_export_") and module not in seen ): seen.add(module) module.prepare_for_onnx_export_(**kwargs) self.apply(apply_prepare_for_onnx_export_)
[docs] @classmethod def from_pretrained( cls, model_name_or_path, checkpoint_file="model.pt", data_name_or_path=".", **kwargs, ): """ Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model file. Downloads and caches the pre-trained model file if needed. The base implementation returns a :class:`~fairseq.hub_utils.GeneratorHubInterface`, which can be used to generate translations or sample from language models. The underlying :class:`~fairseq.models.FairseqModel` can be accessed via the *generator.models* attribute. Other models may override this to implement custom hub interfaces. Args: model_name_or_path (str): either the name of a pre-trained model to load or a path/URL to a pre-trained model state dict checkpoint_file (str, optional): colon-separated list of checkpoint files in the model archive to ensemble (default: 'model.pt') data_name_or_path (str, optional): point args.data to the archive at the given path/URL. Can start with '.' or './' to reuse the model archive path. """ from fairseq import hub_utils x = hub_utils.from_pretrained( model_name_or_path, checkpoint_file, data_name_or_path, archive_map=cls.hub_models(), **kwargs, ) logger.info(x["args"]) return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"])
[docs] @classmethod def hub_models(cls): return {}
[docs]class FairseqEncoderDecoderModel(BaseFairseqModel): """Base class for encoder-decoder models. Args: encoder (FairseqEncoder): the encoder decoder (FairseqDecoder): the decoder """ def __init__(self, encoder, decoder): super().__init__() self.encoder = encoder self.decoder = decoder check_type(self.encoder, FairseqEncoder) check_type(self.decoder, FairseqDecoder)
[docs] def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): """ Run the forward pass for an encoder-decoder model. First feed a batch of source tokens through the encoder. Then, feed the encoder output and previous decoder outputs (i.e., teacher forcing) to the decoder to produce the next outputs:: encoder_out = self.encoder(src_tokens, src_lengths) return self.decoder(prev_output_tokens, encoder_out) Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (LongTensor): source sentence lengths of shape `(batch)` prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) decoder_out = self.decoder( prev_output_tokens, encoder_out=encoder_out, **kwargs ) return decoder_out
[docs] def forward_decoder(self, prev_output_tokens, **kwargs): return self.decoder(prev_output_tokens, **kwargs)
[docs] def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): """ Similar to *forward* but only return features. Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) features = self.decoder.extract_features( prev_output_tokens, encoder_out=encoder_out, **kwargs ) return features
[docs] def output_layer(self, features, **kwargs): """Project features to the default output size (typically vocabulary size).""" return self.decoder.output_layer(features, **kwargs)
[docs] def max_positions(self): """Maximum length supported by the model.""" return (self.encoder.max_positions(), self.decoder.max_positions())
[docs] def max_decoder_positions(self): """Maximum length supported by the decoder.""" return self.decoder.max_positions()
class FairseqModel(FairseqEncoderDecoderModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) utils.deprecation_warning( "FairseqModel is deprecated, please use FairseqEncoderDecoderModel " "or BaseFairseqModel instead", stacklevel=4, )
[docs]class FairseqMultiModel(BaseFairseqModel): """Base class for combining multiple encoder-decoder models.""" def __init__(self, encoders, decoders): super().__init__() assert encoders.keys() == decoders.keys() self.keys = list(encoders.keys()) for key in self.keys: check_type(encoders[key], FairseqEncoder) check_type(decoders[key], FairseqDecoder) self.models = nn.ModuleDict( { key: FairseqEncoderDecoderModel(encoders[key], decoders[key]) for key in self.keys } )
[docs] @staticmethod def build_shared_embeddings( dicts: Dict[str, Dictionary], langs: List[str], embed_dim: int, build_embedding: callable, pretrained_embed_path: Optional[str] = None, ): """ Helper function to build shared embeddings for a set of languages after checking that all dicts corresponding to those languages are equivalent. Args: dicts: Dict of lang_id to its corresponding Dictionary langs: languages that we want to share embeddings for embed_dim: embedding dimension build_embedding: callable function to actually build the embedding pretrained_embed_path: Optional path to load pretrained embeddings """ shared_dict = dicts[langs[0]] if any(dicts[lang] != shared_dict for lang in langs): raise ValueError( "--share-*-embeddings requires a joined dictionary: " "--share-encoder-embeddings requires a joined source " "dictionary, --share-decoder-embeddings requires a joined " "target dictionary, and --share-all-embeddings requires a " "joint source + target dictionary." ) return build_embedding(shared_dict, embed_dim, pretrained_embed_path)
[docs] def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): raise NotImplementedError
[docs] def max_positions(self): """Maximum length supported by the model.""" return { key: ( self.models[key].encoder.max_positions(), self.models[key].decoder.max_positions(), ) for key in self.keys }
[docs] def max_decoder_positions(self): """Maximum length supported by the decoder.""" return min(model.decoder.max_positions() for model in self.models.values())
@property def encoder(self): return self.models[self.keys[0]].encoder @property def decoder(self): return self.models[self.keys[0]].decoder
[docs] def forward_decoder(self, prev_output_tokens, **kwargs): return self.decoder(prev_output_tokens, **kwargs)
[docs] def load_state_dict( self, state_dict, strict=True, model_cfg=None, args: Optional[Namespace] = None, ): """Copies parameters and buffers from *state_dict* into this module and its descendants. Overrides the method in :class:`nn.Module`. Compared with that method this additionally "upgrades" *state_dicts* from old checkpoints. """ if model_cfg is None and args is not None: logger.warn( "using 'args' is deprecated, please update your code to use dataclass config" ) model_cfg = convert_namespace_to_omegaconf(args).model self.upgrade_state_dict(state_dict) from fairseq.checkpoint_utils import prune_state_dict new_state_dict = prune_state_dict(state_dict, model_cfg) return super().load_state_dict(new_state_dict, strict)
[docs]class FairseqLanguageModel(BaseFairseqModel): """Base class for decoder-only models. Args: decoder (FairseqDecoder): the decoder """ def __init__(self, decoder): super().__init__() self.decoder = decoder check_type(self.decoder, FairseqDecoder)
[docs] def forward(self, src_tokens, **kwargs): """ Run the forward pass for a decoder-only model. Feeds a batch of tokens through the decoder to predict the next tokens. Args: src_tokens (LongTensor): tokens on which to condition the decoder, of shape `(batch, tgt_len)` src_lengths (LongTensor): source sentence lengths of shape `(batch)` Returns: tuple: - the decoder's output of shape `(batch, seq_len, vocab)` - a dictionary with any model-specific outputs """ return self.decoder(src_tokens, **kwargs)
[docs] def forward_decoder(self, prev_output_tokens, **kwargs): return self.decoder(prev_output_tokens, **kwargs)
[docs] def extract_features(self, src_tokens, **kwargs): """ Similar to *forward* but only return features. Returns: tuple: - the decoder's features of shape `(batch, seq_len, embed_dim)` - a dictionary with any model-specific outputs """ return self.decoder.extract_features(src_tokens, **kwargs)
[docs] def output_layer(self, features, **kwargs): """Project features to the default output size (typically vocabulary size).""" return self.decoder.output_layer(features, **kwargs)
[docs] def max_positions(self): """Maximum length supported by the model.""" return self.decoder.max_positions()
[docs] def max_decoder_positions(self): """Maximum length supported by the decoder.""" return self.decoder.max_positions()
@property def supported_targets(self): return {"future"}
[docs]class FairseqEncoderModel(BaseFairseqModel): """Base class for encoder-only models. Args: encoder (FairseqEncoder): the encoder """ def __init__(self, encoder): super().__init__() self.encoder = encoder check_type(self.encoder, FairseqEncoder)
[docs] def forward(self, src_tokens, src_lengths, **kwargs): """ Run the forward pass for a encoder-only model. Feeds a batch of tokens through the encoder to generate features. Args: src_tokens (LongTensor): input tokens of shape `(batch, src_len)` src_lengths (LongTensor): source sentence lengths of shape `(batch)` Returns: the encoder's output, typically of shape `(batch, src_len, features)` """ return self.encoder(src_tokens, src_lengths, **kwargs)
[docs] def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" encoder_out = net_output["encoder_out"] if torch.is_tensor(encoder_out): logits = encoder_out.float() if log_probs: return F.log_softmax(logits, dim=-1) else: return F.softmax(logits, dim=-1) raise NotImplementedError
[docs] def max_positions(self): """Maximum length supported by the model.""" return self.encoder.max_positions()