import torch
import hashlib
import numpy as np
from math import sqrt
from transformers import LogitsProcessor
from torch import Tensor


class WatermarkBase:
    def __init__(
        self,
        tokenizer,
        args,
    ):
        # watermarking parameters
        self.tokenizer = tokenizer
        self.args = args
        self.vocab = list(self.tokenizer.get_vocab().values())
        self.vocab_size = len(self.vocab)
        self.gamma = self.args.gamma
        self.delta = self.args.delta
        self.seeding_scheme = self.args.seeding_scheme
        self.hash_key = self.args.hash_key
        self.alpha = self.args.alpha

        self.mask = np.array([True] * int(self.gamma * self.vocab_size) + 
                             [False] * (self.vocab_size - int(self.gamma * self.vocab_size)))
        self.rng = np.random.default_rng(self._hash_fn(self.hash_key))
        self.rng.shuffle(self.mask)
    
    @staticmethod
    def _hash_fn(x: int) -> int:
        """hash function to generate random seed, solution from https://stackoverflow.com/questions/67219691/python-hash-function-that-returns-32-or-64-bits"""
        x = np.int64(x)
        return int.from_bytes(hashlib.sha256(x).digest()[:4], 'little')    


class UnigramLogitsProcessor(WatermarkBase,LogitsProcessor):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.green_list_mask = torch.tensor(self.mask, dtype=torch.float32)

    def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
        scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
        return scores

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if not hasattr(self, "prev_input_len"):
            self.prev_input_len = 0
        if not hasattr(self, "logit_mean_list"):
            self.logit_mean_list = []

        curr_input_len = input_ids.shape[1]

        if curr_input_len <= self.prev_input_len:
            self.logit_mean_list = []

        top1_logit = scores.max(dim=-1).values.mean().item()
        top5_logit = scores.topk(5, dim=-1).values[:,1:].mean().item()
        self.logit_mean_list.append(top1_logit - top5_logit)

        self.prev_input_len = curr_input_len

        self.delta =self.alpha*(sum(self.logit_mean_list)/len(self.logit_mean_list))
        
        greenlist_mask = torch.zeros_like(scores)
        for i in range(input_ids.shape[0]):
            greenlist_mask[i][:len(self.green_list_mask)] = self.green_list_mask
        scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=greenlist_mask.bool(), greenlist_bias=self.delta)
        return scores


class UnigramWatermarkDetector(WatermarkBase):
    """Top-level class of Unigram algorithm"""

    def __init__(
        self,
        device: torch.device = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        assert device, "Must pass device"

        self.device = device
        self.rng = torch.Generator(device=self.device)

        if self.seeding_scheme == "simple_1":
            self.min_prefix_len = 1
        else:
            raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}")
    
    def _compute_z_score(self, observed_count: int, T: int) -> float:
        """Compute z-score for the given observed count and total tokens."""
        expected_count = self.gamma
        numer = observed_count - expected_count * T 
        denom = sqrt(T * expected_count * (1 - expected_count))  
        z = numer / denom
        return z
    
    def _score_sequence(
        self,
        input_ids: Tensor,
        return_z_score: bool = True,
        return_p_value: bool = False,
    ):
        num_tokens_scored = len(input_ids) - self.min_prefix_len
        if num_tokens_scored < 1:
            raise ValueError(
                (
                    f"Must have at least {1} token to score after "
                    f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme."
                )
            )
        
        green_token_count, green_token_mask = 0, []
        for idx in range(self.min_prefix_len, len(input_ids)):
            curr_token = input_ids[idx]
            if self.mask[curr_token] == True:
                green_token_count += 1
                green_token_mask.append(True)
            else:
                green_token_mask.append(False)

        score_dict = {}
        score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
        return score_dict
    
    def detect(
        self,
        text: str = None,
        tokenized_text: list[int] = None,
        **kwargs,
    ) -> dict:
        assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"

        if tokenized_text is None:
            assert self.tokenizer is not None, (
                "Watermark detection on raw string ",
                "requires an instance of the tokenizer ",
                "that was used at generation time.",
            )
            tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
            if tokenized_text[0] == self.tokenizer.bos_token_id:
                tokenized_text = tokenized_text[1:]
        else:
            # try to remove the bos_tok at beginning if it's there
            if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
                tokenized_text = tokenized_text[1:]

        # call score method
        output_dict = {}
        score_dict = self._score_sequence(tokenized_text, **kwargs)

        output_dict.update(score_dict)

        return output_dict