import torch.distributions as dist

import numpy as np
import torch
import torch.nn.functional as F


class GenCompl:
    def __init__(self, num_samples, num_class) -> None:
        self.num_class = num_class
        self.keep_list = torch.ones(num_samples, num_class).cuda()
        self.probs = torch.ones(num_samples, num_class).cuda() / num_class

    def full_space_partial(self, labels, index=None, num=None):
        # probs = self.probs[index]
        probs = (1 - labels) / (1 - labels).sum(1, keepdim=True)
        neg = self._gen_noreplace_uniform_neg(probs, num)
        return neg

    def _gen_noreplace_uniform_neg(self, probs, num):
        neg = torch.vstack(
            [
                F.one_hot(
                    torch.tensor(
                        np.random.choice(
                            self.num_class,
                            int(
                                torch.round(num[i] * (probs[i] > 0).sum())
                                .clamp(min=0.0, max=self.num_class)
                                .item()
                            )
                            # int(num[i].item())
                            if num is not None
                            else np.random.randint(0, self.num_class-1, dtype=np.uint8),
                            replace=False,
                            p=probs[i].detach().cpu().numpy(),
                        )
                    ),
                    self.num_class,
                ).sum(0)
                for i in range(probs.shape[0])
            ]
        ).cuda()
        neg[neg > 1] = 1
        return neg

