#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
from torch import FloatTensor, LongTensor,BoolTensor
from torch.nn import functional as F
import time
from typing import Union

from . import AbstractWatermarkCode, AbstractReweight, AbstractScore
import json


class Splitmark_WatermarkCode(AbstractWatermarkCode):
    def __init__(self, shuffle: LongTensor, split_k:BoolTensor):
        self.shuffle = shuffle
        self.split_k = split_k
        self.unshuffle = torch.argsort(shuffle, dim=-1)

    @classmethod
    def from_random(
        cls,
        rng: Union[torch.Generator, list[torch.Generator]],
        vocab_size: int,
    ):
        if isinstance(rng, list):
            batch_size = len(rng)
            shuffle = torch.stack(
                [
                    torch.randperm(vocab_size, generator=rng[i], device=rng[i].device)
                    for i in range(batch_size)
                ]
            )
            split_k = torch.cat([
                torch.randint(low=0,high=2,size=(1,1),dtype=torch.bool,generator=rng[i], device=rng[i].device)
                for i in range(batch_size)
            ],dim=0
            )
        else:
            shuffle = torch.randperm(vocab_size, generator=rng, device=rng.device)
            split_k = torch.randint(low=0,high=2,size=(1,1),dtype=torch.bool,device=rng.device,generator=rng)
        return cls(shuffle,split_k)


class Split_Reweight(AbstractReweight):
    watermark_code_type = Splitmark_WatermarkCode

    def __init__(self, alpha: float):
        self.alpha = alpha

    def __repr__(self):

        return f"Split_Reweight(alpha={self.alpha})"

    def reweight_logits(
        self, code: AbstractWatermarkCode, p_logits: FloatTensor
    ) -> FloatTensor:
        
        start = time.time()
        # s_ means shuffled
        s_logits = torch.gather(p_logits, -1, code.shuffle) #torch.Size([256, 50264])
        s_probs=torch.softmax(s_logits,dim=-1)
        bsz,vocab_size=s_logits.shape
        split1=list(range(0,round(vocab_size/2)))
        split2=list(range(round(vocab_size/2),vocab_size))
        
        scale_a_1=torch.minimum(2*torch.ones((bsz,)).to(s_probs.device),1/s_probs[:,split1].sum(dim=-1))
        # scale_a_2=(1-scale_a_1*s_probs[:,split1].sum(dim=-1))/s_probs[:,split2].sum(dim=-1)
        # scale_a_2[torch.isnan(scale_a_2)]=1
        
        scale_b_2=torch.minimum(2*torch.ones((bsz,)).to(s_probs.device),1/s_probs[:,split2].sum(dim=-1))
        
        scale_a_2=2-scale_b_2
        scale_b_1=2-scale_a_1

        reweighted_s_probs_a= torch.cat([s_probs[:,split1]*scale_a_1.view(bsz,-1),s_probs[:,split2]*scale_a_2.view(bsz,-1)],dim=-1)
        reweighted_s_probs_b= torch.cat([s_probs[:,split1]*scale_b_1.view(bsz,-1),s_probs[:,split2]*scale_b_2.view(bsz,-1)],dim=-1)
        

        
        # print(code.split_k)
        reweighted_s_probs=torch.where(code.split_k.repeat(1,vocab_size).to(s_logits.device),reweighted_s_probs_a,reweighted_s_probs_b)

        
        reweighted_s_logits=torch.log(reweighted_s_probs)
        reweighted_logits=torch.gather(reweighted_s_logits, -1, code.unshuffle)    
        
        return reweighted_logits
