##!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
from torch import FloatTensor, LongTensor,BoolTensor
from torch.nn import functional as F
import time
from typing import Union

from . import AbstractWatermarkCode, AbstractReweight, AbstractScore
import json



class SynthID_Text_WatermarkCode(AbstractWatermarkCode):
    def __init__(self, hash_table: LongTensor):
        self.hash_table=hash_table

    @classmethod
    def from_random(
        cls,
        rng: Union[torch.Generator, list[torch.Generator]],
        vocab_size: int,
    ):
        if isinstance(rng, list):
            batch_size = len(rng)
            hash_table = torch.stack(
                [
                    torch.randint(low=0,high=2,size=(30,vocab_size), generator=rng[i], device=rng[i].device)
                    for i in range(batch_size)
                ],
                dim=0
            )
        else:
            # shuffle = torch.randperm(vocab_size, generator=rng, device=rng.device)
            raise NotImplementedError #?
            hash_table=torch.randint(low=0,high=2,size=(30,vocab_size), generator=rng, device=rng.device)
        return cls(hash_table)


class SynthID_Text_Reweight(AbstractReweight):
    watermark_code_type = SynthID_Text_WatermarkCode
    
    def __init__(self,m):
        self.m=m
        assert self.m==30 #as it must be the same as the WatermarkCode

    def __repr__(self):

        return f"SynthID_Text_Reweight(m={self.m})"

    def reweight_logits(
        self, code: AbstractWatermarkCode, p_logits: FloatTensor
    ) -> FloatTensor:    
        #p_logits #[bsz,vocab_size]
        
        # s_ means shuffled
        samples = torch.multinomial(torch.softmax(p_logits,dim=-1), num_samples=2**self.m, replacement=True) #[bsz,2**m]
        
        bsz,vocab_size=p_logits.shape
        
        hash_table=code.hash_table #[bsz,30,vocab_size]
        
        
        for layer_idx in range(self.m):
            
            values=torch.gather(hash_table[:,layer_idx,:],dim=1,index=samples) #[bsz,n]
            values=values.view((bsz,-1,2))
            _,value_size,_=values.shape
            rand_noise=torch.randint(low=0,high=2,size=(bsz,value_size,1),device=p_logits.device)*0.1
            values=values+rand_noise
            samples=samples.view((bsz,-1,2))
            
            max_idx=torch.argmax(values,dim=-1,keepdim=True) #[bsz,n/2,1]
            samples=torch.gather(samples,dim=-1,index=max_idx).view((bsz,-1))
        
        samples=samples.view(bsz)
        final_idx=samples
        
        modified_logits = torch.where(
            torch.arange(p_logits.shape[-1], device=p_logits.device)
            == final_idx.unsqueeze(-1),
            torch.full_like(p_logits, 0),
            torch.full_like(p_logits, float("-inf")),
        )
        
        return modified_logits
  