from torch.utils.data import Sampler, DataLoader



class GraphBatchSampler(Sampler):
    r"""Wraps another sampler to yield a mini-batch of indices.

    Args:
        sampler (Sampler or Iterable): Base sampler. Can be any iterable object
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``

    Example:
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """

    def __init__(self, sampler, batch_size, block_list):
        # Since collections.abc.Iterable does not check for `__getitem__`, which
        # is one way for an object to be an iterable, we don't do an `isinstance`
        # check here.
        if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}")

        self.sampler = sampler
        self.batch_size = batch_size
        self.block_list = block_list

        self.pending_idx = None

    def __iter__(self):
        # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
        batch = []
        current_size = 0
        for idx in self.sampler:
            d_size = self.block_list[idx]
            if current_size + d_size > self.batch_size:
                # self.pending_idx = idx
                yield batch
                batch = [idx]
                current_size = d_size
            else:
                current_size += d_size
                batch.append(idx)
        if current_size > 0:
            yield batch
    # def __len__(self) -> int:
    #     # Can only be called if self.sampler has __len__ implemented
    #     # We cannot enforce this condition, so we turn off typechecking for the
    #     # implementation below.
    #     # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    #     if self.drop_last:
    #         return len(self.sampler) // self.batch_size  # type: ignore[arg-type]
    #     else:
    #         return (len(self.sampler) + self.batch_size - 1) // self.batch_size  # type: ignore[arg-type]
