Source code for fairseq.data.concat_dataset

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import bisect

import numpy as np

from . import FairseqDataset


[docs]class ConcatDataset(FairseqDataset): @staticmethod def cumsum(sequence, sample_ratios): r, s = [], 0 for e, ratio in zip(sequence, sample_ratios): curr_len = int(ratio * len(e)) r.append(curr_len + s) s += curr_len return r def __init__(self, datasets, sample_ratios=1): super(ConcatDataset, self).__init__() assert len(datasets) > 0, "datasets should not be an empty iterable" self.datasets = list(datasets) if isinstance(sample_ratios, int): sample_ratios = [sample_ratios] * len(self.datasets) self.sample_ratios = sample_ratios self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios) self.real_sizes = [len(d) for d in self.datasets] def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) return self.datasets[dataset_idx][sample_idx] def _get_dataset_and_sample_index(self, idx: int): dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] sample_idx = sample_idx % self.real_sizes[dataset_idx] return dataset_idx, sample_idx
[docs] def collater(self, samples): # For now only supports datasets with same underlying collater implementations return self.datasets[0].collater(samples)
[docs] def size(self, idx: int): """ Return an example's size as a float or tuple. """ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) return self.datasets[dataset_idx].size(sample_idx)
[docs] def num_tokens(self, index: int): return np.max(self.size(index))
@property def sizes(self): return np.concatenate( [np.tile(ds.sizes, sr) for ds, sr in zip(self.datasets, self.sample_ratios)] ) @property def supports_prefetch(self): return all(d.supports_prefetch for d in self.datasets)
[docs] def ordered_indices(self): """ Returns indices sorted by length. So less padding is needed. """ return np.argsort(self.sizes)
[docs] def prefetch(self, indices): frm = 0 for to, ds in zip(self.cumulative_sizes, self.datasets): real_size = len(ds) if getattr(ds, 'supports_prefetch', False): ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) frm = to