from __future__ import annotations

import math
from typing import Iterable, List

import numpy as np
from torch.utils.data import Sampler


class UniformComboBatchSampler(Sampler[List[int]]):
    """Fixed-size, shape-homogeneous batches using compact NumPy indices.

    combo_ids: np.ndarray of shape [N], dtype int, with values in [0..K-1]
    num_combos: K distinct combos (N_C,D)
    batch_size: items per batch (drop_last controls last partials)
    """

    def __init__(
        self,
        combo_ids: np.ndarray,
        num_combos: int,
        batch_size: int,
        shuffle: bool = True,
        drop_last: bool = True,
        seed: int = 123,
    ):
        assert combo_ids.ndim == 1
        self.combo_ids = combo_ids.astype(np.int32, copy=False)
        self.num_combos = int(num_combos)
        self.bs = int(batch_size)
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.seed = seed
        self.epoch = 0
        self.idxs = [np.flatnonzero(self.combo_ids == k).astype(np.int32) for k in range(self.num_combos)]

    def set_epoch(self, epoch: int):
        self.epoch = int(epoch)

    def __iter__(self) -> Iterable[List[int]]:
        rng = np.random.default_rng(self.seed + self.epoch)
        # Shuffle in-bucket indices
        for arr in self.idxs:
            if self.shuffle:
                rng.shuffle(arr)
        # Round-robin over combos to mix uniformly
        order = np.arange(self.num_combos)
        if self.shuffle:
            rng.shuffle(order)
        pos = np.zeros(self.num_combos, dtype=np.int64)
        num_batches = [len(a) // self.bs if self.drop_last else math.ceil(len(a) / self.bs) for a in self.idxs]
        max_rounds = max(num_batches) if num_batches else 0

        for _ in range(max_rounds):
            for k in order:
                start = pos[k]
                end = start + self.bs
                if end <= len(self.idxs[k]) or (not self.drop_last and start < len(self.idxs[k])):
                    batch = self.idxs[k][start:end]
                    if len(batch) == self.bs or not self.drop_last:
                        pos[k] = end
                        yield batch.tolist()

    def __len__(self) -> int:
        if self.drop_last:
            return sum(len(a) // self.bs for a in self.idxs)
        else:
            return sum((len(a) + self.bs - 1) // self.bs for a in self.idxs)

