from dataclasses import dataclass
from typing import Sequence, Tuple

import grain.python as grain

from egxc.dataloading.datasets.base import BaseDataset
from egxc.dataloading.datasets.ensemble import DatasetEnsemble
from egxc.dataloading.utils import RandomSubsetWrapper
from egxc.systems import PreloadSystem


class GrainDataLoaderWrapper:
    """
    Wrapper around a grain DataLoader to provide a length property, such that the iteration
    stops after the specified number of samples, but the worker processes continue to run
    until the dataloader has been run over the specified number of total epochs.
    """

    __iterator = None
    __counter = 0

    def __init__(self, dataloader: grain.DataLoader, length: int):
        self.__dataloader = dataloader
        self.__length = length

    def __iter__(self):
        if self.__iterator is None:
            # only initialize the iterator once
            self.__iterator = iter(self.__dataloader)
        return self

    def __len__(self):
        return self.__length

    def __next__(self):
        self.__counter += 1
        if self.__counter > self.__length:
            self.__counter = 0
            raise StopIteration
        return next(self.__iterator)  # type: ignore


@dataclass
class DataLoaders:
    train: GrainDataLoaderWrapper
    val: GrainDataLoaderWrapper
    test: GrainDataLoaderWrapper


def _get_sample_for_model_init(
    dataset: BaseDataset, preload_transform: Sequence[grain.Transformation]
) -> PreloadSystem:
    """
    Utility function to get a single sample from a dataset for model initialization.
    Uses the exact same transformations as the main dataloaders, but avoids thread / process
    creation overhead.
    """
    sampler = grain.SequentialSampler(1, shard_options=grain.NoSharding())
    dataloader = grain.DataLoader(
        data_source=dataset,
        operations=preload_transform,
        sampler=sampler,
        worker_count=0,
    )
    return next(iter(dataloader))[0]


def get_individual_dataloader(
    dataset: BaseDataset,
    transformations: Sequence[grain.Transformation],
    shuffle: bool,
    workers: int | None,
    worker_buffer_size: int,
    random_seed: int,
) -> GrainDataLoaderWrapper:
    length = len(dataset)
    sampler = grain.IndexSampler(
        num_records=length,
        shard_options=grain.NoSharding(),
        shuffle=shuffle,
        seed=random_seed,
    )

    dataloader = grain.DataLoader(
        data_source=dataset,
        operations=transformations,
        sampler=sampler,
        worker_count=workers,
        worker_buffer_size=worker_buffer_size,
    )
    return GrainDataLoaderWrapper(dataloader, length)


def get_psys_and_dataloaders(
    datasets: DatasetEnsemble,
    transformations: Sequence[grain.Transformation],
    shuffle: bool,
    workers: int | None,
    worker_buffer_size: int,
    shuffling_seed: int,
    n_test_samples: int | None = None,
) -> Tuple[PreloadSystem, DataLoaders]:
    """
    Constructs dataloaders and retrieves a single preloaded system for model initialization.

    Args:
        datasets (DatasetEnsemble): An ensemble object containing 'train', 'val', and 'test'
            BaseDataset splits.
        transformations (Sequence[grain.Transformation]): Transformations or preprocessing steps to
            apply to each dataset sample before batching.
        shuffle (bool): Whether to shuffle data in each loader (consistent seed for all splits).
        workers (int | None): Number of worker threads/processes for each DataLoader.
            If None, defaults to implementation default (often main thread).
        worker_buffer_size (int): Number of items each worker prefetches in its buffer.
        shuffling_seed (int): Random seed for deterministic (reproducible) shuffling within samplers.
        n_test_samples (int | None, optional): If not None, select a random subset (fixed by `shuffling_seed`)
            of the test set of size `n_test_samples`, otherwise use the full test set.

    Returns:
        Tuple[PreloadSystem, DataLoaders]:
            - `PreloadSystem`: A single sample (with `transformations` applied) from the training set, suitable for JAX shapes/model init, etc.
            - `DataLoaders`: Object containing train, val, and test DataLoader wrappers, each using the supplied settings.
    """

    if n_test_samples is None:
        testset = datasets.test
    else:
        testset = RandomSubsetWrapper(datasets.test, n_test_samples, shuffling_seed)

    out = DataLoaders(
        get_individual_dataloader(
            datasets.train,
            transformations,
            shuffle,
            workers,
            worker_buffer_size,
            shuffling_seed,
        ),
        get_individual_dataloader(
            datasets.val,
            transformations,
            shuffle,
            workers,
            worker_buffer_size,
            shuffling_seed,
        ),
        get_individual_dataloader(
            testset,
            transformations,
            shuffle,
            workers,
            worker_buffer_size,
            shuffling_seed,
        ),
    )
    sys = _get_sample_for_model_init(datasets.train, transformations)
    return sys, out
