import numpy as np
from typing import List
import torch
from src.utils import logit
import re
from scipy.stats import mood

### Utility functions to ensure the output is in the right format, otherwise discarded

def find_sublist_idx(response: List[int], candidate_answer: List[int]):

    n = len(candidate_answer)
    for k in range(len(response)):
        if response[k] == candidate_answer[0]:
            if response[k:k+n] == candidate_answer:
                return k
    return -1

def find_common_prefix(l: List[List[int]]):
    k = 0
    for i in range(len(l[0])):
        for subl in l:
            if subl[i] != l[0][i]:
                return k
        k += 1
    return k

def get_model_info(model_name):
    # Define the data based on the table with updated model names
    model_data = {
        "meta-llama_Llama-2-7b-chat-hf": {
            "word_list": ["mangoes", "pineapples", "papayas", "kiwis"],
            "example": "strawberries",
            "format": "Sure!"
        },
        "meta-llama_Llama-2-13b-chat-hf": {
            "word_list": ["peaches", "plums", "cherries", "apricots"],
            "example": "strawberries",
            "format": "Sure!"
        },
        "meta-llama_Llama-2-70b-chat-hf": {
            "word_list": ["mangoes", "pineapples", "papayas", "kiwis"],
            "example": "strawberries",
            "format": "Sure!"
        },
        "mistralai_Mistral-7B-Instruct-v0.1": {
            "word_list": ["strawberries", "blueberries", "raspberries", "blackberries"],
            "example": "apples",
            "format": ""
        },
        "meta-llama_Meta-Llama-3-8B-Instruct": {
            "word_list": ["strawberries", "blueberries", "raspberries", "blackberries"],
            "example": "apples",
            "format": "Sure!"
        },
        "BAAI_Infinity-Instruct-3M-0625-Yi-1.5-9B": {
            "word_list": ["mangoes", "pineapples", "papayas", "kiwis"],
            "example": "strawberries",
            "format": "Sure!"
        },
        "Qwen_Qwen2-7B-Instruct": {
            "word_list": ["peaches", "plums", "cherries", "apricots"],
            "example": "strawberries",
            "format": ""
        },
    }

    # Check if the model name exists in the data
    if model_name in model_data:
        return model_data[model_name]
    else:
        return None
    

### Main function to generate data

def generate_data_custom_delta(watermark, temperature, word_list, context, example, format: str = "Sure! "):
    """Generate data for the Red-Green detection test."""
    
    print("Generating data for the Red-Green detection test.")
    
    prefixes = ["I ate", "I chose", "I picked", "I selected", "I took", "I went for", "I settled on", "I got", "I gathered", "I harvested"]

    out =  []
    out_logits = []

    shape = (len(prefixes), 9, len(word_list))


    k_array = [ (prefix, int(str(i)*context)) for prefix in prefixes for i in range(1,10)]
    
    i = 0


    while i != len(k_array):

        prefix, k = k_array[i]

        prompt = f"Complete the sentence \"{prefix} {k}\" using only and exacty a random word from the list: {word_list}.  Answer in this speific format: {format} {prefix} {k} {example}. (here I chose an other fruit for the sake of the example, you have to choose among {word_list})."

        
        max_new_tokens = 65


        messages = [
            {
                "role": "user",
                "content": prompt
            }
            ]

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        encoded_inputs = watermark.tokenizer.apply_chat_template(messages, return_tensors="pt")
        model_inputs = encoded_inputs.to(device)

        generation_outputs = watermark.generate(
                        model_inputs, 
                        temperature=temperature,
                        max_new_tokens=max_new_tokens)

        n_keys = len(generation_outputs)
        key_idx = 0
        temp_out = []
        temp_logits_out = [] 
        while key_idx != n_keys:
            
            generation_output = generation_outputs[key_idx]
            
            output =  generation_output.sequences[0,model_inputs.shape[-1]:]
            response = watermark.tokenizer.decode(output)
            print("response:", response)

            logits = generation_output.logits

            output = generation_output.sequences[0,model_inputs.shape[-1]:].cpu().detach().numpy()
            output = list(output)

            

            candidates = [watermark.tokenizer.encode(f"{prefix} {k} {word}.")[2:] for word in word_list]
            n_prefix = find_common_prefix(candidates) 

            tokens_of_interest = [candidate[n_prefix] for candidate in candidates]

            idxs = [find_sublist_idx(output, candidate) for candidate in candidates]
            
            number_of_repetition = [len(re.findall(word, response)) for word in word_list]

            # If only -1 in idxs; retry
            retry = False
            if (np.array(idxs) == -1).all():
                retry = True
                pass
            elif np.sum(number_of_repetition) >= 2: #We don't want the model to first pick his choice and then write it
                retry = True
                pass
            else:
                output_idx = max(idxs) + n_prefix 


                logits = generation_output.logits
                logits = torch.stack(logits, dim=1)/temperature
                probs = torch.softmax(logits, dim=-1)

                probs = probs[0].cpu()
                logits_np = logits[0].cpu()

                print(probs[output_idx, tokens_of_interest])
                print(logits_np[output_idx, tokens_of_interest])

                key_idx += 1

                temp_out.append(probs.cpu().numpy()[output_idx, tokens_of_interest])
                temp_logits_out.append(logits_np.numpy()[output_idx, tokens_of_interest])
                
            if retry:
                generation_output = watermark.generate_key(
                                model_inputs, 
                                temperature=temperature,
                                max_new_tokens=max_new_tokens,
                                key_number=key_idx)
                generation_outputs[key_idx] = generation_output
                
        i += 1 

        out.append(np.mean(temp_out, axis=0))
        out_logits.append(temp_logits_out)
        
    out = np.array(out)
    out = out.reshape(shape)
    out_logits = np.array(out_logits)
    return out, out_logits

def generate_data_context_size_estimation(watermark, temperature, word_list,  example, contexts: List[int] = [1,2,3,4,5], digits_of_interest: List[int] = [1,2], format: str = "Sure! "):
    """Generate data for the context size estimator."""
    
    prefix = "I ate"

    out =  []
    out_logits = []

    shape = (len(contexts), 9, len(digits_of_interest), len(word_list))

    k_array = [str(perturbation) + str(t2)*context for context in contexts for perturbation in range(1,10) for t2 in digits_of_interest]
    
    i = 0

    while i != len(k_array):

        k = k_array[i]

        prompt = f"Complete the sentence \"{prefix} {k}\" using only and exacty a random word from the list: {word_list}.  Answer in this speific format: {format} {prefix} {k} {example}. (here I chose an other fruit for the sake of the example, you have to choose among {word_list})."

        
        max_new_tokens = 65


        messages = [
            {
                "role": "user",
                "content": prompt
            }
            ]

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Apply the chat template
        encoded_inputs = watermark.tokenizer.apply_chat_template(messages, return_tensors="pt")
        model_inputs = encoded_inputs.to(device)

        generation_outputs = watermark.generate(
                        model_inputs, 
                        temperature=temperature,
                        max_new_tokens=max_new_tokens)

        n_keys = len(generation_outputs)
        key_idx = 0
        temp_out = []
        temp_logits_out = [] 
        while key_idx != n_keys:
            
            generation_output = generation_outputs[key_idx]
            
            output =  generation_output.sequences[0,model_inputs.shape[-1]:]
            response = watermark.tokenizer.decode(output)
            print("response:", response)

            logits = generation_output.logits

            output = generation_output.sequences[0,model_inputs.shape[-1]:].cpu().detach().numpy()
            output = list(output)

            

            candidates = [watermark.tokenizer.encode(f"{prefix} {k} {word}.")[2:] for word in word_list]
            n_prefix = find_common_prefix(candidates) 

            tokens_of_interest = [candidate[n_prefix] for candidate in candidates]

            idxs = [find_sublist_idx(output, candidate) for candidate in candidates]
            
            number_of_repetition = [len(re.findall(word, response)) for word in word_list]

            # If only -1 in idxs; retry
            retry = False
            if (np.array(idxs) == -1).all():
                retry = True
                pass
            elif np.sum(number_of_repetition) >= 2: #We don't want the model to first pick his choice and then write it
                retry = True
                pass
            else:
                output_idx = max(idxs) + n_prefix 


                logits = generation_output.logits
                logits = torch.stack(logits, dim=1)/temperature
                probs = torch.softmax(logits, dim=-1)

                probs = probs[0].cpu()
                logits_np = logits[0].cpu()

                print(probs[output_idx, tokens_of_interest])
                print(logits_np[output_idx, tokens_of_interest])

                key_idx += 1

                temp_out.append(probs.cpu().numpy()[output_idx, tokens_of_interest])
                temp_logits_out.append(logits_np.numpy()[output_idx, tokens_of_interest])
                
            if retry:
                generation_output = watermark.generate_key(
                                model_inputs, 
                                temperature=temperature,
                                max_new_tokens=max_new_tokens,
                                key_number=key_idx)
                generation_outputs[key_idx] = generation_output
                
        i += 1 

        out.append(np.mean(temp_out, axis=0))
        out_logits.append(temp_logits_out)
        
    out = np.array(out)
    out = out.reshape(shape)
    out_logits = np.array(out_logits)
    return out, out_logits


def test_kgw_detection(data: np.array, num_permutations: int):
    
    data = logit(data) 
    
    # Select the token x to look at
    weigth = np.sum(data, axis= (0,1))
    chosen = np.argmax(weigth)
    
    data = data  - np.median(data, axis = 1, keepdims=True)*0
    
    # Calculate the observed statistic
    observed_statistics = statistic(data, chosen)
    statistics = np.zeros(num_permutations)
        
    # Permutation loop
    for i in range(num_permutations):

        # Permute the entire dataset
        permuted_data = np.random.permutation(data.reshape(-1,4)).reshape(-1, 9,4)
        res = statistic(permuted_data, chosen)
        statistics[i] = res
    
    p_value = np.mean(statistics >= observed_statistics)
    
    return observed_statistics, statistics, p_value
    
    
def statistic(data: np.array, chosen: int):

    data = data.reshape(-1, 9,4) # Axis 1 is the last digit axis
    result_array = data[:,:,chosen]
    
    median = np.median(result_array, axis = 1)
    std = np.median(np.std(result_array, axis = 0))

    r = 1.96
    red = result_array.T - median < - r*std    
    green = result_array.T - median > r*std
    
    
    red_score = np.sum(red, axis = 1)
    green_score = np.sum(green, axis = 1)
    
    max_red = np.max(red_score)
    min_green = np.min(green_score)
    max_green = np.max(green_score)
    min_red = np.min(red_score)
    
    max_common = np.max([max_red, max_green])
    min_common = np.max([min_red, min_green])

    return max_common - min_common

def estimate_context_size(data: np.array, significance: float = 0.05):
    
    n_context, n_perturb, n_t, n_choices =  data.shape
    
    weights = np.sum(data, axis=(0,1,2))
    choice = np.argmax(weights)
    
    data = logit(data)
    
    # starting from smallest_context:
    chosen_contexts = []
    for t in range(n_t):
        for context in range(1, n_context + 1):
            
            if context == 1:
                samples1 = data[context-1, :, t, choice]
            else:
                samples2 = data[context-1, :, t, choice]
                res = mood(samples1, samples2, alternative='greater')
                
                if res.pvalue < significance:
                    chosen_contexts.append(context)
                    break
    
    if len(chosen_contexts) == 0:
        return -1
    else:
        return chosen_contexts