#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
from torch import FloatTensor, LongTensor,BoolTensor
from torch.nn import functional as F
import time
from typing import Union

from . import AbstractWatermarkCode, AbstractReweight, AbstractScore
import json


class Trimark_WatermarkCode(AbstractWatermarkCode):
    def __init__(self, shuffle: LongTensor, split_k:BoolTensor):
        self.shuffle = shuffle
        self.split_k = split_k
        self.unshuffle = torch.argsort(shuffle, dim=-1)

    @classmethod
    def from_random(
        cls,
        rng: Union[torch.Generator, list[torch.Generator]],
        vocab_size: int,
    ):
        if isinstance(rng, list):
            batch_size = len(rng)
            shuffle = torch.stack(
                [
                    torch.randperm(vocab_size, generator=rng[i], device=rng[i].device)
                    for i in range(batch_size)
                ]
            )
            split_k = torch.cat([
                torch.randint(low=0,high=3,size=(1,1),dtype=torch.long,generator=rng[i], device=rng[i].device)
                for i in range(batch_size)
            ],dim=0
            )
        else:
            shuffle = torch.randperm(vocab_size, generator=rng, device=rng.device)
            split_k = torch.randint(low=0,high=3,size=(1,1),dtype=torch.long,device=rng.device,generator=rng)
        return cls(shuffle,split_k)


class Tri_Reweight(AbstractReweight):
    watermark_code_type = Trimark_WatermarkCode

    def __init__(self, alpha: float):
        self.alpha = alpha

    def __repr__(self):

        return f"Tri_Reweight(alpha={self.alpha})"

    def reweight_logits(
        self, code: AbstractWatermarkCode, p_logits: FloatTensor
    ) -> FloatTensor:

        def set_nan_to_zero(x):
            x[torch.isnan(x)]=0
            return x
            
        start = time.time()
        # s_ means shuffled
        s_logits = torch.gather(p_logits, -1, code.shuffle) #torch.Size([256, 50264])
        s_probs=torch.softmax(s_logits,dim=-1)
        bsz,vocab_size=s_logits.shape
        split1=list(range(0,round(vocab_size/3)))
        split2=list(range(round(vocab_size/3),round(vocab_size*2/3)))
        split3=list(range(round(vocab_size*2/3),vocab_size))
        
        scale_a_1=torch.minimum(3*torch.ones((bsz,)).to(s_probs.device),1/s_probs[:,split1].sum(dim=-1))
        # scale_a_2=(1-scale_a_1*s_probs[:,split1].sum(dim=-1))/s_probs[:,split2].sum(dim=-1)
        # scale_a_2[torch.isnan(scale_a_2)]=1
        
        scale_b_2=torch.minimum(3*torch.ones((bsz,)).to(s_probs.device),1/s_probs[:,split2].sum(dim=-1))
        scale_c_3=torch.minimum(3*torch.ones((bsz,)).to(s_probs.device),1/s_probs[:,split3].sum(dim=-1))
        
        
        
        scale_a_2,scale_a_3=torch.zeros((bsz,)).to(s_probs.device),torch.zeros((bsz,)).to(s_probs.device)
        scale_b_1,scale_b_3=torch.zeros((bsz,)).to(s_probs.device),torch.zeros((bsz,)).to(s_probs.device)
        scale_c_1,scale_c_2=torch.zeros((bsz,)).to(s_probs.device),torch.zeros((bsz,)).to(s_probs.device)
        
        sum1=s_probs[:,split1].sum(dim=-1)
        sum2=s_probs[:,split2].sum(dim=-1)
        sum3=s_probs[:,split3].sum(dim=-1)
        
        
        #not robust against nan
        for bsz_idx in range(bsz):
            if scale_a_1[bsz_idx]==3:
                scale_b_1[bsz_idx]=scale_c_1[bsz_idx]=0
                scale_b_3[bsz_idx]=(1-scale_b_2[bsz_idx]*sum2[bsz_idx])/sum3[bsz_idx]
                scale_c_2[bsz_idx]=(1-scale_c_3[bsz_idx]*sum3[bsz_idx])/sum2[bsz_idx]
                
                scale_a_2[bsz_idx]=3-scale_b_2[bsz_idx]-scale_c_2[bsz_idx]
                scale_a_3[bsz_idx]=3-scale_b_3[bsz_idx]-scale_c_3[bsz_idx]
                
            elif scale_b_2[bsz_idx]==3:
                scale_a_2[bsz_idx]=scale_c_2[bsz_idx]=0
                
                scale_a_3[bsz_idx]=(1-scale_a_1[bsz_idx]*sum1[bsz_idx])/sum3[bsz_idx]
                scale_c_1[bsz_idx]=(1-scale_c_3[bsz_idx]*sum3[bsz_idx])/sum1[bsz_idx]
                
                scale_b_1[bsz_idx]=3-scale_a_1[bsz_idx]-scale_c_1[bsz_idx]
                scale_b_3[bsz_idx]=3-scale_a_3[bsz_idx]-scale_c_3[bsz_idx]
                
                pass
            elif scale_c_3[bsz_idx]==3:
                scale_a_3[bsz_idx]=scale_b_3[bsz_idx]=0
                
                scale_a_2[bsz_idx]=(1-scale_a_1[bsz_idx]*sum1[bsz_idx])/sum2[bsz_idx]
                scale_b_1[bsz_idx]=(1-scale_b_2[bsz_idx]*sum2[bsz_idx])/sum1[bsz_idx]
                
                scale_c_1[bsz_idx]=3-scale_a_1[bsz_idx]-scale_b_1[bsz_idx]
                scale_c_2[bsz_idx]=3-scale_a_2[bsz_idx]-scale_b_2[bsz_idx]
                
            else:
                print('Unexpected result')
                raise NotImplementedError
    
        scale_a_2=set_nan_to_zero(scale_a_2)
        scale_a_3=set_nan_to_zero(scale_a_3)
        scale_b_1=set_nan_to_zero(scale_b_1)
        scale_b_3=set_nan_to_zero(scale_b_3)
        scale_c_1=set_nan_to_zero(scale_c_1)
        scale_c_2=set_nan_to_zero(scale_c_2)
        
        reweighted_s_probs_a= torch.cat([s_probs[:,split1]*scale_a_1.view(bsz,-1),s_probs[:,split2]*scale_a_2.view(bsz,-1),
                                         s_probs[:,split3]*scale_a_3.view(bsz,-1)],dim=-1)
        reweighted_s_probs_b= torch.cat([s_probs[:,split1]*scale_b_1.view(bsz,-1),s_probs[:,split2]*scale_b_2.view(bsz,-1),
                                         s_probs[:,split3]*scale_b_3.view(bsz,-1)],dim=-1)
        reweighted_s_probs_c=torch.cat([s_probs[:,split1]*scale_c_1.view(bsz,-1),s_probs[:,split2]*scale_c_2.view(bsz,-1),
                                        s_probs[:,split3]*scale_c_3.view(bsz,-1)],dim=-1)

        
        reweighted_s_probs=torch.zeros_like(reweighted_s_probs_a).to(s_logits.device)
        
        split_k=code.split_k.repeat(1,vocab_size).to(s_logits.device)
        reweighted_s_probs=torch.where(split_k==0,reweighted_s_probs_a,reweighted_s_probs)
        reweighted_s_probs=torch.where(split_k==1,reweighted_s_probs_b,reweighted_s_probs)
        reweighted_s_probs=torch.where(split_k==2,reweighted_s_probs_c,reweighted_s_probs)
        
        reweighted_s_probs[reweighted_s_probs<0]=0
        
        reweighted_s_logits=torch.log(reweighted_s_probs)
        reweighted_logits=torch.gather(reweighted_s_logits, -1, code.unshuffle)
        
        
        return reweighted_logits
