Source code for fairseq.criterions.fairseq_criterion

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

from torch.nn.modules.loss import _Loss


[docs]class FairseqCriterion(_Loss): def __init__(self, args, task): super().__init__() self.args = args self.padding_idx = task.target_dictionary.pad() if task.target_dictionary is not None else -100
[docs] @staticmethod def add_args(parser): """Add criterion-specific arguments to the parser.""" pass
[docs] @classmethod def build_criterion(cls, args, task): return cls(args, task)
[docs] def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ raise NotImplementedError
[docs] @staticmethod def aggregate_logging_outputs(logging_outputs): """Aggregate logging outputs from data parallel training.""" raise NotImplementedError
[docs] @staticmethod def grad_denom(sample_sizes): """Compute the gradient denominator for a set of sample sizes.""" return sum(sample_sizes)