Source code for fairseq.data.fairseq_dataset

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch.utils.data


class EpochListening:
    """Mixin for receiving updates whenever the epoch increments."""
    def set_epoch(self, epoch):
        """Will receive the updated epoch number at the beginning of the epoch.
        """
        pass


[docs]class FairseqDataset(torch.utils.data.Dataset, EpochListening): """A dataset that provides helpers for batching.""" def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError
[docs] def collater(self, samples): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate Returns: dict: a mini-batch suitable for forwarding with a Model """ raise NotImplementedError
[docs] def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to enforce ``--max-tokens`` during batching.""" raise NotImplementedError
[docs] def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" raise NotImplementedError
[docs] def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" return np.arange(len(self))
@property def supports_prefetch(self): """Whether this dataset supports prefetching.""" return False def attr(self, attr: str, index: int): return getattr(self, attr, None)
[docs] def prefetch(self, indices): """Prefetch the data required for this epoch.""" raise NotImplementedError
class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): """For datasets that need to be read sequentially, usually because the data is being streamed or otherwise can't be manipulated on a single machine. """ def __iter__(self): raise NotImplementedError