from __future__ import annotations
import collections
from math import sqrt

import scipy.stats
import numpy as np
import torch
from torch import Tensor
from tokenizers import Tokenizer
from transformers import LogitsProcessor, LogitsProcessorList
from transformers import BertTokenizer, BertModel, AutoTokenizer
import torch.nn.functional as F
from scipy import stats
import time, json, os
from train_watermark_network import TransformModel
from utils.config import DATA_DIR, MODEL_DIR


class WatermarkBase:
    def __init__(
        self,
        gamma: float,
        delta: float,
        target_tokenizer,
    ):
        self.target_tokenizer =  target_tokenizer
        self.vocab_size = len(self.target_tokenizer)
        self.gamma = gamma
        self.delta = delta

    def _get_greenlist_ids(self, input_ids: torch.LongTensor):
        pass
    
    def _compute_z_score(self, observed_count, T):
        expected_count = self.gamma
        numer = observed_count - expected_count * T
        denom = sqrt(T * expected_count * (1 - expected_count))
        z = numer / denom
        return z

    def _compute_p_value(self, z):
        p_value = scipy.stats.norm.sf(z)
        return p_value
    
    def detect(self, text):
        pass
    
    def _get_bias(self, input_ids: torch.LongTensor) -> list[int]:
        green_list_ids = self._get_greenlist_ids(input_ids).cpu().numpy()
        bias = np.zeros(self.vocab_size, dtype=int)
        bias[green_list_ids] = 1
        return bias
    

class WatermarkContext(WatermarkBase):
    def __init__(
        self,
        device: torch.device,
        chunk_length,
        target_tokenizer,
        delta: float = 4.0,
        gamma: float = 0.5,
        embedding_model: str = "bert-large",
        transform_model_path: str = "transform_model.pth",
        mapping_file: str = "mapping/mapping_gpt2.json"
    ):
        super().__init__(gamma, delta, target_tokenizer)
        self.device = device
        self.embedding_tokenizer = AutoTokenizer.from_pretrained(os.path.join(MODEL_DIR, embedding_model))
        self.embedding_model = BertModel.from_pretrained(os.path.join(MODEL_DIR, embedding_model)).to(self.device)
        self.chunk_length = chunk_length
        transform_model = TransformModel()
        transform_model.load_state_dict(torch.load(os.path.join(MODEL_DIR, transform_model_path)))
        self.transform_model = transform_model.to(self.device)
        with open(os.path.join(DATA_DIR, mapping_file), 'r') as f:
            self.mapping = json.load(f)

    def get_embedding(self, sentence):
        input_ids = self.embedding_tokenizer.encode(sentence, return_tensors="pt", max_length=512, truncation="longest_first")
        input_ids = input_ids.to(self.device)
        with torch.no_grad():
            output = self.embedding_model(input_ids)
        return output[0][:, 0, :]
    
    def get_context_sentence(self, input_ids: torch.LongTensor):
        sentence = self.target_tokenizer.decode(input_ids, skip_special_tokens=True)
        words_2d = self.get_text_split(sentence)
        if len(words_2d[-1]) == self.chunk_length:
            return ' '.join([' '.join(group) for group in words_2d])
        else:
            return ' '.join([' '.join(group) for group in words_2d[:-1]])
    
    def get_text_split(self, sentence):
        words = sentence.split()
        words_2d = [words[i:i + self.chunk_length] for i in range(0, len(words), self.chunk_length)]
        return words_2d

    def scale_vector(self, v):
        mean = np.mean(v)
        v_minus_mean = v - mean
        v_minus_mean = np.tanh(1000*v_minus_mean)
        return v_minus_mean
    
    def detect(
        self,
        text: str = None
    ):
        word_2d = self.get_text_split(text)
        all_value = []
        t_v_pair = []
        for i in range(1, len(word_2d)):
            context_sentence = ' '.join([' '.join(group) for group in word_2d[0:i]])
            current_sentence = ' '.join(word_2d[i])
            if len(context_sentence.split(' ')) < 40:
                continue
            context_embedding = self.get_embedding(context_sentence)
            output = self.transform_model(context_embedding).cpu()[0].numpy()
            similarity_array = self.scale_vector(output)[self.mapping]
            tokens = self.target_tokenizer.encode(current_sentence, return_tensors="pt", add_special_tokens=False)
            tokens_symbols = self.target_tokenizer.convert_ids_to_tokens(tokens[0])
            for index in range(len(tokens[0])):
                all_value.append(-float(similarity_array[tokens[0][index]]))
                t_v_pair.append((tokens_symbols[index], all_value[-1]))

        return np.mean(all_value)
    
    def _get_bias(self, input_ids: torch.LongTensor) -> list[int]:
        context_sentence = self.get_context_sentence(input_ids)
        context_embedding = self.get_embedding(context_sentence)
        output = self.transform_model(context_embedding).cpu()[0].numpy()
        similarity_array = self.scale_vector(output)[self.mapping]
        return -similarity_array

        
class WatermarkLogitsProcessor(LogitsProcessor):

    def __init__(self, watermark_base: WatermarkBase, *args, **kwargs):
        self.watermark_base = watermark_base

    def _bias_logits(self, scores: torch.Tensor, batched_bias: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
        batched_bias = torch.Tensor(batched_bias).to(self.watermark_base.device)
        scores = scores + batched_bias*greenlist_bias
        return scores
    
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        batched_bias = [None for _ in range(input_ids.shape[0])]

        for b_idx in range(input_ids.shape[0]):
            current_bias = self.watermark_base._get_bias(input_ids[b_idx])
            batched_bias[b_idx] = current_bias

        scores = self._bias_logits(scores=scores, batched_bias=batched_bias, greenlist_bias=self.watermark_base.delta)
        return scores
