# Multi datasets for continual learning
# All datasets needs to be in the same format.
# have targets and classes within the dataset.
import copy
from typing import Callable, Optional, Iterable
from torch.utils.data import Dataset


class multiDatasets(Dataset):

    def __init__(
        self,
        datasets: Iterable[Dataset],
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:

        super().__init__()
        self.datasets = []
        self.classes = []
        self.classes_names = []
        self.items_list = []
        self.targets = []

        for dataset in datasets:
            # if not isinstance(dataset, Dataset):
            #     raise TypeError("dataset should be a Dataset object")
            self.datasets.append(
                dataset(root, train, copy.deepcopy(transform),
                        target_transform, download))
            self.classes += [
                str(x) for x in self.datasets[-1].classes_names
                if str(x) not in self.classes
            ]
        self.classes_names = self.classes

        self.items_list, self.targets = [], []
        for ds_id in range(len(self.datasets)):
            for item_idx, x in enumerate(self.datasets[ds_id].targets):
                target = self.classes_names.index(
                    self.datasets[ds_id].classes_names[x])
                self.targets.append(target)
                self.items_list.append((ds_id, item_idx, target))

    def __getitem__(self, index):
        ds_id, item_idx, target = self.items_list[index]
        image = self.datasets[ds_id][item_idx][0]
        return image, target

    def __len__(self):
        return len(self.items_list)

    @property
    def n_classes(self):
        return len(self.classes_names)
