"""
Extension to pytorch batch sampler to also yield a random scalar between a given range.
"""

from typing import Any, Callable, Sequence

from torch.utils.data import Sampler


class BatchSamplerSyncRandom(Sampler):
    r"""Extending the Batch Sampler to also pass a random
        item that is generated and passed to each of the loaders.

    Args:
        sampler (Sampler or Iterable): Base sampler. Can be any iterable object
            with ``__len__`` implemented.
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``
        random_generator (Callable[Any, [None,]]): Will return an item that was generated
            and pass to all the dataloaders

    Example:
        >>> list(BatchSamplerRandScale(SequentialSampler(range(10)),
                batch_size=3, drop_last=False, random_generator=lambda: random.uniform(0.5,1)))
        [[(0, 0.65), (1, 0.65), (2, 0.65)],
         [(3, 0.8), (4, 0.8), (5, 0.8)],
         [(6, 0.93), (7, 0.93), (8, 0.93)],
         [(9, 0.54)]]
    """

    def __init__(
        self,
        sampler: Sampler | Sequence[Any],
        batch_size: int,
        drop_last: bool,
        random_generator: Callable[[], Any],
    ):
        # Since collections.abc.Iterable does not check for `__getitem__`, which
        # is one way for an object to be an iterable, we don't do an `isinstance`
        # check here.
        super().__init__(None)
        if (
            not isinstance(batch_size, int)
            or isinstance(batch_size, bool)
            or batch_size <= 0
        ):
            raise ValueError(
                f"batch_size should be a positive integer value, but got {batch_size}"
            )
        if not isinstance(drop_last, bool):
            raise ValueError(
                f"drop_last should be a boolean value, but got {drop_last}"
            )

        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.rand_gen = random_generator

    def __iter__(self):
        batch = []
        rand_sample = self.rand_gen()
        for idx in self.sampler:
            batch.append((idx, rand_sample))
            if len(batch) == self.batch_size:
                yield batch
                batch = []
                rand_sample = self.rand_gen()

        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        len_ = len(self.sampler) // self.batch_size  # rounds down
        return len_ if self.drop_last else len_ + 1
