Source code for fairseq.data.monolingual_dataset

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch

from . import FairseqDataset, data_utils


def collate(samples, pad_idx, eos_idx, fixed_pad_length=None, pad_to_bsz=None):
    if len(samples) == 0:
        return {}

    def merge(key, is_list=False):
        if is_list:
            res = []
            for i in range(len(samples[0][key])):
                res.append(
                    data_utils.collate_tokens(
                        [s[key][i] for s in samples],
                        pad_idx,
                        eos_idx,
                        left_pad=False,
                        pad_to_length=fixed_pad_length,
                        pad_to_bsz=pad_to_bsz,
                    )
                )
            return res
        else:
            return data_utils.collate_tokens(
                [s[key] for s in samples],
                pad_idx,
                eos_idx,
                left_pad=False,
                pad_to_length=fixed_pad_length,
                pad_to_bsz=pad_to_bsz,
            )

    src_tokens = merge("source")
    if samples[0]["target"] is not None:
        is_target_list = isinstance(samples[0]["target"], list)
        target = merge("target", is_target_list)
    else:
        target = src_tokens

    return {
        "id": torch.LongTensor([s["id"] for s in samples]),
        "nsentences": len(samples),
        "ntokens": sum(len(s["source"]) for s in samples),
        "net_input": {
            "src_tokens": src_tokens,
            "src_lengths": torch.LongTensor([s["source"].numel() for s in samples]),
        },
        "target": target,
    }


[docs]class MonolingualDataset(FairseqDataset): """ A wrapper around torch.utils.data.Dataset for monolingual data. Args: dataset (torch.utils.data.Dataset): dataset to wrap sizes (List[int]): sentence lengths vocab (~fairseq.data.Dictionary): vocabulary shuffle (bool, optional): shuffle the elements before batching (default: True). """ def __init__( self, dataset, sizes, src_vocab, tgt_vocab=None, add_eos_for_other_targets=False, shuffle=False, targets=None, add_bos_token=False, fixed_pad_length=None, pad_to_bsz=None, src_lang_idx=None, tgt_lang_idx=None, ): self.dataset = dataset self.sizes = np.array(sizes) self.vocab = src_vocab self.tgt_vocab = tgt_vocab or src_vocab self.add_eos_for_other_targets = add_eos_for_other_targets self.shuffle = shuffle self.add_bos_token = add_bos_token self.fixed_pad_length = fixed_pad_length self.pad_to_bsz = pad_to_bsz self.src_lang_idx = src_lang_idx self.tgt_lang_idx = tgt_lang_idx assert targets is None or all( t in {"self", "future", "past"} for t in targets ), "targets must be none or one of 'self', 'future', 'past'" if targets is not None and len(targets) == 0: targets = None self.targets = targets def __getitem__(self, index): if self.targets is not None: # *future_target* is the original sentence # *source* is shifted right by 1 (maybe left-padded with eos) # *past_target* is shifted right by 2 (left-padded as needed) # # Left-to-right language models should condition on *source* and # predict *future_target*. # Right-to-left language models should condition on *source* and # predict *past_target*. source, future_target, past_target = self.dataset[index] source, target = self._make_source_target( source, future_target, past_target ) else: source = self.dataset[index] target = None source, target = self._maybe_add_bos(source, target) return {"id": index, "source": source, "target": target} def __len__(self): return len(self.dataset) def _make_source_target(self, source, future_target, past_target): if self.targets is not None: target = [] if ( self.add_eos_for_other_targets and (("self" in self.targets) or ("past" in self.targets)) and source[-1] != self.vocab.eos() ): # append eos at the end of source source = torch.cat([source, source.new([self.vocab.eos()])]) if "future" in self.targets: future_target = torch.cat( [future_target, future_target.new([self.vocab.pad()])] ) if "past" in self.targets: # first token is before the start of sentence which is only used in "none" break mode when # add_eos_for_other_targets is False past_target = torch.cat( [ past_target.new([self.vocab.pad()]), past_target[1:], source[-2, None], ] ) for t in self.targets: if t == "self": target.append(source) elif t == "future": target.append(future_target) elif t == "past": target.append(past_target) else: raise Exception("invalid target " + t) if len(target) == 1: target = target[0] else: target = future_target return source, self._filter_vocab(target) def _maybe_add_bos(self, source, target): if self.add_bos_token: source = torch.cat([source.new([self.vocab.bos()]), source]) if target is not None: target = torch.cat([target.new([self.tgt_vocab.bos()]), target]) return source, target
[docs] def num_tokens_vec(self, indices): """Return the number of tokens for a set of positions defined by indices. This value is used to enforce ``--max-tokens`` during batching.""" return self.sizes[indices]
def _filter_vocab(self, target): if len(self.tgt_vocab) != len(self.vocab): def _filter(target): mask = target.ge(len(self.tgt_vocab)) if mask.any(): target[mask] = self.tgt_vocab.unk() return target if isinstance(target, list): return [_filter(t) for t in target] return _filter(target) return target
[docs] def collater(self, samples): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate Returns: dict: a mini-batch with the following keys: - `id` (LongTensor): example IDs in the original input order - `ntokens` (int): total number of tokens in the batch - `net_input` (dict): the input to the Model, containing keys: - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in the source sentence of shape `(bsz, src_len)`. Padding will appear on the right. - `target` (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape `(bsz, tgt_len)`. Padding will appear on the right. """ return collate( samples, self.vocab.pad(), self.vocab.eos(), self.fixed_pad_length, self.pad_to_bsz, )
[docs] def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to enforce ``--max-tokens`` during batching.""" return self.sizes[index]
[docs] def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" return self.sizes[index]
[docs] def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: order = [np.random.permutation(len(self))] else: order = [np.arange(len(self))] order.append(self.sizes) return np.lexsort(order)
@property def supports_prefetch(self): return getattr(self.dataset, "supports_prefetch", False)
[docs] def prefetch(self, indices): self.dataset.prefetch(indices)