# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch

from . import FairseqDataset


class ConcatSentencesDataset(FairseqDataset):
    def __init__(self, *datasets):
        super().__init__()
        self.datasets = datasets
        assert all(
            len(ds) == len(datasets[0]) for ds in datasets
        ), "datasets must have the same length"

    def __getitem__(self, index):
        return torch.cat([ds[index] for ds in self.datasets])

    def __len__(self):
        return len(self.datasets[0])

    def collater(self, samples):
        return self.datasets[0].collater(samples)

    @property
    def sizes(self):
        return sum(ds.sizes for ds in self.datasets)

    def num_tokens(self, index):
        return sum(ds.num_tokens(index) for ds in self.datasets)

    def size(self, index):
        return sum(ds.size(index) for ds in self.datasets)

    def ordered_indices(self):
        return self.datasets[0].ordered_indices()

    @property
    def supports_prefetch(self):
        return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets)

    def prefetch(self, indices):
        for ds in self.datasets:
            if getattr(ds, "supports_prefetch", False):
                ds.prefetch(indices)

    def set_epoch(self, epoch):
        super().set_epoch(epoch)
        for ds in self.datasets:
            if hasattr(ds, "set_epoch"):
                ds.set_epoch(epoch)
