import numpy as np
import torch


class PiPSSampler:
    """
    According to Gemini, although this method is pi-PS sampling, it is more specifically dependent rounding.
    This is the paper it is referring to: "Dependent Rounding and Its Applications to Approximation Algorithms" by Gandhi and Srinivasan
    """

    def __init__(self, generator: torch.Generator):
        self.gen = generator

    def update(self, probs, i, j):
        rnd = torch.rand(1).item()
        if probs[i] + probs[j] < 1:
            flag = (probs[i] + probs[j]) * rnd < probs[i]
            if flag:
                probs[i] = probs[i] + probs[j]
                probs[j] = 0.0
            else:
                probs[j] = probs[i] + probs[j]
                probs[i] = 0.0
        else:
            flag = (2 - probs[i] - probs[j]) * rnd < (1 - probs[i])
            if flag:
                probs[i] = probs[i] + probs[j] - 1
                probs[j] = 1.0
            else:
                probs[j] = probs[i] + probs[j] - 1
                probs[i] = 1.0

    def sample(self, probs):
        sum_probs = probs.sum()
        # inds = np.arange(len(probs))
        # inds = torch.arange(len(probs))

        # self.rng.shuffle(inds)
        inds = torch.randperm(len(probs), generator=self.gen)
        while inds.shape[0] > 1:
            update_inds = inds[:2]
            self.update(probs, update_inds[0], update_inds[1])

            # rem_flags = np.ones_like(inds, dtype=bool)
            rem_flags = torch.ones_like(inds, dtype=torch.bool)
            rem_flags[:2] = probs[update_inds] * (1 - probs[update_inds]) > 1e-8
            inds = inds[rem_flags]

        assert torch.isclose(probs.sum(), sum_probs), (
            f"Probabilities changed during sampling: {probs.sum()} vs {sum_probs}"
        )
