from __future__ import annotations
import torch
from transformers import LogitsProcessor
from torch import Tensor
import scipy.stats
from math import sqrt
import torch.nn.functional as F
import numpy as np

class WatermarkBase:
    def __init__(
        self,
        tokenizer,
        args,
    ):
        # watermarking parameters
        self.tokenizer = tokenizer
        self.args = args
        self.vocab = list(self.tokenizer.get_vocab().values())
        self.vocab_size = len(self.vocab)
        self.gamma = self.args.gamma
        self.delta = self.args.delta
        self.seeding_scheme = self.args.seeding_scheme
        self.hash_key = self.args.hash_key
        self.rng = torch.Generator(device='cuda')
        self.f_scheme = "time"
        self.f_scheme_map = {"time": self._f_time}
        self.window_scheme_map = {"left": self._get_greenlist_ids_left}
        self.prefix_length = 1
        self.prf = torch.randperm(self.vocab_size, device='cuda', generator=self.rng)


    def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
        if seeding_scheme is None:
            seeding_scheme = self.seeding_scheme

        if seeding_scheme == "simple_1":
            assert input_ids.shape[-1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng"
            prev_token = input_ids[-1].item()
            self.rng.manual_seed(self.hash_key * prev_token)
        else:
            raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}")
        return


    def _f(self, input_ids: torch.LongTensor) -> int:
        """Get the previous token."""
        return int(self.f_scheme_map[self.f_scheme](input_ids))

    
    def _f_time(self, input_ids: torch.LongTensor):
        """Get the previous token time."""
        time_result = 1
        for i in range(0, self.prefix_length):
            time_result *= input_ids[-1 - i].item()
        return self.prf[time_result % self.vocab_size]


    def _get_greenlist_ids_left(self, input_ids: torch.LongTensor) -> list[int]:
        """Get greenlist ids for the input_ids via leftHash scheme."""
        self.rng.manual_seed((self.hash_key * self._f(input_ids)) % self.vocab_size)
        greenlist_size = int(self.vocab_size * self.gamma)
        vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
        greenlist_ids = vocab_permutation[:greenlist_size]

        return greenlist_ids

    
    def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
        return self.window_scheme_map["left"](input_ids)


class MorphmarkWatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)
        self.type = "exp"
        self.k_linear = 1.55
        self.k_exp = 1.30
        self.k_log = 2.15
        self.p_0 = 0.15

    def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
        green_tokens_mask = torch.zeros_like(scores)
        for b_idx in range(len(greenlist_token_ids)):
            green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
        final_mask = green_tokens_mask.bool()
        return final_mask

    def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor) -> torch.Tensor:
        # scores = scores.to(torch.float32)  # due to float16 can not equal to 10^50
        probs = torch.softmax(scores / 1.0, dim=-1)

        P_G = probs[greenlist_mask].sum().item()

        if P_G < self.p_0:
            r = 0.0
        else:
            if self.type == "linear":
                r = self.k_linear * P_G
            elif self.type == "exp":
                r = np.exp(self.k_exp * P_G) - 1
            elif self.type == "log":
                r = np.log(self.k_log * P_G + 1)
            else:
                raise ValueError(f"{self.type} is not defined.")

        beta = r * (1 - P_G)
        weights = probs[greenlist_mask]
        normalized_weights = weights / weights.sum()
        probs[greenlist_mask] = probs[greenlist_mask] + normalized_weights * beta
        weights = probs[~greenlist_mask]
        normalized_weights = weights / weights.sum()
        probs[~greenlist_mask] = probs[~greenlist_mask] - normalized_weights * beta
        probs = torch.nan_to_num(probs, nan=0)
        probs = torch.clamp(probs, min=0)
        probs = probs / probs.sum(dim=-1, keepdim=True)
        scores = torch.log(probs) * 1.0

        return scores

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if self.rng is None:
            self.rng = torch.Generator(device=input_ids.device)
            
        batched_greenlist_ids = [None for _ in range(input_ids.shape[0])]

        for b_idx in range(input_ids.shape[0]):
            greenlist_ids = self._get_greenlist_ids(input_ids[b_idx])
            batched_greenlist_ids[b_idx] = greenlist_ids

        green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids)

        scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask)
        return scores
    


class MorphmarkWatermarkDetector(WatermarkBase):
    def __init__(
        self,
        device: torch.device = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        assert device, "Must pass device"

        self.device = device

        if self.seeding_scheme == "simple_1":
            self.min_prefix_len = 1
        else:
            raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}")
    
    def _compute_z_score(self, observed_count, T):
        # count refers to number of green tokens, T is total number of tokens
        expected_count = self.gamma
        numer = observed_count - expected_count * T
        denom = sqrt(T * expected_count * (1 - expected_count))
        z = numer / denom
        return z

    def _compute_p_value(self, z):
        p_value = scipy.stats.norm.sf(z)
        return p_value

    def _score_sequence(
        self,
        input_ids: Tensor,
        return_z_score: bool = True,
        return_p_value: bool = False,
    ):
        num_tokens_scored = len(input_ids) - self.min_prefix_len
        if num_tokens_scored < 1:
            raise ValueError(
                (
                    f"Must have at least {1} token to score after "
                    f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme."
                )
            )
        green_token_count, green_token_mask = 0, []
        for idx in range(self.min_prefix_len, len(input_ids)):
            curr_token = input_ids[idx]
            greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
            if curr_token in greenlist_ids:
                green_token_count += 1
                green_token_mask.append(True)
            else:
                green_token_mask.append(False)

        score_dict = dict()

        if return_z_score:
            score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
        if return_p_value:
            z_score = score_dict.get("z_score")
            if z_score is None:
                z_score = self._compute_z_score(green_token_count, num_tokens_scored)
            score_dict.update(dict(p_value=self._compute_p_value(z_score)))

        return score_dict

    def detect(
        self,
        text: str = None,
        tokenized_text: list[int] = None,
        **kwargs,
    ) -> dict:

        assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"

        if tokenized_text is None:
            assert self.tokenizer is not None, (
                "Watermark detection on raw string ",
                "requires an instance of the tokenizer ",
                "that was used at generation time.",
            )
            tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
            if tokenized_text[0] == self.tokenizer.bos_token_id:
                tokenized_text = tokenized_text[1:]
        else:
            # try to remove the bos_tok at beginning if it's there
            if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
                tokenized_text = tokenized_text[1:]

        # call score method
        output_dict = {}
        score_dict = self._score_sequence(tokenized_text, **kwargs)

        output_dict.update(score_dict)

        return output_dict