from typing import Any, Callable, Iterable, Optional
import numpy as np
import bisect
from omegaconf import DictConfig

from pado.core.base.dataset import PadoDataset

__all__ = ["ConcatDataset"]


class ConcatDataset(PadoDataset):
    """
    Concatenation of multiple dataset.
    This special dataset is not registered.
    """

    def __init__(self, datasets: Iterable[PadoDataset]) -> None:
        super().__init__()
        self.datasets = list(datasets)

        for d in self.datasets:
            if len(d) == 0:
                raise ValueError(f"ConcatDataset {d.__class__.__name__} is empty.")
        self.dataset_lengths = [len(d) for d in self.datasets]
        self.cumsum = np.cumsum(self.dataset_lengths).tolist()

    @property
    def num(self):
        return len(self.datasets)

    def __len__(self):
        return self.cumsum[-1]

    def __getitem__(self, index: int) -> Any:
        if index < 0:
            if -index > len(self):
                raise IndexError("Index length overflow, exceed dataset length.")
            index = len(self) + index
        elif index >= len(self):
            raise IndexError("Index length overflow, exceed dataset length.")

        dataset_index = bisect.bisect_right(self.cumsum, index)
        if dataset_index == 0:
            sample_index = index
        else:
            sample_index = index - self.cumsum[dataset_index - 1]
        return self.datasets[dataset_index][sample_index]

    @classmethod
    def from_config(cls,
                    cfg: DictConfig,
                    transform: Optional[Callable] = None,
                    target_transform: Optional[Callable] = None):
        raise NotImplementedError
