Data Loading and Utilities

Datasets

Datasets define the data format and provide helpers for creating mini-batches.

class fairseq.data.FairseqDataset[source]

A dataset that provides helpers for batching.

batch_by_size(indices, max_tokens=None, max_sentences=None, required_batch_size_multiple=1)[source]

Given an ordered set of indices, return batches according to max_tokens, max_sentences and required_batch_size_multiple.

collater(samples)[source]

Merge a list of samples to form a mini-batch.

Parameters:samples (List[dict]) – samples to collate
Returns:a mini-batch suitable for forwarding with a Model
Return type:dict
filter_indices_by_size(indices, max_sizes)[source]

Filter a list of sample indices. Remove those that are longer than specified in max_sizes.

WARNING: don’t update, override method in child classes

Parameters:
  • indices (np.array) – original array of sample indices
  • max_sizes (int or list[int] or tuple[int]) – max sample size, can be defined separately for src and tgt (then list or tuple)
Returns:

filtered sample array list: list of removed indices

Return type:

np.array

get_batch_shapes()[source]

Return a list of valid batch shapes, for example:

[(8, 512), (16, 256), (32, 128)]

The first dimension of each tuple is the batch size and can be None to automatically infer the max batch size based on --max-tokens. The second dimension of each tuple is the max supported length as given by fairseq.data.FairseqDataset.num_tokens().

This will be used by fairseq.data.FairseqDataset.batch_by_size() to restrict batch shapes. This is useful on TPUs to avoid too many dynamic shapes (and recompilations).

num_tokens(index)[source]

Return the number of tokens in a sample. This value is used to enforce --max-tokens during batching.

num_tokens_vec(indices)[source]

Return the number of tokens for a set of positions defined by indices. This value is used to enforce --max-tokens during batching.

ordered_indices()[source]

Return an ordered list of indices. Batches will be constructed based on this order.

prefetch(indices)[source]

Prefetch the data required for this epoch.

size(index)[source]

Return an example’s size as a float or tuple. This value is used when filtering a dataset with --max-positions.

supports_fetch_outside_dataloader

Whether this dataset supports fetching outside the workers of the dataloader.

supports_prefetch

Whether this dataset supports prefetching.

class fairseq.data.LanguagePairDataset(src, src_sizes, src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, left_pad_source=True, left_pad_target=False, shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, align_dataset=None, constraints=None, append_bos=False, eos=None, num_buckets=0, src_lang_id=None, tgt_lang_id=None, pad_to_multiple=1)[source]

A pair of torch.utils.data.Datasets.

Parameters:
  • src (torch.utils.data.Dataset) – source dataset to wrap
  • src_sizes (List[int]) – source sentence lengths
  • src_dict (Dictionary) – source vocabulary
  • tgt (torch.utils.data.Dataset, optional) – target dataset to wrap
  • tgt_sizes (List[int], optional) – target sentence lengths
  • tgt_dict (Dictionary, optional) – target vocabulary
  • left_pad_source (bool, optional) – pad source tensors on the left side (default: True).
  • left_pad_target (bool, optional) – pad target tensors on the left side (default: False).
  • shuffle (bool, optional) – shuffle dataset elements before batching (default: True).
  • input_feeding (bool, optional) – create a shifted version of the targets to be passed into the model for teacher forcing (default: True).
  • remove_eos_from_source (bool, optional) – if set, removes eos from end of source if it’s present (default: False).
  • append_eos_to_target (bool, optional) – if set, appends eos to end of target if it’s absent (default: False).
  • align_dataset (torch.utils.data.Dataset, optional) – dataset containing alignments.
  • constraints (Tensor, optional) – 2d tensor with a concatenated, zero- delimited list of constraints for each sentence.
  • append_bos (bool, optional) – if set, appends bos to the beginning of source/target sentence.
  • num_buckets (int, optional) – if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes.
  • src_lang_id (int, optional) – source language ID, if set, the collated batch will contain a field ‘src_lang_id’ in ‘net_input’ which indicates the source language of the samples.
  • tgt_lang_id (int, optional) –

    target language ID, if set, the collated batch will contain a field ‘tgt_lang_id’ which indicates the target language

    of the samples.
collater(samples, pad_to_length=None)[source]

Merge a list of samples to form a mini-batch.

Parameters:
  • samples (List[dict]) – samples to collate
  • pad_to_length (dict, optional) – a dictionary of {‘source’: source_pad_to_length, ‘target’: target_pad_to_length} to indicate the max length to pad to in source and target respectively.
Returns:

a mini-batch with the following keys:

  • id (LongTensor): example IDs in the original input order
  • ntokens (int): total number of tokens in the batch
  • net_input (dict): the input to the Model, containing keys:
    • src_tokens (LongTensor): a padded 2D Tensor of tokens in the source sentence of shape (bsz, src_len). Padding will appear on the left if left_pad_source is True.
    • src_lengths (LongTensor): 1D Tensor of the unpadded lengths of each source sentence of shape (bsz)
    • prev_output_tokens (LongTensor): a padded 2D Tensor of tokens in the target sentence, shifted right by one position for teacher forcing, of shape (bsz, tgt_len). This key will not be present if input_feeding is False. Padding will appear on the left if left_pad_target is True.
    • src_lang_id (LongTensor): a long Tensor which contains source language IDs of each sample in the batch
  • target (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape (bsz, tgt_len). Padding will appear on the left if left_pad_target is True.
  • tgt_lang_id (LongTensor): a long Tensor which contains target language
    IDs of each sample in the batch

Return type:

dict

filter_indices_by_size(indices, max_sizes)[source]
Filter a list of sample indices. Remove those that are longer
than specified in max_sizes.
Parameters:
  • indices (np.array) – original array of sample indices
  • max_sizes (int or list[int] or tuple[int]) – max sample size, can be defined separately for src and tgt (then list or tuple)
Returns:

filtered sample array list: list of removed indices

Return type:

np.array

get_batch_shapes()[source]

Return a list of valid batch shapes, for example:

[(8, 512), (16, 256), (32, 128)]

The first dimension of each tuple is the batch size and can be None to automatically infer the max batch size based on --max-tokens. The second dimension of each tuple is the max supported length as given by fairseq.data.FairseqDataset.num_tokens().

This will be used by fairseq.data.FairseqDataset.batch_by_size() to restrict batch shapes. This is useful on TPUs to avoid too many dynamic shapes (and recompilations).

num_tokens(index)[source]

Return the number of tokens in a sample. This value is used to enforce --max-tokens during batching.

num_tokens_vec(indices)[source]

Return the number of tokens for a set of positions defined by indices. This value is used to enforce --max-tokens during batching.

ordered_indices()[source]

Return an ordered list of indices. Batches will be constructed based on this order.

prefetch(indices)[source]

Prefetch the data required for this epoch.

size(index)[source]

Return an example’s size as a float or tuple. This value is used when filtering a dataset with --max-positions.

supports_prefetch

Whether this dataset supports prefetching.

class fairseq.data.MonolingualDataset(dataset, sizes, src_vocab, tgt_vocab=None, add_eos_for_other_targets=False, shuffle=False, targets=None, add_bos_token=False, fixed_pad_length=None, pad_to_bsz=None, src_lang_idx=None, tgt_lang_idx=None)[source]

A wrapper around torch.utils.data.Dataset for monolingual data.

Parameters:
  • dataset (torch.utils.data.Dataset) – dataset to wrap
  • sizes (List[int]) – sentence lengths
  • vocab (Dictionary) – vocabulary
  • shuffle (bool, optional) – shuffle the elements before batching (default: True).
collater(samples)[source]

Merge a list of samples to form a mini-batch.

Parameters:samples (List[dict]) – samples to collate
Returns:a mini-batch with the following keys:
  • id (LongTensor): example IDs in the original input order
  • ntokens (int): total number of tokens in the batch
  • net_input (dict): the input to the Model, containing keys:
    • src_tokens (LongTensor): a padded 2D Tensor of tokens in the source sentence of shape (bsz, src_len). Padding will appear on the right.
  • target (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape (bsz, tgt_len). Padding will appear on the right.
Return type:dict
num_tokens(index)[source]

Return the number of tokens in a sample. This value is used to enforce --max-tokens during batching.

num_tokens_vec(indices)[source]

Return the number of tokens for a set of positions defined by indices. This value is used to enforce --max-tokens during batching.

ordered_indices()[source]

Return an ordered list of indices. Batches will be constructed based on this order.

prefetch(indices)[source]

Prefetch the data required for this epoch.

size(index)[source]

Return an example’s size as a float or tuple. This value is used when filtering a dataset with --max-positions.

supports_prefetch

Whether this dataset supports prefetching.

Helper Datasets

These datasets wrap other fairseq.data.FairseqDataset instances and provide additional functionality:

class fairseq.data.BacktranslationDataset(tgt_dataset, src_dict, tgt_dict=None, backtranslation_fn=None, output_collater=None, cuda=True, **kwargs)[source]

Sets up a backtranslation dataset which takes a tgt batch, generates a src using a tgt-src backtranslation function (backtranslation_fn), and returns the corresponding {generated src, input tgt} batch.

Parameters:
  • tgt_dataset (FairseqDataset) – the dataset to be backtranslated. Only the source side of this dataset will be used. After backtranslation, the source sentences in this dataset will be returned as the targets.
  • src_dict (Dictionary) – the dictionary of backtranslated sentences.
  • tgt_dict (Dictionary, optional) – the dictionary of sentences to be backtranslated.
  • backtranslation_fn (callable, optional) – function to call to generate backtranslations. This is typically the generate method of a SequenceGenerator object. Pass in None when it is not available at initialization time, and use set_backtranslation_fn function to set it when available.
  • output_collater (callable, optional) – function to call on the backtranslated samples to create the final batch (default: tgt_dataset.collater).
  • cuda – use GPU for generation
collater(samples)[source]

Merge and backtranslate a list of samples to form a mini-batch.

Using the samples from tgt_dataset, load a collated target sample to feed to the backtranslation model. Then take the backtranslation with the best score as the source and the original input as the target.

Note: we expect tgt_dataset to provide a function collater() that will collate samples into the format expected by backtranslation_fn. After backtranslation, we will feed the new list of samples (i.e., the (backtranslated source, original source) pairs) to output_collater and return the result.

Parameters:samples (List[dict]) – samples to backtranslate and collate
Returns:a mini-batch with keys coming from output_collater
Return type:dict
num_tokens(index)[source]

Just use the tgt dataset num_tokens

ordered_indices()[source]

Just use the tgt dataset ordered_indices

prefetch(indices)[source]

Prefetch the data required for this epoch.

size(index)[source]

Return an example’s size as a float or tuple. This value is used when filtering a dataset with --max-positions.

Note: we use tgt_dataset to approximate the length of the source sentence, since we do not know the actual length until after backtranslation.

supports_prefetch

Whether this dataset supports prefetching.

class fairseq.data.ConcatDataset(datasets, sample_ratios=1)[source]
can_reuse_epoch_itr_across_epochs

Whether we can reuse the fairseq.data.EpochBatchIterator for this dataset across epochs.

This needs to return False if the sample sizes can change across epochs, in which case we may need to regenerate batches at each epoch. If your dataset relies in set_epoch then you should consider setting this to False.

collater(samples, **extra_args)[source]

Merge a list of samples to form a mini-batch.

Parameters:samples (List[dict]) – samples to collate
Returns:a mini-batch suitable for forwarding with a Model
Return type:dict
num_tokens(index: int)[source]

Return the number of tokens in a sample. This value is used to enforce --max-tokens during batching.

ordered_indices()[source]

Returns indices sorted by length. So less padding is needed.

prefetch(indices)[source]

Prefetch the data required for this epoch.

set_epoch(epoch)[source]

Will receive the updated epoch number at the beginning of the epoch.

size(idx: int)[source]

Return an example’s size as a float or tuple.

supports_prefetch

Whether this dataset supports prefetching.

class fairseq.data.ResamplingDataset(dataset, weights=None, replace=True, size_ratio=1.0, batch_by_size=True, seed=0, epoch=1)[source]

Randomly samples from a given dataset at each epoch.

Sampling is done with or without replacement, depending on the “replace” parameter.

Optionally, the epoch size can be rescaled. This is potentially desirable to increase per-epoch coverage of the base dataset (since sampling with replacement means that many items in the dataset will be left out). In the case of sampling without replacement, size_ratio should be strictly less than 1.

Parameters:
  • dataset (Dataset) – dataset on which to sample.
  • weights (List[float]) – list of probability weights (default: None, which corresponds to uniform sampling).
  • replace (bool) – sampling mode; True for “with replacement”, or False for “without replacement” (default: True)
  • size_ratio (float) – the ratio to subsample to; must be positive (default: 1.0).
  • batch_by_size (bool) – whether or not to batch by sequence length (default: True).
  • seed (int) – RNG seed to use (default: 0).
  • epoch (int) – starting epoch number (default: 1).
can_reuse_epoch_itr_across_epochs

Whether we can reuse the fairseq.data.EpochBatchIterator for this dataset across epochs.

This needs to return False if the sample sizes can change across epochs, in which case we may need to regenerate batches at each epoch. If your dataset relies in set_epoch then you should consider setting this to False.

num_tokens(index)[source]

Return the number of tokens in a sample. This value is used to enforce --max-tokens during batching.

ordered_indices()[source]

Return an ordered list of indices. Batches will be constructed based on this order.

prefetch(indices)[source]

Prefetch the data required for this epoch.

set_epoch(epoch)[source]

Will receive the updated epoch number at the beginning of the epoch.

size(index)[source]

Return an example’s size as a float or tuple. This value is used when filtering a dataset with --max-positions.

class fairseq.data.RoundRobinZipDatasets(datasets, eval_key=None)[source]

Zip multiple FairseqDataset instances together.

Shorter datasets are repeated in a round-robin fashion to match the length of the longest one.

Parameters:
  • datasets (Dict[FairseqDataset]) – a dictionary of FairseqDataset instances.
  • eval_key (str, optional) – a key used at evaluation time that causes this instance to pass-through batches from datasets[eval_key].
collater(samples)[source]

Merge a list of samples to form a mini-batch.

filter_indices_by_size(indices, max_positions=None)[source]

Filter each sub-dataset independently, then update the round robin to work on the filtered sub-datasets.

num_tokens(index)[source]

Return an example’s length (number of tokens), used for batching.

ordered_indices()[source]

Ordered indices for batching.

prefetch(indices)[source]

Prefetch the data required for this epoch.

size(index)[source]

Return an example’s size as a float or tuple. This value is used when filtering a dataset with --max-positions.

supports_prefetch

Whether this dataset supports prefetching.

class fairseq.data.TransformEosDataset(dataset, eos, append_eos_to_src=False, remove_eos_from_src=False, append_eos_to_tgt=False, remove_eos_from_tgt=False, has_target=True)[source]

A FairseqDataset wrapper that appends/prepends/strips EOS.

Note that the transformation is applied in collater().

Parameters:
  • dataset (FairseqDataset) – dataset to wrap
  • eos (int) – index of the end-of-sentence symbol
  • append_eos_to_src (bool, optional) – append EOS to the end of src
  • remove_eos_from_src (bool, optional) – remove EOS from the end of src
  • append_eos_to_tgt (bool, optional) – append EOS to the end of tgt
  • remove_eos_from_tgt (bool, optional) – remove EOS from the end of tgt
collater(samples)[source]

Merge a list of samples to form a mini-batch.

Parameters:samples (List[dict]) – samples to collate
Returns:a mini-batch suitable for forwarding with a Model
Return type:dict
num_tokens(index)[source]

Return the number of tokens in a sample. This value is used to enforce --max-tokens during batching.

ordered_indices()[source]

Return an ordered list of indices. Batches will be constructed based on this order.

prefetch(indices)[source]

Prefetch the data required for this epoch.

size(index)[source]

Return an example’s size as a float or tuple. This value is used when filtering a dataset with --max-positions.

supports_prefetch

Whether this dataset supports prefetching.

Dictionary

class fairseq.data.Dictionary(*, bos='<s>', pad='<pad>', eos='</s>', unk='<unk>', extra_special_symbols=None)[source]

A mapping from symbols to consecutive integers

add_from_file(f)[source]

Loads a pre-existing dictionary from a text file and adds its symbols to this instance.

add_symbol(word, n=1, overwrite=False)[source]

Adds a word to the dictionary

bos()[source]

Helper to get index of beginning-of-sentence symbol

eos()[source]

Helper to get index of end-of-sentence symbol

finalize(threshold=-1, nwords=-1, padding_factor=8)[source]

Sort symbols by frequency in descending order, ignoring special ones.

Parameters:
  • threshold defines the minimum word count (-) –
  • nwords defines the total number of words in the final dictionary, (-) – including special symbols
  • padding_factor can be used to pad the dictionary size to be a (-) – multiple of 8, which is important on some hardware (e.g., Nvidia Tensor Cores).
index(sym)[source]

Returns the index of the specified symbol

classmethod load(f)[source]

Loads the dictionary from a text file with the format:

` <symbol0> <count0> <symbol1> <count1> ... `

pad()[source]

Helper to get index of pad symbol

pad_to_multiple_(padding_factor)[source]

Pad Dictionary size to be a multiple of padding_factor.

save(f)[source]

Stores dictionary into a text file

string(tensor, bpe_symbol=None, escape_unk=False, extra_symbols_to_ignore=None, unk_string=None, include_eos=False, separator=' ')[source]

Helper for converting a tensor of token indices to a string.

Can optionally remove BPE symbols or escape <unk> words.

unk()[source]

Helper to get index of unk symbol

unk_string(escape=False)[source]

Return unknown string, optionally escaped as: <<unk>>

update(new_dict)[source]

Updates counts from new dictionary.

Iterators

class fairseq.data.CountingIterator(iterable, start=None, total=None)[source]

Wrapper around an iterable that maintains the iteration count.

Parameters:
  • iterable (iterable) – iterable to wrap
  • start (int) – starting iteration count. Note that this doesn’t actually advance the iterator.
  • total (int) – override the iterator length returned by __len. This can be used to truncate iterator.
n

number of elements consumed from this iterator

Type:int
has_next()[source]

Whether the iterator has been exhausted.

skip(n)[source]

Fast-forward the iterator by skipping n elements.

take(n)[source]

Truncate the iterator to n elements at most.

class fairseq.data.EpochBatchIterator(dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1, buffer_size=0, timeout=0, disable_shuffling=False, skip_remainder_batch=False, grouped_shuffling=False, reuse_dataloader=False, persistent_workers=False)[source]

A multi-epoch iterator over a torch.utils.data.Dataset.

Compared to torch.utils.data.DataLoader, this iterator:

  • can be reused across multiple epochs with the next_epoch_itr() method (optionally shuffled between epochs)
  • can be serialized/deserialized with the state_dict() and load_state_dict() methods
  • supports sharding with the num_shards and shard_id arguments
Parameters:
  • dataset (Dataset) – dataset from which to load the data
  • collate_fn (callable) – merges a list of samples to form a mini-batch
  • batch_sampler (Sampler or a callable) – an iterator over batches of indices, or a callable to create such an iterator (~torch.utils.data.Sampler). A callable batch_sampler will be called for each epoch to enable per epoch dynamic batch iterators defined by this callable batch_sampler.
  • seed (int, optional) – seed for random number generator for reproducibility (default: 1).
  • num_shards (int, optional) – shard the data iterator into N shards (default: 1).
  • shard_id (int, optional) – which shard of the data iterator to return (default: 0).
  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0).
  • epoch (int, optional) – the epoch to start the iterator from (default: 1).
  • buffer_size (int, optional) – the number of batches to keep ready in the queue. Helps speeding up dataloading. When buffer_size is zero, the default torch.utils.data.DataLoader preloading is used.
  • timeout (int, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative (default: 0).
  • disable_shuffling (bool, optional) – force disable shuffling (default: False).
  • skip_remainder_batch (bool, optional) –

    if set, discard the last batch in an epoch for the sake of training stability, as the last batch is usually smaller than

    local_batch_size * distributed_word_size (default: False).
  • grouped_shuffling (bool, optional) – enable shuffling batches in groups of num_shards. Ensures that each GPU receives similar length sequences when batches are sorted by length.
end_of_epoch() → bool[source]

Returns whether the most recent epoch iterator has been exhausted

iterations_in_epoch

The number of consumed batches in the current epoch.

load_state_dict(state_dict)[source]

Copies the state of the iterator from the given state_dict.

next_epoch_idx

Return the epoch index after next_epoch_itr is called.

next_epoch_itr(shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True)[source]

Return a new iterator over the dataset.

Parameters:
  • shuffle (bool, optional) – shuffle batches before returning the iterator (default: True).
  • fix_batches_to_gpus (bool, optional) – ensure that batches are always allocated to the same shards across epochs. Requires that dataset supports prefetching (default: False).
  • set_dataset_epoch (bool, optional) – update the wrapped Dataset with the new epoch number (default: True).
state_dict()[source]

Returns a dictionary containing a whole state of the iterator.

class fairseq.data.GroupedIterator(iterable, chunk_size, skip_remainder_batch=False)[source]

Wrapper around an iterable that returns groups (chunks) of items.

Parameters:
  • iterable (iterable) – iterable to wrap
  • chunk_size (int) – size of each chunk
  • skip_remainder_batch (bool, optional) –

    if set, discard the last grouped batch in each training epoch, as the last grouped batch is usually smaller than

    local_batch_size * distributed_word_size * chunk_size (default: False).
n

number of elements consumed from this iterator

Type:int
class fairseq.data.ShardedIterator(iterable, num_shards, shard_id, fill_value=None, skip_remainder_batch=None)[source]

A sharded wrapper around an iterable, padded to length.

Parameters:
  • iterable (iterable) – iterable to wrap
  • num_shards (int) – number of shards to split the iterable into
  • shard_id (int) – which shard to iterator over
  • fill_value (Any, optional) – padding value when the iterable doesn’t evenly divide num_shards (default: None).
n

number of elements consumed from this iterator

Type:int