import torch
from torch import FloatTensor,LongTensor
from torch.nn import functional as F
from typing import Union

from . import AbstractWatermarkCode, AbstractReweight, AbstractScore


class InverseSampling_WatermarkCode(AbstractWatermarkCode):
    def __init__(self, u: FloatTensor, shuffle: LongTensor):
        self.u = u
        self.shuffle=shuffle
        self.unshuffle = torch.argsort(shuffle, dim=-1)

    @classmethod
    def from_random(
        cls,
        rng: Union[torch.Generator, list[torch.Generator]],
        vocab_size: int,
    ):
        if isinstance(rng, list):
            batch_size = len(rng)
            u = torch.stack(
                [
                    torch.rand((), generator=rng[i], device=rng[i].device)
                    for i in range(batch_size)
                ]
            )
            shuffle = torch.stack(
                [
                    torch.randperm(vocab_size, generator=rng[i], device=rng[i].device)
                    for i in range(batch_size)
                ]
            )
        else:
            u = torch.rand((), generator=rng, device=rng.device)
            shuffle = torch.randperm(vocab_size, generator=rng, device=rng.device)
        return cls(u, shuffle)


class InverseSampling_Reweight(AbstractReweight):
    watermark_code_type = InverseSampling_WatermarkCode

    def __repr__(self):
        return f"InverseSampling_Reweight()"

    def reweight_logits(
        self, code: AbstractWatermarkCode, p_logits: FloatTensor
    ) -> FloatTensor:

        s_p_logits = torch.gather(p_logits, -1, code.shuffle)
        cumsum = torch.cumsum(F.softmax(s_p_logits, dim=-1), dim=-1)
        index = torch.searchsorted(cumsum, code.u[..., None], right=True)
        index = torch.clamp(index, 0, s_p_logits.shape[-1] - 1)
        s_modified_logits = torch.where(
            torch.arange(s_p_logits.shape[-1], device=s_p_logits.device) == index,
            torch.full_like(s_p_logits, 0),
            torch.full_like(s_p_logits, float("-inf")),
        )
        modified_logits=torch.gather(s_modified_logits, -1, code.unshuffle)
        return modified_logits

