#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
from torch import FloatTensor, LongTensor
from transformers import LogitsProcessor

from .base import AbstractReweight, AbstractContextCodeExtractor, AbstractScore, AbstractWatermarkKey
from typing import List
from .beta import Beta_Reweight

class WatermarkLogitsProcessor(LogitsProcessor):
    def __init__(
        self,
        private_key: any,
        reweight: AbstractReweight, # sample strategy
        watermark_key_list: List[AbstractWatermarkKey]
    ):
        self.watermark_key_list=watermark_key_list
        self.private_key=private_key
        self.reweight=reweight

    def __repr__(self):
        watermark_str=', '.join([repr(watermark_key) for watermark_key in self.watermark_key_list])
        
        res_str=f"WatermarkLogitsProcessor(private_key={repr(self.private_key)}, reweight={repr(self.reweight)}, watermark_key_list=[{watermark_str}])"
    
        return res_str

    def get_rng_seed(self, key_list) -> any:
        import hashlib
        m = hashlib.sha256()
        # m.update(self.private_key)    
        for key in key_list:
            m.update(key)
        full_hash = m.digest()
        seed = int.from_bytes(full_hash, "big") % (2**32 - 1)
        return seed

    
    def reset_watermark_key(self,batch_size):
        for watermark_key in self.watermark_key_list:
            watermark_key.reset(batch_size)

    def _get_codes(self, input_ids: LongTensor):
        batch_size = input_ids.size(0)

        mask=[]
        seeds=[]
        for batch_idx in range(batch_size):
            cur_mask=0
            key_list=[self.private_key]
            for watermark_key in self.watermark_key_list:
                cur_wm_mask,cur_wm_key=watermark_key.generate_key_and_mask(input_ids[batch_idx],batch_idx)
                if cur_wm_key is not None:
                    key_list.append(cur_wm_key)
                cur_mask=(cur_mask or cur_wm_mask)
            mask.append(cur_mask)
            seeds.append(self.get_rng_seed(key_list))
        # mask, seeds = zip(
        #     *[
        #         (
        #             (context_codes[i] in self.cc_history[i])
        #             or (self.skip_cnt <= self.skip_first_budget),
        #             self.get_rng_seed(context_codes[i], i),
        #         )
        #         for i in range(batch_size)
        #     ]
        # )

        return mask, seeds

    def _core(self, input_ids: LongTensor, scores: FloatTensor):
        mask, seeds = self._get_codes(input_ids)
        rng = [
            torch.Generator(device=scores.device).manual_seed(seed) for seed in seeds
        ]
        mask = torch.tensor(mask, device=scores.device,dtype=torch.bool)
        watermark_code = self.reweight.watermark_code_type.from_random(
            rng, scores.size(1)
        )
        reweighted_scores = self.reweight.reweight_logits(watermark_code, scores)
        return mask, reweighted_scores

    def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> FloatTensor:
        mask, reweighted_scores = self._core(input_ids, scores)
        return torch.where(mask[:, None], scores, reweighted_scores)
    
    
#     def get_green_token_quantile(self, input_ids: LongTensor, vocab_size, current_token):
#         mask, seeds = self._get_codes(input_ids)
#         rng = [
#             torch.Generator(device=input_ids.device).manual_seed(seed) for seed in seeds
#         ]
#         mask = torch.tensor(mask, device=input_ids.device)
#         watermark_code = self.reweight.watermark_code_type.from_random(
#             rng, vocab_size
#         )
#         token_quantile = [(torch.where(watermark_code.shuffle[i] == current_token[i])[0]+1)/vocab_size
#                          for i in range(input_ids.shape[0])]
# #         reweighted_scores = self.reweight.reweight_logits(watermark_code, scores)
#         return token_quantile

    def get_green_token_quantile(self, input_ids: LongTensor, vocab_size, current_token,debug=False):
        mask, seeds = self._get_codes(input_ids)
        rng = [
            torch.Generator(device=input_ids.device).manual_seed(seed) for seed in seeds
        ]
        assert isinstance(self.reweight,Beta_Reweight)
        mask = torch.tensor(mask, device=input_ids.device)
        watermark_code = self.reweight.watermark_code_type.from_random(
            rng, vocab_size
        )
        
        # calculate the score here
        token_quantile = [(torch.where(watermark_code.shuffle[i] == current_token[i])[0]+1)/vocab_size
                        for i in range(input_ids.shape[0])]
        
        # if debug:
        #     print('input_ids shape:',input_ids.shape)
        #     print('midres1:',torch.where(watermark_code.shuffle[0] == current_token[0]))
        #     print('current_token[0]:',current_token)
        #     print('shuffle[0] shape:',watermark_code.shuffle[0].shape)
        return token_quantile
        
    
    def get_score(
        self,
        labels: LongTensor,
        old_logits: FloatTensor,
        new_logits: FloatTensor,
        scorer,
    ) -> FloatTensor:
        raise NotImplementedError
        from unbiased_watermark import (
            RobustLLR_Score_Batch_v1,
            RobustLLR_Score_Batch_v2,
        )

        if isinstance(scorer, RobustLLR_Score_Batch_v1):
            all_scores = scorer.score(old_logits, new_logits)
            query_ids = labels.unsqueeze(-1).expand(
                tuple(-1 for _ in range(decoder_input_ids.ndim))
                + (all_scores.size(-2),)
            )
            #  scores: [batch_size, query_size]
            scores = torch.gather(all_scores, -1, query_ids.unsqueeze(-1)).squeeze(-1)
        elif isinstance(scorer, RobustLLR_Score_Batch_v2):
            llr, max_llr, min_llr = scorer.score(old_logits, new_logits)
            query_ids = labels
            unclipped_scores = torch.gather(llr, -1, query_ids.unsqueeze(-1)).squeeze(
                -1
            )
            #  scores: [batch_size, query_size]
            scores = torch.clamp(unclipped_scores.unsqueeze(-1), min_llr, max_llr)
        return scores

    def get_la_score(
        self,
        input_ids: LongTensor,
        labels: LongTensor,
        vocab_size: int,
    ) -> FloatTensor:
        raise NotImplementedError
        assert "get_la_score" in dir(
            self.reweight
        ), "Reweight does not support likelihood agnostic detection"
        mask, seeds = self._get_codes(input_ids)
        rng = [
            torch.Generator(device=input_ids.device).manual_seed(seed) for seed in seeds
        ]
        mask = torch.tensor(mask, device=input_ids.device)
        watermark_code = self.reweight.watermark_code_type.from_random(rng, vocab_size)
        all_scores = self.reweight.get_la_score(watermark_code)
        scores = torch.gather(all_scores, -1, labels.unsqueeze(-1)).squeeze(-1)
        scores = torch.logical_not(mask).float() * scores
        return scores


def get_score(
    text: str,
    watermark_processor: WatermarkLogitsProcessor,
    score: AbstractScore,
    model,
    tokenizer,
    temperature=0.2,
    prompt: str = "",
    **kwargs,
) -> tuple[FloatTensor, int]:
    raise NotImplementedError
    input_ids = tokenizer.encode(text)
    prompt_len = len(tokenizer.encode(prompt))
    input_ids = torch.tensor(input_ids, device=model.device).unsqueeze(0)
    outputs = model(input_ids)
    logits = (
        torch.cat(
            [torch.zeros_like(outputs.logits[:, :1]), outputs.logits[:, :-1]],
            dim=1,
        )
        / temperature
    )
    new_logits = torch.clone(logits)
    for i in range(logits.size(1)):
        if i == prompt_len:
            watermark_processor.reset_history()
        if i == 0:
            watermark_processor.reset_history()
            continue
        new_logits[:, i] = watermark_processor(input_ids[:, :i], logits[:, i])
    all_scores = score.score(logits, new_logits)
    if input_ids.ndim + 2 == all_scores.ndim:
        # score is RobustLLR_Score_Batch
        input_ids = input_ids.unsqueeze(-1).expand(
            tuple(-1 for _ in range(input_ids.ndim)) + (all_scores.size(-2),)
        )
    scores = torch.gather(all_scores, -1, input_ids.unsqueeze(-1)).squeeze(-1)
    return scores[0], prompt_len


import hashlib


class WatermarkLogitsProcessor_Kuditipudi(LogitsProcessor):
    def __init__(
        self,
        private_key_set: any,
        reweight: AbstractReweight,
    ):
        self.private_key_set = private_key_set
        self.reweight = reweight

    def __repr__(self):
        return f"WatermarkLogitsProcessor_Kuditipudi(Key set length:{len(self.private_key_set)}, {repr(self.reweight)})"

    def get_rng_seed(self) -> any:

        selected_key_idx = torch.randint(
            low=0, high=len(self.private_key_set), size=(1,)
        )
        m = hashlib.sha256()
        m.update(self.private_key_set[selected_key_idx])
        full_hash = m.digest()
        seed = int.from_bytes(full_hash, "big") % (2**32 - 1)
        return seed

    def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> FloatTensor:
        bsz = input_ids.shape[0]
        seeds = [self.get_rng_seed() for i in range(bsz)]
        rng = [
            torch.Generator(device=scores.device).manual_seed(seed) for seed in seeds
        ]
        watermark_code = self.reweight.watermark_code_type.from_random(
            rng, scores.size(1)
        )
        # print('inpute scores:',scores)
        # raise NotImplementedError
        reweighted_scores = self.reweight.reweight_logits(watermark_code, scores)
        return reweighted_scores


class WatermarkLogitsProcessor_Baseline(LogitsProcessor):
    def __repr__(self):
        return f"WatermarkLogitsProcessor_Baseline()"


    def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> FloatTensor:
        return scores