#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
from torch import FloatTensor, LongTensor
from transformers import LogitsProcessor

from .base import AbstractReweight, AbstractContextCodeExtractor, AbstractScore, AbstractWatermarkKey
from typing import List
from .beta import Beta_Reweight
from .dipmark import Dip_Reweight
from .splitmark import Split_Reweight
from .trimark import Tri_Reweight
from .nmark import N_Reweight
from .gumbelmax import GumbelMax_Reweight
from .sta import STA_Reweight
from .unigram import Unigram_Reweight
from .EXP_edit import EXP_edit_Reweight
from .ITS_edit import ITS_edit_Reweight
from .synthid_text import SynthID_Text_Reweight
import json

class WatermarkLogitsProcessor(LogitsProcessor):
    def __init__(
        self,
        private_key: any,
        reweight: AbstractReweight, # sample strategy
        watermark_key_list: List[AbstractWatermarkKey]
    ):
        self.watermark_key_list=watermark_key_list
        self.private_key=private_key
        self.reweight=reweight

    def __repr__(self):
        watermark_str=', '.join([repr(watermark_key) for watermark_key in self.watermark_key_list])
        
        res_str=f"WatermarkLogitsProcessor(private_key={repr(self.private_key)}, reweight={repr(self.reweight)}, watermark_key_list=[{watermark_str}])"
    
        return res_str

    def get_rng_seed(self, key_list) -> any:
        import hashlib
        m = hashlib.sha256()
        # m.update(self.private_key)    
        for key in key_list:
            m.update(key)
        full_hash = m.digest()
        seed = int.from_bytes(full_hash, "big") % (2**32 - 1)
        return seed

    
    def reset_watermark_key(self,batch_size):
        for watermark_key in self.watermark_key_list:
            watermark_key.reset(batch_size)

    def _get_codes(self, input_ids: LongTensor):
        batch_size = input_ids.size(0)

        mask=[]
        seeds=[]
        for batch_idx in range(batch_size):
            cur_mask=0
            key_list=[self.private_key]
            for watermark_key in self.watermark_key_list:
                cur_wm_mask,cur_wm_key=watermark_key.generate_key_and_mask(input_ids[batch_idx],batch_idx)
                if cur_wm_key is not None:
                    key_list.append(cur_wm_key)
                cur_mask=(cur_mask or cur_wm_mask)
            mask.append(cur_mask)
            seeds.append(self.get_rng_seed(key_list))

        return mask, seeds

    def _core(self, input_ids: LongTensor, scores: FloatTensor):
        mask, seeds = self._get_codes(input_ids)
        rng = [
            torch.Generator(device=scores.device).manual_seed(seed) for seed in seeds
        ]
        mask = torch.tensor(mask, device=scores.device,dtype=torch.bool)
        
        # from_random is only called here, creating the Nmark_WatermarkCode class
        if isinstance(self.reweight,N_Reweight) or isinstance(self.reweight,Cluster_N_Reweight):
            watermark_code = self.reweight.watermark_code_type.from_random(
                rng, scores.size(1),self.reweight.n
            )
        else:
            watermark_code = self.reweight.watermark_code_type.from_random(
                rng, scores.size(1)
            )
        # Call the reweight
        reweighted_scores = self.reweight.reweight_logits(watermark_code, scores)
        return mask, reweighted_scores

    def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> FloatTensor:
        mask, reweighted_scores = self._core(input_ids, scores)
        return torch.where(mask[:, None], scores, reweighted_scores)
    

    def get_green_token_quantile(self, input_ids: LongTensor, vocab_size, current_token,debug=False):
        mask, seeds = self._get_codes(input_ids)
        rng = [
            torch.Generator(device=input_ids.device).manual_seed(seed) for seed in seeds
        ]
        assert isinstance(self.reweight,Beta_Reweight) or isinstance(self.reweight,Dip_Reweight)
        mask = torch.tensor(mask, device=input_ids.device)
        watermark_code = self.reweight.watermark_code_type.from_random(
            rng, vocab_size
        )
        
        # calculate the score here
        token_quantile = [(torch.where(watermark_code.shuffle[i] == current_token[i])[0]+1)/vocab_size
                        for i in range(input_ids.shape[0])]
        
        # if debug:
        #     print('input_ids shape:',input_ids.shape)
        #     print('midres1:',torch.where(watermark_code.shuffle[0] == current_token[0]))
        #     print('current_token[0]:',current_token)
        #     print('shuffle[0] shape:',watermark_code.shuffle[0].shape)
        return token_quantile
    
    def get_gumbelmax_score(self, input_ids: LongTensor, vocab_size, current_token,debug=False):
        mask, seeds = self._get_codes(input_ids)
        rng = [
            torch.Generator(device=input_ids.device).manual_seed(seed) for seed in seeds
        ]
        assert isinstance(self.reweight,GumbelMax_Reweight)
        mask = torch.tensor(mask, device=input_ids.device)
        watermark_code = self.reweight.watermark_code_type.from_random(
            rng, vocab_size
        )
        
        scores=[-torch.log(1-watermark_code.v[i,current_token[i]]) for i in range(input_ids.shape[0])]
        
        return scores
    
    def get_sta_score(self, input_ids: LongTensor, vocab_size, current_token,debug=False):
        mask, seeds = self._get_codes(input_ids)
        rng = [
            torch.Generator(device=input_ids.device).manual_seed(seed) for seed in seeds
        ]
        assert isinstance(self.reweight,STA_Reweight)
        mask = torch.tensor(mask, device=input_ids.device)
        watermark_code = self.reweight.watermark_code_type.from_random(
            rng, vocab_size
        )
        
        green_list_size=round(self.reweight.gamma*vocab_size)
        scores=[torch.tensor(current_token[i] in watermark_code.shuffle[i][:green_list_size]).float() for i in range(input_ids.shape[0])]
        
        return scores
    
    def get_unigram_score(self, input_ids: LongTensor, vocab_size, current_token,debug=False):
        mask, seeds = self._get_codes(input_ids)
        rng = [
            torch.Generator(device=input_ids.device).manual_seed(seed) for seed in seeds
        ]
        assert isinstance(self.reweight,Unigram_Reweight)
        mask = torch.tensor(mask, device=input_ids.device)
        watermark_code = self.reweight.watermark_code_type.from_random(
            rng, vocab_size
        )
        
        green_list_size=round(self.reweight.gamma*vocab_size)
        scores=[torch.tensor(current_token[i] in watermark_code.shuffle[i][:green_list_size]).float() for i in range(input_ids.shape[0])]
        
        return scores
        
    
    def get_synthid_text_res(self,input_ids: LongTensor,vocab_size,current_token,debug=False):
        mask, seeds = self._get_codes(input_ids)
        rng = [
            torch.Generator(device=input_ids.device).manual_seed(seed) for seed in seeds
        ]
        assert isinstance(self.reweight,SynthID_Text_Reweight)
        mask = torch.tensor(mask, device=input_ids.device)
        watermark_code = self.reweight.watermark_code_type.from_random(
            rng, vocab_size
        )
        hash_table=watermark_code.hash_table #[bsz,30,vocab_size]
        
        bsz=input_ids.shape[0]
        res=hash_table.float().mean(dim=1)[range(bsz),current_token] #[bsz]
        return res
    
    def get_n_res(self,input_ids: LongTensor,vocab_size,current_token,cur_n,debug=False):
        mask, seeds = self._get_codes(input_ids)
        rng = [
            torch.Generator(device=input_ids.device).manual_seed(seed) for seed in seeds
        ]
        assert isinstance(self.reweight, N_Reweight)
        assert self.reweight.n==cur_n
        mask = torch.tensor(mask, device=input_ids.device)
        watermark_code = self.reweight.watermark_code_type.from_random(
            rng, vocab_size,self.reweight.n
        )
        
        # cur_n=32000
        splits=[]
        if vocab_size%cur_n==0:
            splits=torch.arange(start=0,end=vocab_size).reshape(cur_n,vocab_size//cur_n).to(input_ids.device)
        else:
            for n_idx in range(cur_n):
                splits.append(list(range(round(vocab_size*n_idx/cur_n),round(vocab_size*(n_idx+1)/cur_n))))
        
        scores=[]
        for bsz_idx in range(input_ids.shape[0]):
            cur_k=watermark_code.split_k[bsz_idx]
            if current_token[bsz_idx] in watermark_code.shuffle[bsz_idx][splits[cur_k]]:
                scores.append(1)
            else:
                scores.append(0)
                
        return scores


    def get_cluster_n_res(self,input_ids: LongTensor, vocab_size, current_token, cur_n, cluster_dict):
        assert isinstance(self.reweight, Cluster_N_Reweight)
        assert self.reweight.n == cur_n

        mask, seeds = self._get_codes(input_ids)
        rng = [torch.Generator(device=input_ids.device).manual_seed(seed) for seed in seeds]
        mask = torch.tensor(mask, device=input_ids.device)
        # from_random is only called here, creating the Nmark_WatermarkCode class
        watermark_code = self.reweight.watermark_code_type.from_random(
            rng, vocab_size, self.reweight.n
        )

        if vocab_size % cur_n == 0:
            splits = torch.arange(start=0, end=vocab_size).reshape(cur_n, vocab_size // cur_n).to(input_ids.device)
        else:
            splits = [[] for _ in range(cur_n)]

            # Assign audio tokens to the correct splits
            for token_id, split_index in cluster_dict.items():
                splits[split_index].append(token_id)

            # Assign remaining tokens using block-based splitting
            remaining_tokens = sorted(set(range(vocab_size)) - set(cluster_dict.keys()))
            for n_idx in range(cur_n):
                start = round(len(remaining_tokens) * n_idx / cur_n)
                end = round(len(remaining_tokens) * (n_idx + 1) / cur_n)
                splits[n_idx].extend(remaining_tokens[start:end])

        scores = []
        for bsz_idx in range(input_ids.shape[0]):
            cur_k = watermark_code.split_k[bsz_idx]
            if self.reweight.shuffle:
                current_split = watermark_code.shuffle[bsz_idx][splits[cur_k]]
            else:
                current_split = splits[cur_k]

            if current_token[bsz_idx] in current_split:
                scores.append(1)
            else:
                scores.append(0)

        return scores


import hashlib


class WatermarkLogitsProcessor_Kuditipudi_old(LogitsProcessor):
    def __init__(
        self,
        private_key_set: any,
        reweight: AbstractReweight,
    ):
        self.private_key_set = private_key_set
        self.reweight = reweight

    def __repr__(self):
        return f"WatermarkLogitsProcessor_Kuditipudi(Key set length:{len(self.private_key_set)}, {repr(self.reweight)})"

    def get_rng_seed(self) -> any:

        selected_key_idx = torch.randint(
            low=0, high=len(self.private_key_set), size=(1,)
        )
        m = hashlib.sha256()
        m.update(self.private_key_set[selected_key_idx])
        full_hash = m.digest()
        seed = int.from_bytes(full_hash, "big") % (2**32 - 1)
        return seed

    def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> FloatTensor:
        bsz = input_ids.shape[0]
        seeds = [self.get_rng_seed() for i in range(bsz)]
        rng = [
            torch.Generator(device=scores.device).manual_seed(seed) for seed in seeds
        ]
        watermark_code = self.reweight.watermark_code_type.from_random(
            rng, scores.size(1)
        )
        # print('inpute scores:',scores)
        # raise NotImplementedError
        reweighted_scores = self.reweight.reweight_logits(watermark_code, scores)
        return reweighted_scores



class WatermarkLogitsProcessor_Kuditipudi_OriImplement(LogitsProcessor):
    def __init__(
        self,
        key_set_size: int,
        reweight: AbstractReweight,
    ):
        # pass
        # assert key_set_size>0
        self.key_set_size=key_set_size
        self.private_key_set = torch.arange(start=4200,end=4200+self.key_set_size).numpy().tolist()
        self.reweight = reweight
        
        # # delibrately do no control the rng
        # # Warning: under this implementation, all sentences in the same batch have the same random offset
        self.random_offset=torch.randint(low=0,high=len(self.private_key_set),size=(1,)).item()

    def __repr__(self):
        return f"WatermarkLogitsProcessor_Kuditipudi_OriImplement(key_set_size={self.key_set_size}, reweight={repr(self.reweight)})"

    def get_rng_seed(self,idx) -> any:
        
        # selected_key_idx = torch.randint(
        #     low=0, high=len(self.private_key_set), size=(1,)
        # )
        
        key=self.private_key_set[(self.random_offset+idx)%self.key_set_size]
        return self.get_rng_seed_with_key(key)
    
    
    def get_rng_seed_with_key(self,key:int):
        m = hashlib.sha256()
        # m.update(key.detach().cpu().numpy().tobytes())
        m.update(key.to_bytes(length=64,byteorder='big'))
        full_hash = m.digest()
        seed = int.from_bytes(full_hash, "big") % (2**32 - 1)
        return seed
        

    def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> FloatTensor:
        bsz = input_ids.shape[0]
        cur_idx=input_ids.shape[-1]
        seeds = [self.get_rng_seed(cur_idx) for i in range(bsz)]
        rng = [
            torch.Generator(device=scores.device).manual_seed(seed) for seed in seeds
        ]
        watermark_code = self.reweight.watermark_code_type.from_random(
            rng, scores.size(1)
        )
        # print('inpute scores:',scores)
        # raise NotImplementedError
        reweighted_scores = self.reweight.reweight_logits(watermark_code, scores)
        return reweighted_scores
    
    #gamma: 0.4 for ITS_edit; 0.0 for EXP_edit
    def get_cur_dist(self,input_ids,base_dist,permutations,offsets,gamma):
        #TODO: implement this!
        
        assert input_ids.shape[0]==1 #bsz==1
        
        l1=input_ids.shape[-1]
        l2=self.key_set_size
        l3=offsets.shape[-1]
        l4=permutations.shape[-1] # num of permutations
        
        # print(f'l1: {l1}, l2: {l2}, l3: {l3}',flush=True)
        # res=[[-1 for _ in range(l2)] for _ in range(l1)]
        
        # res=-torch.ones((l1,l2)).to(input_ids.device)
        
        res=torch.zeros(l1+1,l2+1,l3,l4).to(input_ids.device)
        
        
        res[0,:,:,:]=gamma*torch.arange(start=0,end=l2+1).to(input_ids.device).view(-1,1,1).repeat(1,l3,l4)
        res[:,0,:,:]=gamma*torch.arange(start=0,end=l1+1).to(input_ids.device).view(-1,1,1).repeat(1,l3,l4)
        
        
        for idx_sum in range(2,l1+l2+1):
        # for idx_sum in range(2,4):
            min_idx1=max(idx_sum-l2,1)
            min_idx2=max(idx_sum-l1,1)
            
            max_idx1=idx_sum-min_idx2
            max_idx2=idx_sum-min_idx1
            
            
            idx1_list=torch.arange(start=min_idx1,end=max_idx1+1).to(input_ids.device)
            idx2_list=idx_sum-idx1_list
            # idx2_list=torch.arange(start=min_idx2,end=max_idx2+1).to(input_ids.device)
            
            res[idx1_list,idx2_list,:,:]=res[idx1_list-1,idx2_list,:,:]+gamma
            res[idx1_list,idx2_list,:,:]=torch.minimum(res[idx1_list,idx2_list,:,:],res[idx1_list,idx2_list-1,:,:]+gamma)
            ori_idx2=permutations[(idx2_list.view(-1,1)-1+offsets.view(1,-1))%l2,:] #[idx_list_len,l3,l4]
            # print(ori_idx2.shape)
            expanded_idx1_list=(idx1_list-1).view(-1,1,1).repeat(1,l3,l4) #[idx_list_len,l3,l4]

            res[idx1_list,idx2_list,:,:]=torch.minimum(res[idx1_list,idx2_list,:,:],
                                                       res[idx1_list-1,idx2_list-1,:,:]+base_dist[expanded_idx1_list,ori_idx2])
            
            
        return res[l1,l1,:,:]
        
    
    def get_total_dist(self,input_ids,base_dist,permutations,gamma):
        
        
        input_ids=input_ids[:,:self.key_set_size]

            
        offsets=torch.arange(start=0,end=self.key_set_size).to(input_ids.device)
        cur_d=self.get_cur_dist(input_ids,base_dist,permutations,offsets,gamma) #[l3,l4]
        
        return torch.min(cur_d,dim=0)[0] #[l4]

        
    def get_p_val(self, input_ids: LongTensor, vocab_size,gamma):
        bsz=input_ids.shape[0]
        assert bsz==1
        
        
        # watermark_code_list=[]
        l1=input_ids.shape[-1]
        l2=self.key_set_size
        base_dist=torch.zeros((l1,l2)).to(input_ids.device)
        for key_idx in range(l2):
            cur_seed=self.get_rng_seed_with_key(key=self.private_key_set[key_idx])
            rng=[torch.Generator(device=input_ids.device).manual_seed(cur_seed)]
            watermark_code = self.reweight.watermark_code_type.from_random(
                rng, vocab_size
            )
            
            if isinstance(self.reweight,ITS_edit_Reweight):
                cur_indices=watermark_code.unshuffle[0,input_ids[0,:]]
                cur_d=torch.abs(watermark_code.u[0]-cur_indices/(vocab_size-1))
                base_dist[:,key_idx]=cur_d
            
            elif isinstance(self.reweight,EXP_edit_Reweight):
                cur_v=watermark_code.v[0,input_ids[0,:]]
                base_dist[:,key_idx]=torch.log(1-cur_v) 
                # return torch.log(1-cur_v) 

            else:
                print(f'Unknown reweight type: {type(self.reweight)}')
                exit(1)        
        
        
        
        ori_permutation=torch.arange(start=0,end=l2).to(input_ids.device).view(-1,1)
        pos_d=self.get_total_dist(input_ids,base_dist=base_dist,permutations=ori_permutation,gamma=gamma)[0]
        
        max_repeat_t=5000 #ori:5000
        
        p_cnt=1
        
        
        permute_bsz=20
        assert max_repeat_t%permute_bsz==0
        
        for repeat_idx in range(max_repeat_t//permute_bsz):
            # print('-'*80,flush=True)
            # print(f'running repeat idx: {repeat_idx}',flush=True)
            rand_perms=torch.stack([torch.randperm(self.key_set_size).to(input_ids.device) for _ in range(permute_bsz)],dim=-1)
            # cur_input_key_set=[self.private_key_set[i] for i in rand_perm]
            cur_key_set_d=self.get_total_dist(input_ids,base_dist=base_dist,permutations=rand_perms,gamma=gamma)

            p_cnt+=(cur_key_set_d<=pos_d).sum()
        
        return p_cnt/(1+max_repeat_t)
        

class WatermarkLogitsProcessor_Baseline(LogitsProcessor):
    
    def __repr__(self):
        return f"WatermarkLogitsProcessor_Baseline()"


    def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> FloatTensor:
        return scores