from typing import List, Sequence, Tuple

import numpy as onp
from grain._src.python.data_sources import SupportsIndex

from egxc.dataloading.datasets.base import BaseDataset, RawSample


def random_index_split(
    num_elements: int, fractions: Sequence[float], seed: int
) -> Tuple[List[int], ...]:
    sum_fractions = sum(fractions)
    assert sum_fractions == 1.0
    indices = onp.arange(num_elements)
    indices = onp.random.RandomState(seed).permutation(indices)
    out = []
    previous = 0
    for frac in fractions[:-1]:
        n = int(num_elements * frac)
        out.append(indices[previous : previous + n].tolist())
        previous += n
    out.append(indices[previous:].tolist())
    return tuple(out)


class IndexWrapper(BaseDataset):
    directory: str

    def __init__(self, dataset: BaseDataset, indices: List[int]):
        self.__dataset = dataset
        self.__indices = indices
        self.copy_params_from_dataset(dataset)

    def __getitem__(self, idx: SupportsIndex) -> RawSample:
        idx = self.__indices[idx]
        return self.__dataset[idx]

    def __len__(self) -> int:
        return len(self.__indices)


class RandomSubsetWrapper(BaseDataset):
    def __init__(self, dataset: BaseDataset, n_samples: int, sampling_seed: int):
        self.__original_dataset = dataset
        self.__original_length = len(dataset)
        assert n_samples <= self.__original_length
        self.__indices = (
            onp.random.RandomState(sampling_seed)
            .permutation(self.__original_length)[:n_samples]
            .tolist()
        )
        self.__n_samples = n_samples
        self.__seed = sampling_seed

    def __getitem__(self, idx: SupportsIndex) -> RawSample:
        idx = self.__indices[idx]
        return self.__original_dataset[idx]

    def __len__(self) -> int:
        return self.__n_samples
