import torch
from src.attacks.kgw_detection import find_sublist_idx, find_common_prefix
from src.utils import generate_random_prefix
import re
import numpy as np


def get_cacheTest_phase1(model_name):
    model_data = {
        "mistralai_Mistral-7B-Instruct-v0.1": {
            "word_list": ["apples", "pears"],
            "example": "strawberries",
            "format": ""
        },
        "meta-llama_Llama-2-13b-chat-hf": {
            "word_list": ["peaches", "plums"],
            "example": "strawberries",
            "format": "Sure!"
        },
        "meta-llama_Llama-2-70b-chat-hf": {
            "word_list": ["cherries", "apricots"],
            "example": "strawberries",
            "format": "Sure!"
        }, 
        "meta-llama_Llama-2-7b-chat-hf": { 
            "word_list": ["peaches", "plums"],
            "example": "strawberries",
            "format": "Sure!"
        },
        "meta-llama_Meta-Llama-3-8B-Instruct": { 
            "word_list": ["cherries", "apricots"],
            "example": "strawberries",
            "format": "Sure!"
        },
        "BAAI_Infinity-Instruct-3M-0625-Yi-1.5-9B": {
            "word_list": ["cherries", "apricots"],
            "example": "strawberries",
            "format": "Sure!"
        },
        "Qwen_Qwen2-7B-Instruct": {
            "word_list": ["cherries", "apples"],
            "example": "strawberries",
            "format": ""
        },
    }

    if model_name in model_data:
        return model_data[model_name]
    else:
        return None
    
    
def generate_cache_detection_data(watermark, temperature, word_list,example, format: str = "Sure! ", max_trials: int = 1):
    """Generate data for the cache detection test. If maxtrials is more than 1, the data is used for estimating alpha as well"""
    
    prefix = "I ate"
    
    out =  []
    out_logits = []
    
    n_sucess = 0
    n_trials_phase1 = 0
    
    shape = (max_trials, 2)
    
    mix = watermark.randomize_every > 1
    
    while n_sucess < max_trials:
        
        k = generate_random_prefix()
        n_trials_phase1 += 1
        
        
        prompt = f"Complete the sentence \"{prefix} {k}\" using only and exacty a random word from the list: {word_list}.  Answer in this speific format: {format} {prefix} {k} {example}. (here I chose an other fruit for the sake of the example, you have to choose among {word_list})."
        
        print(word_list)
        max_new_tokens = 65
        
        messages = [
            {
                "role": "user",
                "content": prompt
            }
            ]
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Apply the chat template
        encoded_inputs = watermark.tokenizer.apply_chat_template(messages, return_tensors="pt")
        model_inputs = encoded_inputs.to(device)

        generation_outputs = watermark.generate(
                        model_inputs, 
                        temperature=temperature,
                        max_new_tokens=max_new_tokens,
                        mix=mix)

        temp_out = []
        temp_logits_out = []        

        generation_output = generation_outputs[0] 
        output =  generation_output.sequences[0,model_inputs.shape[-1]:]
        response = watermark.tokenizer.decode(output)
        print("response:", response)

        output = generation_output.sequences[0,model_inputs.shape[-1]:].cpu().detach().numpy()
        output = list(output)


        candidates = [watermark.tokenizer.encode(f"{prefix} {k} {word}.")[2:] for word in word_list]
        n_prefix = find_common_prefix(candidates) 

        tokens_of_interest = [candidate[n_prefix] for candidate in candidates]

        idxs = [find_sublist_idx(output, candidate) for candidate in candidates]
        
        number_of_repetition = [len(re.findall(word, response)) for word in word_list]

        # If only -1 in idxs; retry
        if (np.array(idxs) == -1).all():
            pass
        elif np.sum(number_of_repetition) >= 2: #We don't want the model to first pick his choice and then write it
            pass
        else:
            output_idx = max(idxs) + n_prefix 

            logits = generation_output.logits
            logits = torch.stack(logits, dim=1)/temperature
            probs = torch.softmax(logits, dim=-1)

            probs = probs[0].cpu()
            logits_np = logits[0].cpu()

            print(probs[output_idx, tokens_of_interest])
            print(logits_np[output_idx, tokens_of_interest])
            
            
            probs_np = probs.cpu().numpy()
            if np.max(probs_np[output_idx, tokens_of_interest]) > 0.9: # Failure of phase 1, retrying it
                continue   
            elif  np.max(probs_np[output_idx, tokens_of_interest]) < 0.5: # Failure of phase 1, retrying it
                continue
            else:
                temp_out.append(probs.cpu().numpy()[output_idx, tokens_of_interest])
                temp_logits_out.append(logits_np.numpy()[output_idx, tokens_of_interest])
                
                out.append(np.mean(temp_out, axis=0))
                out_logits.append(temp_logits_out)
                n_sucess += 1
                
                print("-----------------------")

            
    out = np.array(out)
    out = out.reshape(shape)
    out_logits = np.array(out_logits)
    return out, out_logits, n_trials_phase1