from typing import List

import numpy as np
import torch
from torch.utils.data import Sampler


class IPPSampler(Sampler[List[int]]):
    r"""
    Each sample is selected with a probability equal to its individualized sample rate.
    The sampler generates ``steps`` number of batches.
    """

    def __init__(
        self, *, n_samples: int, per_sample_sampling_rates: List[float], generator=None, n_steps=None
    ):
        r"""
        Args:
            num_samples: number of samples to draw.
            sample_rate: probability used in sampling.
            generator: Generator used in sampling.
            steps: Number of steps (iterations of the Sampler)
        """
        self.num_samples = n_samples
        self.sample_rates = per_sample_sampling_rates
        self.generator = generator

        if self.num_samples <= 0:
            raise ValueError(
                "num_samples should be a positive integer "
                "value, but got num_samples={}".format(self.num_samples)
            )
        
        if self.num_samples != len(self.sample_rates):
            raise ValueError(
                "Number of data {} is different from number of sample rates {}".format(self.num_samples, len(self.sample_rates))
            )

        if n_steps is not None:
            self.steps = n_steps
        else:
            self.steps = int(torch.sum(self.sample_rates[0])) # problematic


    def __len__(self):
        return self.steps

    def __iter__(self):
        num_batches = self.steps
        while num_batches > 0:
            mask = (
                torch.rand(self.num_samples, generator=self.generator)
                < self.sample_rates
            )
            
            indices = mask.nonzero(as_tuple=False).reshape(-1).tolist()
            #if num_batches >= self.steps - 3:
            #    print(indices[:5], len(indices))
            #print(indices[:5], len(indices))
            yield indices

            num_batches -= 1
