Tasks

Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss.

Tasks can be selected via the --task command-line argument. Once selected, a task may expose additional command-line arguments for further configuration.

Example usage:

# setup the task (e.g., load dictionaries)
task = fairseq.tasks.setup_task(args)

# build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)

# load datasets
task.load_dataset('train')
task.load_dataset('valid')

# iterate over mini-batches of data
batch_itr = task.get_batch_iterator(
    task.dataset('train'), max_tokens=4096,
)
for batch in batch_itr:
    # compute the loss
    loss, sample_size, logging_output = task.get_loss(
        model, criterion, batch,
    )
    loss.backward()

Translation

class fairseq.tasks.translation.TranslationTask(cfg: fairseq.tasks.translation.TranslationConfig, src_dict, tgt_dict)[source]

Translate from one (source) language to another (target) language.

Parameters:
  • src_dict (Dictionary) – dictionary for the source language
  • tgt_dict (Dictionary) – dictionary for the target language

Note

The translation task is compatible with fairseq-train, fairseq-generate and fairseq-interactive.

Language Modeling

class fairseq.tasks.language_modeling.LanguageModelingTask(args, dictionary, output_dictionary=None, targets=None)[source]

Train a language model.

Parameters:
  • dictionary (Dictionary) – the dictionary for the input of the language model
  • output_dictionary (Dictionary) – the dictionary for the output of the language model. In most cases it will be the same as dictionary, but could possibly be a more limited version of the dictionary (if --output-dictionary-size is used).
  • targets (List[str]) – list of the target types that the language model should predict. Can be one of “self”, “future”, and “past”. Defaults to “future”.

Note

The language modeling task is compatible with fairseq-train, fairseq-generate, fairseq-interactive and fairseq-eval-lm.

The language modeling task provides the following additional command-line arguments:

usage:  [--task language_modeling]
        [--sample-break-mode {none,complete,complete_doc,eos}]
        [--tokens-per-sample TOKENS_PER_SAMPLE]
        [--output-dictionary-size OUTPUT_DICTIONARY_SIZE] [--self-target]
        [--future-target] [--past-target] [--add-bos-token]
        [--max-target-positions MAX_TARGET_POSITIONS]
        [--shorten-method {none,truncate,random_crop}]
        [--shorten-data-split-list SHORTEN_DATA_SPLIT_LIST]
        [--pad-to-fixed-length] [--pad-to-fixed-bsz]
        data

Task name

--task Enable this task with: --task=language_modeling

Additional command-line arguments

data path to data directory
--sample-break-mode

Possible choices: none, complete, complete_doc, eos

If omitted or “none”, fills each sample with tokens-per-sample tokens. If set to “complete”, splits samples only at the end of sentence, but may include multiple sentences per sample. “complete_doc” is similar but respects doc boundaries. If set to “eos”, includes only one sentence per sample.

Default: “none”

--tokens-per-sample

max number of tokens per sample for LM dataset

Default: 1024

--output-dictionary-size

limit the size of output dictionary

Default: -1

--self-target

include self target

Default: False

--future-target

include future target

Default: False

--past-target

include past target

Default: False

--add-bos-token

prepend beginning of sentence token (<s>)

Default: False

--max-target-positions max number of tokens in the target sequence
--shorten-method

Possible choices: none, truncate, random_crop

if not none, shorten sequences that exceed –tokens-per-sample

Default: “none”

--shorten-data-split-list

comma-separated list of dataset splits to apply shortening to, e.g., “train,valid” (default: all dataset splits)

Default: “”

--pad-to-fixed-length

pad to fixed length

Default: False

--pad-to-fixed-bsz

boolean to pad to fixed batch size

Default: False

Adding new tasks

fairseq.tasks.register_task(name, dataclass=None)[source]

New tasks can be added to fairseq with the register_task() function decorator.

For example:

@register_task('classification')
class ClassificationTask(FairseqTask):
    (...)

Note

All Tasks must implement the FairseqTask interface.

Parameters:name (str) – the name of the task
class fairseq.tasks.FairseqTask(cfg: fairseq.dataclass.configs.FairseqDataclass, **kwargs)[source]

Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss.

Tasks have limited statefulness. In particular, state that needs to be saved to/loaded from checkpoints needs to be stored in the self.state StatefulContainer object. For example:

self.state.add_factory("dictionary", self.load_dictionary)
print(self.state.dictionary)  # calls self.load_dictionary()

This is necessary so that when loading checkpoints, we can properly recreate the task state after initializing the task instance.

classmethod add_args(parser)[source]

Add task-specific arguments to the parser.

aggregate_logging_outputs(logging_outputs, criterion)[source]

[deprecated] Aggregate logging outputs from data parallel training.

begin_epoch(epoch, model)[source]

Hook function called before the start of each epoch.

begin_valid_epoch(epoch, model)[source]

Hook function called before the start of each validation epoch.

build_bpe(args)[source]

Build the tokenizer for this task.

build_criterion(cfg: omegaconf.dictconfig.DictConfig)[source]

Build the FairseqCriterion instance for this task.

Parameters:cfg (omegaconf.DictConfig) – configration object
Returns:a FairseqCriterion instance
build_dataset_for_inference(src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs) → torch.utils.data.dataset.Dataset[source]
classmethod build_dictionary(filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8)[source]

Build the dictionary

Parameters:
  • filenames (list) – list of filenames
  • workers (int) – number of concurrent workers
  • threshold (int) – defines the minimum word count
  • nwords (int) – defines the total number of words in the final dictionary, including special symbols
  • padding_factor (int) – can be used to pad the dictionary size to be a multiple of 8, which is important on some hardware (e.g., Nvidia Tensor Cores).
build_generator(models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None)[source]

Build a SequenceGenerator instance for this task.

Parameters:
  • models (List[FairseqModel]) – ensemble of models
  • args (fairseq.dataclass.configs.GenerationConfig) – configuration object (dataclass) for generation
  • extra_gen_cls_kwargs (Dict[str, Any]) – extra options to pass through to SequenceGenerator
  • prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]) – If provided, this function constrains the beam search to allowed tokens only at each step. The provided function should take 2 arguments: the batch ID (batch_id: int) and a unidimensional tensor of token ids (inputs_ids: torch.Tensor). It has to return a List[int] with the allowed tokens for the next generation step conditioned on the previously generated tokens (inputs_ids) and the batch ID (batch_id). This argument is useful for constrained generation conditioned on the prefix, as described in “Autoregressive Entity Retrieval” (https://arxiv.org/abs/2010.00904) and https://github.com/facebookresearch/GENRE.
build_model(cfg: fairseq.dataclass.configs.FairseqDataclass, from_checkpoint=False)[source]

Build the BaseFairseqModel instance for this task.

Parameters:cfg (FairseqDataclass) – configuration object
Returns:a BaseFairseqModel instance
build_tokenizer(args)[source]

Build the pre-tokenizer for this task.

can_reuse_epoch_itr(dataset)[source]
dataset(split)[source]

Return a loaded dataset split.

Parameters:split (str) – name of the split (e.g., train, valid, test)
Returns:a FairseqDataset corresponding to split
filter_indices_by_size(indices, dataset, max_positions=None, ignore_invalid_inputs=False)[source]

Filter examples that are too large

Parameters:
  • indices (np.array) – original array of sample indices
  • dataset (FairseqDataset) – dataset to batch
  • max_positions (optional) – max sentence length supported by the model (default: None).
  • ignore_invalid_inputs (bool, optional) – don’t raise Exception for sentences that are too long (default: False).
Returns:

array of filtered sample indices

Return type:

np.array

get_batch_iterator(dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1, data_buffer_size=0, disable_iterator_cache=False, skip_remainder_batch=False, grouped_shuffling=False, update_epoch_batch_itr=False)[source]

Get an iterator that yields batches of data from the given dataset.

Parameters:
  • dataset (FairseqDataset) – dataset to batch
  • max_tokens (int, optional) – max number of tokens in each batch (default: None).
  • max_sentences (int, optional) – max number of sentences in each batch (default: None).
  • max_positions (optional) – max sentence length supported by the model (default: None).
  • ignore_invalid_inputs (bool, optional) – don’t raise Exception for sentences that are too long (default: False).
  • required_batch_size_multiple (int, optional) – require batch size to be a multiple of N (default: 1).
  • seed (int, optional) – seed for random number generator for reproducibility (default: 1).
  • num_shards (int, optional) – shard the data iterator into N shards (default: 1).
  • shard_id (int, optional) – which shard of the data iterator to return (default: 0).
  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0).
  • epoch (int, optional) – the epoch to start the iterator from (default: 1).
  • data_buffer_size (int, optional) – number of batches to preload (default: 0).
  • disable_iterator_cache (bool, optional) – don’t cache the EpochBatchIterator (ignores FairseqTask::can_reuse_epoch_itr) (default: False).
  • skip_remainder_batch (bool, optional) –

    if set, discard the last batch in each training epoch, as the last batch is often smaller than

    local_batch_size * distributed_word_size (default: True).
  • grouped_shuffling (bool, optional) – group batches with each groups containing num_shards batches and shuffle groups. Reduces difference between sequence lengths among workers for batches sorted by length.
  • update_epoch_batch_itr (bool optional) – if true then donot use the cached batch iterator for the epoch
Returns:

a batched iterator over the

given dataset split

Return type:

EpochBatchIterator

get_interactive_tokens_and_lengths(lines, encode_fn)[source]
has_sharded_data(split)[source]
inference_step(generator, models, sample, prefix_tokens=None, constraints=None)[source]
load_dataset(split: str, combine: bool = False, task_cfg: fairseq.dataclass.configs.FairseqDataclass = None, **kwargs)[source]

Load a given dataset split.

Parameters:
  • split (str) – name of the split (e.g., train, valid, test)
  • combine (bool) – combines a split segmented into pieces into one dataset
  • task_cfg (FairseqDataclass) – optional task configuration stored in the checkpoint that can be used to load datasets
classmethod load_dictionary(filename)[source]

Load the dictionary from the filename

Parameters:filename (str) – the filename
load_state_dict(state_dict: Dict[str, Any])[source]
static logging_outputs_can_be_summed(criterion) → bool[source]

Whether the logging outputs returned by train_step and valid_step can be summed across workers prior to calling aggregate_logging_outputs. Setting this to True will improves distributed training speed.

max_positions()[source]

Return the max input length allowed by the task.

optimizer_step(optimizer, model, update_num)[source]
reduce_metrics(logging_outputs, criterion)[source]

Aggregate logging outputs from data parallel training.

classmethod setup_task(cfg: omegaconf.dictconfig.DictConfig, **kwargs)[source]

Setup the task (e.g., load dictionaries).

Parameters:cfg (omegaconf.DictConfig) – parsed command-line arguments
source_dictionary

Return the source Dictionary (if applicable for this task).

state_dict()[source]
target_dictionary

Return the target Dictionary (if applicable for this task).

train_step(sample, model, criterion, optimizer, update_num, ignore_grad=False)[source]

Do forward and backward, and return the loss as computed by criterion for the given model and sample.

Parameters:
Returns:

  • the loss
  • the sample size, which is used as the denominator for the gradient
  • logging outputs to display while training

Return type:

tuple

valid_step(sample, model, criterion)[source]