import numpy as np
from numpy.typing import NDArray
from typing import Iterator
from .dataset import DynamicalSystemDataset
from .transforms import Transform

def index_generation(num_samples: int, shuffle: bool, seed: int) -> NDArray[np.int_]:
    indices = np.arange(num_samples)
    if shuffle:
        rng_state = np.random.default_rng(seed)
        rng_state.shuffle(indices)
    return indices

def batch_indices(
    indices: NDArray[np.int_],
    batch_size: int,
    drop_last: bool,
) -> Iterator[NDArray[np.int_]]:
    num_samples = len(indices)

    for start in range(0, num_samples, batch_size):
        batch = indices[start : start + batch_size]

        if drop_last and len(batch) < batch_size:
            break

        yield batch

def batch_iterator(dataset: DynamicalSystemDataset, batch_size: int, seed: int, shuffle: bool, transforms: list[Transform], drop_last: bool):
    num_samples = len(dataset)
    epoch_rng = np.random.default_rng(seed)
    indices = index_generation(num_samples=num_samples, shuffle=shuffle, seed=seed)

    for batch_idx in batch_indices(indices, batch_size, drop_last):
        traj_batch, params_batch = dataset.get_batch(batch_idx)

        sample_seed = epoch_rng.integers(0, 2**32)
        sample_rng = np.random.default_rng(sample_seed)
        for t in transforms:
            traj_batch, params_batch = t(traj_batch, params_batch, sample_rng)

        yield traj_batch, params_batch