Source code for fairseq.data.fairseq_dataset

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch.utils.data


[docs]class FairseqDataset(torch.utils.data.Dataset): """A dataset that provides helpers for batching.""" def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError
[docs] def collater(self, samples): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate Returns: dict: a mini-batch suitable for forwarding with a Model """ raise NotImplementedError
[docs] def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to enforce ``--max-tokens`` during batching.""" raise NotImplementedError
[docs] def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" raise NotImplementedError
[docs] def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" return np.arange(len(self))
@property def supports_prefetch(self): """Whether this dataset supports prefetching.""" return False
[docs] def prefetch(self, indices): """Prefetch the data required for this epoch.""" raise NotImplementedError
def set_epoch(self, epoch): pass