Tasks

Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss.

Tasks can be selected via the --task command-line argument. Once selected, a task may expose additional command-line arguments for further configuration.

Example usage:

# setup the task (e.g., load dictionaries)
task = fairseq.tasks.setup_task(args)

# build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)

# load datasets
task.load_dataset('train')
task.load_dataset('valid')

# iterate over mini-batches of data
batch_itr = task.get_batch_iterator(
    task.dataset('train'), max_tokens=4096,
)
for batch in batch_itr:
    # compute the loss
    loss, sample_size, logging_output = task.get_loss(
        model, criterion, batch,
    )
    loss.backward()

Translation

class fairseq.tasks.translation.TranslationTask(args, src_dict, tgt_dict)[source]

Translate from one (source) language to another (target) language.

Parameters:
  • src_dict (Dictionary) – dictionary for the source language
  • tgt_dict (Dictionary) – dictionary for the target language

Note

The translation task is compatible with fairseq-train, fairseq-generate and fairseq-interactive.

The translation task provides the following additional command-line arguments:

usage:  [--task translation] [-s SRC] [-t TARGET] [--lazy-load] [--raw-text]
        [--load-alignments] [--left-pad-source BOOL] [--left-pad-target BOOL]
        [--max-source-positions N] [--max-target-positions N]
        [--upsample-primary UPSAMPLE_PRIMARY] [--truncate-source]
        data

Task name

--task Enable this task with: --task=translation

Additional command-line arguments

data colon separated path to data directories list, will be iterated upon during epochs in round-robin manner
-s, --source-lang source language
-t, --target-lang target language
--lazy-load

load the dataset lazily

Default: False

--raw-text

load raw text dataset

Default: False

--load-alignments

load the binarized alignments

Default: False

--left-pad-source

pad the source on the left

Default: “True”

--left-pad-target

pad the target on the left

Default: “False”

--max-source-positions

max number of tokens in the source sequence

Default: 1024

--max-target-positions

max number of tokens in the target sequence

Default: 1024

--upsample-primary

amount to upsample primary dataset

Default: 1

--truncate-source

boolean to truncate source to max-source-positions

Default: False

Language Modeling

class fairseq.tasks.language_modeling.LanguageModelingTask(args, dictionary, output_dictionary=None, targets=None)[source]

Train a language model.

Parameters:
  • dictionary (Dictionary) – the dictionary for the input of the language model
  • output_dictionary (Dictionary) – the dictionary for the output of the language model. In most cases it will be the same as dictionary, but could possibly be a more limited version of the dictionary (if --output-dictionary-size is used).
  • targets (List[str]) – list of the target types that the language model should predict. Can be one of “self”, “future”, and “past”. Defaults to “future”.

Note

The language modeling task is compatible with fairseq-train, fairseq-generate, fairseq-interactive and fairseq-eval-lm.

The language modeling task provides the following additional command-line arguments:

usage:  [--task language_modeling]
        [--sample-break-mode {none,complete,complete_doc,eos}]
        [--tokens-per-sample TOKENS_PER_SAMPLE] [--lazy-load] [--raw-text]
        [--output-dictionary-size OUTPUT_DICTIONARY_SIZE] [--self-target]
        [--future-target] [--past-target] [--add-bos-token]
        [--max-target-positions N]
        data

Task name

--task Enable this task with: --task=language_modeling

Additional command-line arguments

data path to data directory
--sample-break-mode

Possible choices: none, complete, complete_doc, eos

If omitted or “none”, fills each sample with tokens-per-sample tokens. If set to “complete”, splits samples only at the end of sentence, but may include multiple sentences per sample. “complete_doc” is similar but respects doc boundaries. If set to “eos”, includes only one sentence per sample.

Default: “none”

--tokens-per-sample

max number of tokens per sample for LM dataset

Default: 1024

--lazy-load

load the dataset lazily

Default: False

--raw-text

load raw text dataset

Default: False

--output-dictionary-size

limit the size of output dictionary

Default: -1

--self-target

include self target

Default: False

--future-target

include future target

Default: False

--past-target

include past target

Default: False

--add-bos-token

prepend beginning of sentence token (<s>)

Default: False

--max-target-positions max number of tokens in the target sequence

Adding new tasks

fairseq.tasks.register_task(name)[source]

New tasks can be added to fairseq with the register_task() function decorator.

For example:

@register_task('classification')
class ClassificationTask(FairseqTask):
    (...)

Note

All Tasks must implement the FairseqTask interface.

Please see the

Parameters:name (str) – the name of the task
class fairseq.tasks.FairseqTask(args)[source]

Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss.

static add_args(parser)[source]

Add task-specific arguments to the parser.

aggregate_logging_outputs(logging_outputs, criterion)[source]
build_criterion(args)[source]

Build the FairseqCriterion instance for this task.

Parameters:args (argparse.Namespace) – parsed command-line arguments
Returns:a FairseqCriterion instance
classmethod build_dictionary(filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8)[source]

Build the dictionary

Parameters:
  • filenames (list) – list of filenames
  • workers (int) – number of concurrent workers
  • threshold (int) – defines the minimum word count
  • nwords (int) – defines the total number of words in the final dictionary, including special symbols
  • padding_factor (int) – can be used to pad the dictionary size to be a multiple of 8, which is important on some hardware (e.g., Nvidia Tensor Cores).
build_generator(args)[source]
build_model(args)[source]

Build the BaseFairseqModel instance for this task.

Parameters:args (argparse.Namespace) – parsed command-line arguments
Returns:a BaseFairseqModel instance
dataset(split)[source]

Return a loaded dataset split.

Parameters:split (str) – name of the split (e.g., train, valid, test)
Returns:a FairseqDataset corresponding to split
get_batch_iterator(dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0)[source]

Get an iterator that yields batches of data from the given dataset.

Parameters:
  • dataset (FairseqDataset) – dataset to batch
  • max_tokens (int, optional) – max number of tokens in each batch (default: None).
  • max_sentences (int, optional) – max number of sentences in each batch (default: None).
  • max_positions (optional) – max sentence length supported by the model (default: None).
  • ignore_invalid_inputs (bool, optional) – don’t raise Exception for sentences that are too long (default: False).
  • required_batch_size_multiple (int, optional) – require batch size to be a multiple of N (default: 1).
  • seed (int, optional) – seed for random number generator for reproducibility (default: 1).
  • num_shards (int, optional) – shard the data iterator into N shards (default: 1).
  • shard_id (int, optional) – which shard of the data iterator to return (default: 0).
  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0).
  • epoch (int, optional) – the epoch to start the iterator from (default: 0).
Returns:

a batched iterator over the

given dataset split

Return type:

EpochBatchIterator

grad_denom(sample_sizes, criterion)[source]
inference_step(generator, models, sample, prefix_tokens=None)[source]
load_dataset(split, combine=False, **kwargs)[source]

Load a given dataset split.

Parameters:split (str) – name of the split (e.g., train, valid, test)
classmethod load_dictionary(filename)[source]

Load the dictionary from the filename

Parameters:filename (str) – the filename
max_positions()[source]

Return the max input length allowed by the task.

classmethod setup_task(args, **kwargs)[source]

Setup the task (e.g., load dictionaries).

Parameters:args (argparse.Namespace) – parsed command-line arguments
source_dictionary

Return the source Dictionary (if applicable for this task).

target_dictionary

Return the target Dictionary (if applicable for this task).

train_step(sample, model, criterion, optimizer, ignore_grad=False)[source]

Do forward and backward, and return the loss as computed by criterion for the given model and sample.

Parameters:
Returns:

  • the loss
  • the sample size, which is used as the denominator for the gradient
  • logging outputs to display while training

Return type:

tuple

update_step(num_updates)[source]

Task level update when number of update increases. This is called after optimization step and learning rate update of each step

valid_step(sample, model, criterion)[source]