Source code for fairseq.tasks.fairseq_task

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch

from fairseq import tokenizer
from import data_utils, FairseqDataset, iterators, Dictionary

[docs]class FairseqTask(object): """ Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss. """
[docs] @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" pass
def __init__(self, args): self.args = args self.datasets = {}
[docs] @classmethod def load_dictionary(cls, filename): """Load the dictionary from the filename Args: filename (str): the filename """ return Dictionary.load(filename)
[docs] @classmethod def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8): """Build the dictionary Args: filenames (list): list of filenames workers (int): number of concurrent workers threshold (int): defines the minimum word count nwords (int): defines the total number of words in the final dictionary, including special symbols padding_factor (int): can be used to pad the dictionary size to be a multiple of 8, which is important on some hardware (e.g., Nvidia Tensor Cores). """ d = Dictionary() for filename in filenames: Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers) d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) return d
[docs] @classmethod def setup_task(cls, args, **kwargs): """Setup the task (e.g., load dictionaries). Args: args (argparse.Namespace): parsed command-line arguments """ return cls(args, **kwargs)
[docs] def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ raise NotImplementedError
[docs] def dataset(self, split): """ Return a loaded dataset split. Args: split (str): name of the split (e.g., train, valid, test) Returns: a :class:`` corresponding to *split* """ from import FairseqDataset if split not in self.datasets: raise KeyError('Dataset not loaded: ' + split) if not isinstance(self.datasets[split], FairseqDataset): raise TypeError('Datasets are expected to be of type FairseqDataset') return self.datasets[split]
[docs] def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0, ): """ Get an iterator that yields batches of data from the given dataset. Args: dataset ( dataset to batch max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). max_positions (optional): max sentence length supported by the model (default: None). ignore_invalid_inputs (bool, optional): don't raise Exception for sentences that are too long (default: False). required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). seed (int, optional): seed for random number generator for reproducibility (default: 1). num_shards (int, optional): shard the data iterator into N shards (default: 1). shard_id (int, optional): which shard of the data iterator to return (default: 0). num_workers (int, optional): how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0). epoch (int, optional): the epoch to start the iterator from (default: 0). Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split """ assert isinstance(dataset, FairseqDataset) # get indices ordered by example size with data_utils.numpy_seed(seed): indices = dataset.ordered_indices() # filter examples that are too large if max_positions is not None: indices = data_utils.filter_by_size( indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs), ) # create mini-batches with given size constraints batch_sampler = data_utils.batch_by_size( indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, required_batch_size_multiple=required_batch_size_multiple, ) # return a reusable, sharded iterator return iterators.EpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, batch_sampler=batch_sampler, seed=seed, num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, epoch=epoch, )
[docs] def build_model(self, args): """ Build the :class:`~fairseq.models.BaseFairseqModel` instance for this task. Args: args (argparse.Namespace): parsed command-line arguments Returns: a :class:`~fairseq.models.BaseFairseqModel` instance """ from fairseq import models return models.build_model(args, self)
[docs] def build_criterion(self, args): """ Build the :class:`~fairseq.criterions.FairseqCriterion` instance for this task. Args: args (argparse.Namespace): parsed command-line arguments Returns: a :class:`~fairseq.criterions.FairseqCriterion` instance """ from fairseq import criterions return criterions.build_criterion(args, self)
[docs] def build_generator(self, args): if getattr(args, 'score_reference', False): from fairseq.sequence_scorer import SequenceScorer return SequenceScorer(self.target_dictionary) else: from fairseq.sequence_generator import SequenceGenerator return SequenceGenerator( self.target_dictionary, beam_size=getattr(args, 'beam', 5), max_len_a=getattr(args, 'max_len_a', 0), max_len_b=getattr(args, 'max_len_b', 200), min_len=getattr(args, 'min_len', 1), stop_early=(not getattr(args, 'no_early_stop', False)), normalize_scores=(not getattr(args, 'unnormalized', False)), len_penalty=getattr(args, 'lenpen', 1), unk_penalty=getattr(args, 'unkpen', 0), sampling=getattr(args, 'sampling', False), sampling_topk=getattr(args, 'sampling_topk', -1), sampling_topp=getattr(args, 'sampling_topp', -1.0), temperature=getattr(args, 'temperature', 1.), diverse_beam_groups=getattr(args, 'diverse_beam_groups', -1), diverse_beam_strength=getattr(args, 'diverse_beam_strength', 0.5), match_source_len=getattr(args, 'match_source_len', False), no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0), )
[docs] def train_step(self, sample, model, criterion, optimizer, ignore_grad=False): """ Do forward and backward, and return the loss as computed by *criterion* for the given *model* and *sample*. Args: sample (dict): the mini-batch. The format is defined by the :class:``. model (~fairseq.models.BaseFairseqModel): the model criterion (~fairseq.criterions.FairseqCriterion): the criterion optimizer (~fairseq.optim.FairseqOptimizer): the optimizer ignore_grad (bool): multiply loss by 0 if this is set to True Returns: tuple: - the loss - the sample size, which is used as the denominator for the gradient - logging outputs to display while training """ model.train() loss, sample_size, logging_output = criterion(model, sample) if ignore_grad: loss *= 0 optimizer.backward(loss) return loss, sample_size, logging_output
[docs] def valid_step(self, sample, model, criterion): model.eval() with torch.no_grad(): loss, sample_size, logging_output = criterion(model, sample) return loss, sample_size, logging_output
[docs] def inference_step(self, generator, models, sample, prefix_tokens=None): with torch.no_grad(): return generator.generate(models, sample, prefix_tokens=prefix_tokens)
[docs] def update_step(self, num_updates): """Task level update when number of update increases. This is called after optimization step and learning rate update of each step""" pass
[docs] def grad_denom(self, sample_sizes, criterion): return criterion.__class__.grad_denom(sample_sizes)
[docs] def aggregate_logging_outputs(self, logging_outputs, criterion): return criterion.__class__.aggregate_logging_outputs(logging_outputs)
[docs] def max_positions(self): """Return the max input length allowed by the task.""" return None
@property def source_dictionary(self): """Return the source :class:`` (if applicable for this task).""" raise NotImplementedError @property def target_dictionary(self): """Return the target :class:`` (if applicable for this task).""" raise NotImplementedError