# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch.nn.modules.loss import _Loss
[docs]class FairseqCriterion(_Loss):
def __init__(self, args, task):
super().__init__()
self.args = args
self.padding_idx = task.target_dictionary.pad() if task.target_dictionary is not None else -100
[docs] @staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
pass
[docs] @classmethod
def build_criterion(cls, args, task):
return cls(args, task)
[docs] def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
raise NotImplementedError
[docs] @staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
raise NotImplementedError
[docs] @staticmethod
def grad_denom(sample_sizes):
"""Compute the gradient denominator for a set of sample sizes."""
return sum(sample_sizes)