from typing import Sequence, Set, Union, Tuple, Mapping, Optional

from torch.utils.data.dataset import ConcatDataset
from torch.utils.data import Dataset


class ConcatDatasetWithSplitIds(ConcatDataset):
    """Trivial subclass of torch...ConcatDataset, augmented with split_ids attribute. """

    def __init__(self, datasets):
        super().__init__(datasets=datasets)
        self.split_ids = self._get_split_ids()

    def _get_split_ids(self):
        split_ids = set()
        for dataset in self.datasets:
            if not hasattr(dataset, "split_ids"):
                message = f"Datasets used in {self.__class__.__name__} must have the 'split_ids' attribute. "
                raise ValueError(message)
            split_ids |= dataset.split_ids
        return split_ids


def get_dataset_split(dataset: Union[Dataset, Sequence[Dataset]]) -> Tuple[Set[int], Optional[Mapping[int, Set[int]]]]:
    """Extract the splits present in given dataset(s) and return which dataset idx corresponds to which split idx."""

    # Construct the errors to be raised.
    def _raise_split_id_attribute_error(dset):
        if not hasattr(dset, "split_ids"):
            message = f"Datasets used in OOD generalization experiments should have the attribute 'split_ids'. "
            raise AttributeError(message)

    def _raise_dataset_type_error(dset):
        if not isinstance(dset, Dataset):
            message = f"This function only accepts pytorch datasets, or a list of pytorch datasets. "
            raise TypeError(message)

    # Get the dataset split.
    if isinstance(dataset, Dataset):
        _raise_split_id_attribute_error(dset=dataset)
        return dataset.split_ids, None

    # Gather all the splits if we're dealing with a list of datasets.
    elif isinstance(dataset, list):
        dataset_idx_2_split_ids = dict()
        all_split_ids = set()
        for i, curr_dataset in enumerate(dataset):
            _raise_dataset_type_error(dset=curr_dataset)
            _raise_split_id_attribute_error(dset=curr_dataset)
            all_split_ids |= curr_dataset.split_ids
            dataset_idx_2_split_ids[i] = curr_dataset.split_ids
        return all_split_ids, dataset_idx_2_split_ids
    else:
        _raise_dataset_type_error(dset=dataset)
