Source code for fairseq.criterions.fairseq_criterion

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import inspect
from typing import Any, Dict, List

from fairseq import metrics, utils
from fairseq.dataclass.utils import gen_parser_from_dataclass
from torch.nn.modules.loss import _Loss

[docs]class FairseqCriterion(_Loss): def __init__(self, task): super().__init__() self.task = task if hasattr(task, "target_dictionary"): tgt_dict = task.target_dictionary self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100
[docs] @classmethod def add_args(cls, parser): """Add criterion-specific arguments to the parser.""" dc = getattr(cls, "__dataclass", None) if dc is not None: gen_parser_from_dataclass(parser, dc())
[docs] @classmethod def build_criterion(cls, args, task): """Construct a criterion from command-line args.""" # Criterions can override this, but for convenience we also try # to automatically map argparse.Namespace keys to corresponding # arguments in the __init__. init_args = {} for p in inspect.signature(cls).parameters.values(): if ( p.kind == p.POSITIONAL_ONLY or p.kind == p.VAR_POSITIONAL or p.kind == p.VAR_KEYWORD ): # we haven't implemented inference for these argument types, # but PRs welcome :) raise NotImplementedError("{} not supported".format(p.kind)) assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} if == "task": init_args["task"] = task elif hasattr(args, init_args[] = getattr(args, elif p.default != p.empty: pass # we'll use the default value else: raise NotImplementedError( "Unable to infer Criterion arguments, please implement " "{}.build_criterion".format(cls.__name__) ) return cls(**init_args)
[docs] def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ raise NotImplementedError
[docs] @staticmethod def aggregate_logging_outputs( logging_outputs: List[Dict[str, Any]], ) -> Dict[str, Any]: """Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( "The aggregate_logging_outputs API is deprecated. " "Please use the reduce_metrics API instead." ) raise NotImplementedError
[docs] @classmethod def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: """Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( "Criterions should implement the reduce_metrics API. " "Falling back to deprecated aggregate_logging_outputs API." ) agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs) for k, v in agg_logging_outputs.items(): if k in {"nsentences", "ntokens", "sample_size"}: continue metrics.log_scalar(k, v)
[docs] @staticmethod def logging_outputs_can_be_summed() -> bool: """ Whether the logging outputs returned by `forward` can be summed across workers prior to calling `reduce_metrics`. Setting this to True will improves distributed training speed. """ return False
class LegacyFairseqCriterion(FairseqCriterion): def __init__(self, args, task): super().__init__(task=task) self.args = args utils.deprecation_warning( "Criterions should take explicit arguments instead of an " "argparse.Namespace object, please update your criterion by " "extending FairseqCriterion instead of LegacyFairseqCriterion." ) @classmethod def build_criterion(cls, args, task): """Construct a criterion from command-line args.""" return cls(args, task)