# TODO: implement gumbelmax from aranson


import torch
from torch import FloatTensor,LongTensor
from torch.nn import functional as F
from typing import Union

from . import AbstractWatermarkCode, AbstractReweight, AbstractScore


class GumbelMax_WatermarkCode(AbstractWatermarkCode):
    def __init__(self, v: FloatTensor):
        self.v = v

    @classmethod
    def from_random(
        cls,
        rng: Union[torch.Generator, list[torch.Generator]],
        vocab_size: int,
    ):
        if isinstance(rng, list):
            batch_size = len(rng)
            v = torch.stack(
                [
                    torch.rand((vocab_size,), device=rng[i].device, generator=rng[i])
                    for i in range(batch_size)
                ]
            )
        else:
            v = torch.rand((vocab_size,), device=rng.device, generator=rng)
        return cls(v)


class GumbelMax_Reweight(AbstractReweight):
    watermark_code_type = GumbelMax_WatermarkCode
    
    def __init__(self,is_baseline):
        self.is_baseline=is_baseline
        
    
    def __repr__(self):
        return f"GumbelMax_Reweight(is_baseline={self.is_baseline})"

    def reweight_logits(
        self, code: AbstractWatermarkCode, p_logits: FloatTensor
    ) -> FloatTensor:
        assert isinstance(code, GumbelMax_WatermarkCode)
        
        if self.is_baseline:
            return p_logits
        else:
            index = torch.argmin(-torch.log(code.v) / torch.softmax(p_logits,dim=-1), dim=-1)
            modified_logits = torch.where(
                torch.arange(p_logits.shape[-1], device=p_logits.device)
                == index.unsqueeze(-1),
                torch.full_like(p_logits, 0),
                torch.full_like(p_logits, float("-inf")),
            )

            return modified_logits

