# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import math
import numpy as np
import torch
from . import data_utils
[docs]class CountingIterator(object):
"""Wrapper around an iterable that maintains the iteration count.
Args:
iterable (iterable): iterable to wrap
Attributes:
count (int): number of elements consumed from this iterator
"""
def __init__(self, iterable, start=0):
self.iterable = iterable
self.count = start
self.itr = iter(self)
self.len = start + len(iterable)
def __len__(self):
return self.len
def __iter__(self):
for x in self.iterable:
self.count += 1
yield x
def __next__(self):
return next(self.itr)
[docs] def has_next(self):
"""Whether the iterator has been exhausted."""
return self.count < len(self)
[docs] def skip(self, num_to_skip):
"""Fast-forward the iterator by skipping *num_to_skip* elements."""
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
return self
[docs]class EpochBatchIterator(object):
"""A multi-epoch iterator over a :class:`torch.utils.data.Dataset`.
Compared to :class:`torch.utils.data.DataLoader`, this iterator:
- can be reused across multiple epochs with the :func:`next_epoch_itr`
method (optionally shuffled between epochs)
- can be serialized/deserialized with the :func:`state_dict` and
:func:`load_state_dict` methods
- supports sharding with the *num_shards* and *shard_id* arguments
Args:
dataset (~torch.utils.data.Dataset): dataset from which to load the data
collate_fn (callable): merges a list of samples to form a mini-batch
batch_sampler (~torch.utils.data.Sampler): an iterator over batches of
indices
seed (int, optional): seed for random number generator for
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 0).
"""
def __init__(
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
num_workers=0, epoch=0,
):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
self.collate_fn = collate_fn
self.frozen_batches = tuple(batch_sampler)
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
self.num_workers = num_workers
self.epoch = epoch
self._cur_epoch_itr = None
self._next_epoch_itr = None
self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)
def __len__(self):
return len(self.frozen_batches)
[docs] def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
"""Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator (default: True).
fix_batches_to_gpus: ensure that batches are always
allocated to the same shards across epochs. Requires
that :attr:`dataset` supports prefetching (default: False).
"""
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
else:
self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
)
return self._cur_epoch_itr
[docs] def end_of_epoch(self):
"""Returns whether the most recent epoch iterator has been exhausted"""
return not self._cur_epoch_itr.has_next()
@property
def iterations_in_epoch(self):
"""The number of consumed batches in the current epoch."""
if self._cur_epoch_itr is not None:
return self._cur_epoch_itr.count
elif self._next_epoch_itr is not None:
return self._next_epoch_itr.count
return 0
[docs] def state_dict(self):
"""Returns a dictionary containing a whole state of the iterator."""
return {
'epoch': self.epoch,
'iterations_in_epoch': self.iterations_in_epoch,
}
[docs] def load_state_dict(self, state_dict):
"""Copies the state of the iterator from the given *state_dict*."""
self.epoch = state_dict['epoch']
itr_pos = state_dict.get('iterations_in_epoch', 0)
if itr_pos > 0:
# fast-forward epoch iterator
self._next_epoch_itr = self._get_iterator_for_epoch(
self.epoch,
shuffle=state_dict.get('shuffle', True),
offset=itr_pos,
)
def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False, offset=0):
def shuffle_batches(batches, seed):
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with data_utils.numpy_seed(seed):
np.random.shuffle(batches)
return batches
if self._supports_prefetch:
batches = self.frozen_batches
if shuffle and not fix_batches_to_gpus:
batches = shuffle_batches(list(batches), self.seed + epoch)
batches = list(ShardedIterator(
batches, self.num_shards, self.shard_id, fill_value=[]
))
self.dataset.prefetch([i for s in batches for i in s])
if shuffle and fix_batches_to_gpus:
batches = shuffle_batches(batches, self.seed + epoch + self.shard_id)
else:
if shuffle:
batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch)
else:
batches = self.frozen_batches
batches = list(ShardedIterator(
batches, self.num_shards, self.shard_id, fill_value=[]
))
if offset > 0 and offset >= len(batches):
return None
return CountingIterator(
torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches[offset:],
num_workers=self.num_workers,
),
start=offset,
)
[docs]class GroupedIterator(object):
"""Wrapper around an iterable that returns groups (chunks) of items.
Args:
iterable (iterable): iterable to wrap
chunk_size (int): size of each chunk
"""
def __init__(self, iterable, chunk_size):
self._len = int(math.ceil(len(iterable) / float(chunk_size)))
self.offset = int(math.ceil(getattr(iterable, 'count', 0) / float(chunk_size)))
self.itr = iterable
self.chunk_size = chunk_size
def __len__(self):
return self._len
def __iter__(self):
return self
def __next__(self):
chunk = []
try:
for _ in range(self.chunk_size):
chunk.append(next(self.itr))
except StopIteration as e:
if len(chunk) == 0:
raise e
return chunk
[docs]class ShardedIterator(object):
"""A sharded wrapper around an iterable, padded to length.
Args:
iterable (iterable): iterable to wrap
num_shards (int): number of shards to split the iterable into
shard_id (int): which shard to iterator over
fill_value (Any, optional): padding value when the iterable doesn't
evenly divide *num_shards* (default: None).
"""
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
if shard_id < 0 or shard_id >= num_shards:
raise ValueError('shard_id must be between 0 and num_shards')
self._sharded_len = len(iterable) // num_shards
if len(iterable) % num_shards > 0:
self._sharded_len += 1
self.itr = itertools.zip_longest(
range(self._sharded_len),
itertools.islice(iterable, shard_id, len(iterable), num_shards),
fillvalue=fill_value,
)
def __len__(self):
return self._sharded_len
def __iter__(self):
return self
def __next__(self):
return next(self.itr)[1]