from __future__ import annotations
import torch
from transformers import LogitsProcessor
from torch import Tensor
import scipy.stats
from math import sqrt
import torch.nn.functional as F
from .network_model import UPVGenerator, UPVDetector

class WatermarkBase:
    def __init__(
        self,
        tokenizer,
        args,
    ):
        # watermarking parameters
        self.tokenizer = tokenizer
        self.args = args
        self.vocab = list(self.tokenizer.get_vocab().values())
        self.vocab_size = len(self.vocab)
        self.gamma = self.args.gamma
        self.delta = self.args.delta
        self.seeding_scheme = self.args.seeding_scheme
        self.hash_key = self.args.hash_key
        self.rng = None
        self.prefix_length = 1
        self.cache = {}
        self.bit_number = 16
        self.num_beams = None
        self.top_k = 20

        self.generator_model = self._get_generator_model(self.bit_number, self.prefix_length + 1).to('cuda')
        self.detector_model = self._get_detector_model(self.bit_number).to('cuda')

    def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
        if seeding_scheme is None:
            seeding_scheme = self.seeding_scheme

        if seeding_scheme == "simple_1":
            assert input_ids.shape[-1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng"
            prev_token = input_ids[-1].item()
            self.rng.manual_seed(self.hash_key * prev_token)
        else:
            raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}")
        return
    
    def _get_generator_model(self, input_dim: int, window_size: int) -> UPVGenerator:
        """Load the generator model from the specified file."""
        model = UPVGenerator(input_dim, window_size)
        model.load_state_dict(torch.load("./watermarking/processor/upv/model/generator_model_b16_p1.pt"))
        return model
    
    def _get_detector_model(self, bit_number: int) -> UPVDetector:
        """Load the detector model from the specified file."""
        model = UPVDetector(bit_number)
        model.load_state_dict(torch.load("./watermarking/processor/upv/model/detector_model_b16_p1_z4.pt"))
        return model

    def _get_predictions_from_generator(self, input_x: torch.Tensor) -> bool:
        """Get predictions from the generator model."""
        with torch.no_grad():
            output = self.generator_model(input_x)
            output = (output > 0.5).bool().item()
        return output

    def int_to_bin_list(self, n: int, length=8) -> list[int]:
        """Convert an integer to a binary list of specified length."""
        bin_str = format(n, 'b')[:length].zfill(length)
        return [int(b) for b in bin_str]

    def _select_candidates(self, scores: torch.Tensor) -> torch.Tensor:
        """Select candidate tokens based on the scores."""
        if self.num_beams is not None:
            threshold_score = torch.topk(scores, self.num_beams, largest=True, sorted=False)[0][-1]
            return (scores >= (threshold_score - self.delta)).nonzero(as_tuple=True)[0]
        else:
            return torch.topk(scores, self.top_k, largest=True, sorted=False).indices
    
    def _get_greenlist_ids(self, input_ids: torch.Tensor, scores: torch.Tensor) -> list[int]:
        """Get greenlist ids for the input_ids."""
        greenlist_ids = []
        candidate_tokens = self._select_candidates(scores)
        
        # Ensure input_ids is a list for concatenation
        input_ids_list = input_ids.tolist() if isinstance(input_ids, torch.Tensor) else input_ids

        for v in candidate_tokens:
            # Now safely concatenate lists
            pair = input_ids_list[-self.prefix_length:] + [v.item()] if self.prefix_length > 0 else [v.item()]
            merged_tuple = tuple(pair)

            if merged_tuple in self.cache:
                result = self.cache[merged_tuple]
            else:
                bin_list = [self.int_to_bin_list(num, self.bit_number) for num in pair]
                result = self._get_predictions_from_generator(torch.tensor(bin_list, device=input_ids.device).float().unsqueeze(0))
                self.cache[merged_tuple] = result
            if result:
                greenlist_ids.append(int(v))

        return greenlist_ids

class UPVWatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)

    def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
        green_tokens_mask = torch.zeros_like(scores)
        for b_idx in range(len(greenlist_token_ids)):
            green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
        final_mask = green_tokens_mask.bool()
        return final_mask

    def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
        scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
        return scores

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if self.rng is None:
            self.rng = torch.Generator(device=input_ids.device)
            
        batched_greenlist_ids = [None for _ in range(input_ids.shape[0])]

        for b_idx in range(input_ids.shape[0]):
            greenlist_ids = self._get_greenlist_ids(input_ids[b_idx], scores=scores[b_idx])
            batched_greenlist_ids[b_idx] = greenlist_ids

        green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids)

        scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta)
        return scores
    


class UPVWatermarkDetector(WatermarkBase):
    def __init__(
        self,
        device: torch.device = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        assert device, "Must pass device"

        self.device = device
        self.rng = torch.Generator(device=self.device)

        if self.seeding_scheme == "simple_1":
            self.min_prefix_len = 1
        else:
            raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}")
    
    def _compute_z_score(self, observed_count, T):
        # count refers to number of green tokens, T is total number of tokens
        expected_count = self.gamma
        numer = observed_count - expected_count * T
        denom = sqrt(T * expected_count * (1 - expected_count))
        z = numer / denom
        return z

    def _compute_p_value(self, z):
        p_value = scipy.stats.norm.sf(z)
        return p_value

    def _detect_watermark_network_mode(self, encoded_text: torch.Tensor) -> tuple[bool, float]:
            """ Detect watermark using the network mode. """
            # Convert input IDs to binary sequence
            inputs_bin = [self.int_to_bin_list(token_id, self.bit_number) for token_id in encoded_text]
            inputs_bin = torch.tensor(inputs_bin, device='cuda')

            # Run the model on the input binary sequence
            outputs = self.detector_model(inputs_bin.unsqueeze(dim=0).float())
            outputs = outputs.reshape([-1])
            predicted = (outputs.data > 0.5).int()

            # Determine watermark presence based on predictions
            is_watermarked = (predicted == 1).sum().item() > 0

            # z_score is not applicable in network mode
            return is_watermarked, None  

    def _score_sequence(
        self,
        input_ids: Tensor,
        return_z_score: bool = True,
        return_p_value: bool = False,
    ):
        num_tokens_scored = len(input_ids) - self.min_prefix_len
        if num_tokens_scored < 1:
            raise ValueError(
                (
                    f"Must have at least {1} token to score after "
                    f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme."
                )
            )
        green_token_count, green_token_mask = 0, []
        for idx in range(self.min_prefix_len, len(input_ids)):
            curr_token = input_ids[idx]
            greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
            if curr_token in greenlist_ids:
                green_token_count += 1
                green_token_mask.append(True)
            else:
                green_token_mask.append(False)

        score_dict = dict()

        if return_z_score:
            score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
        if return_p_value:
            z_score = score_dict.get("z_score")
            if z_score is None:
                z_score = self._compute_z_score(green_token_count, num_tokens_scored)
            score_dict.update(dict(p_value=self._compute_p_value(z_score)))

        return score_dict

    def _judge_green(self, input_ids: torch.Tensor, current_number: int) -> bool:
        """Judge if the current token is green based on previous tokens."""

        # Get the last 'prefix_length' items from input_ids
        last_nums = input_ids[-self.prefix_length:] if self.prefix_length > 0 else []
        # Append the current number to the list
        pair = list(last_nums) + [current_number]
        merged_tuple = tuple(pair)
        bin_list = [self.int_to_bin_list(num, self.bit_number) for num in pair]

        # load & update cache
        if merged_tuple in self.cache:
            result = self.cache[merged_tuple]
        else:
            result = self._get_predictions_from_generator(torch.tensor(bin_list, device='cuda').float().unsqueeze(0))
            self.cache[merged_tuple] = result

        return result
    
    def green_token_mask_and_stats(self, input_ids: torch.Tensor) -> tuple[list[bool], int, float]:
        """Get green token mask and statistics for the input_ids."""

        # Initialize a list with None for the prefix tokens which are not scored
        mask_list = [None] * self.prefix_length

        # Count of green tokens, initialized to zero
        green_token_count = 0

        # Iterate over each token in the input_ids starting from prefix_length
        for idx in range(self.prefix_length, len(input_ids)):
            # Get the current token
            curr_token = input_ids[idx]

            # Judge if the current token is green based on previous tokens
            if self._judge_green(input_ids[:idx], curr_token):
                mask_list.append(True)  # Mark this token as green
                green_token_count += 1  # Increment the green token counter
            else:
                mask_list.append(False)  # Mark this token as not green

        # Compute the number of tokens that were evaluated for green status
        num_tokens_scored = len(input_ids) - self.prefix_length

        score_dict = dict()
        # Calculate the z-score for the number of green tokens
        z_score = self._compute_z_score(green_token_count, num_tokens_scored)
        score_dict.update(dict(z_score=z_score))
        return score_dict

    def detect(
        self,
        text: str = None,
        tokenized_text: list[int] = None,
        **kwargs,
    ) -> dict:

        assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"

        if tokenized_text is None:
            assert self.tokenizer is not None, (
                "Watermark detection on raw string ",
                "requires an instance of the tokenizer ",
                "that was used at generation time.",
            )
            tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
            if tokenized_text[0] == self.tokenizer.bos_token_id:
                tokenized_text = tokenized_text[1:]
        else:
            # try to remove the bos_tok at beginning if it's there
            if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
                tokenized_text = tokenized_text[1:]

        # call score method
        output_dict = {}
        # score_dict = self._score_sequence(tokenized_text, **kwargs)
        score_dict = self.green_token_mask_and_stats(tokenized_text)

        output_dict.update(score_dict)

        return output_dict