Source code for fairseq.data.round_robin_zip_datasets

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from collections import OrderedDict
from typing import Dict, Sequence

import numpy as np

from . import FairseqDataset, LanguagePairDataset

logger = logging.getLogger(__name__)


[docs]class RoundRobinZipDatasets(FairseqDataset): """Zip multiple :class:`~fairseq.data.FairseqDataset` instances together. Shorter datasets are repeated in a round-robin fashion to match the length of the longest one. Args: datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of :class:`~fairseq.data.FairseqDataset` instances. eval_key (str, optional): a key used at evaluation time that causes this instance to pass-through batches from *datasets[eval_key]*. """ def __init__(self, datasets, eval_key=None): super().__init__() if isinstance(datasets, dict): datasets = OrderedDict(datasets) assert isinstance(datasets, OrderedDict) assert datasets, "Can't make a RoundRobinZipDatasets out of nothing" for dataset in datasets.values(): assert isinstance(dataset, FairseqDataset) self.datasets = datasets self.eval_key = eval_key self.longest_dataset_key = max(datasets, key=lambda k: len(datasets[k])) self.longest_dataset = datasets[self.longest_dataset_key] self._ordered_indices: Dict[str, Sequence[int]] = None def _map_index(self, key, index): assert ( self._ordered_indices is not None ), "Must call RoundRobinZipDatasets.ordered_indices() first" o = self._ordered_indices[key] return o[index % len(o)] def __getitem__(self, index): if self.eval_key is None: return OrderedDict( [ (key, dataset[self._map_index(key, index)]) for key, dataset in self.datasets.items() ] ) else: # at evaluation time it's useful to pass-through batches from a single key return self.datasets[self.eval_key][self._map_index(self.eval_key, index)] def __len__(self): if self._ordered_indices is not None: return len(self._ordered_indices[self.longest_dataset_key]) return len(self.longest_dataset)
[docs] def collater(self, samples): """Merge a list of samples to form a mini-batch.""" if len(samples) == 0: return None if self.eval_key is None: return OrderedDict( [ (key, dataset.collater([sample[key] for sample in samples])) for key, dataset in self.datasets.items() ] ) else: # at evaluation time it's useful to pass-through batches from a single key return self.datasets[self.eval_key].collater(samples)
[docs] def num_tokens(self, index): """Return an example's length (number of tokens), used for batching.""" # TODO make it configurable whether to use max() or sum() here return max( dataset.num_tokens(self._map_index(key, index)) for key, dataset in self.datasets.items() )
[docs] def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" return { key: dataset.size(self._map_index(key, index)) for key, dataset in self.datasets.items() }
[docs] def ordered_indices(self): """Ordered indices for batching.""" if self._ordered_indices is None: # Call the underlying dataset's ordered_indices() here, so that we # get the same random ordering as we would have from using the # underlying sub-datasets directly. self._ordered_indices = OrderedDict( [ (key, dataset.ordered_indices()) for key, dataset in self.datasets.items() ] ) return np.arange(len(self))
[docs] def filter_indices_by_size(self, indices, max_positions=None): """ Filter each sub-dataset independently, then update the round robin to work on the filtered sub-datasets. """ def _deep_until_language_pair(dataset): if isinstance(dataset, LanguagePairDataset): return dataset if hasattr(dataset, "tgt_dataset"): return _deep_until_language_pair(dataset.tgt_dataset) if hasattr(dataset, "dataset"): return _deep_until_language_pair(dataset.dataset) raise Exception(f"Don't know how to unwrap this dataset: {dataset}") if not isinstance(max_positions, dict): max_positions = {k: max_positions for k in self.datasets.keys()} ignored_some = False for key, dataset in self.datasets.items(): dataset = _deep_until_language_pair(dataset) self._ordered_indices[key], ignored = dataset.filter_indices_by_size( self._ordered_indices[key], max_positions[key] ) if len(ignored) > 0: ignored_some = True logger.warning( f"{len(ignored)} samples from {key} have invalid sizes and will be skipped, " f"max_positions={max_positions[key]}, first few sample ids={ignored[:10]}" ) # Since we are modifying in place the _ordered_indices, # it's not possible anymore to return valid ignored indices. # Hopefully the extra debug information print above should be enough to debug. # Ideally we would receive ignore_invalid_inputs so that we could have # a proper error message. return (np.arange(len(self)), [0] if ignored_some else [])
@property def supports_prefetch(self): return all( getattr(dataset, "supports_prefetch", False) for dataset in self.datasets.values() )
[docs] def prefetch(self, indices): for key, dataset in self.datasets.items(): dataset.prefetch([self._map_index(key, index) for index in indices])