import os  
import sys  
import json  
import random  
import numpy as np  
import torch  
import matplotlib.pyplot as plt  
from tqdm import tqdm  
from transformers import AutoModelForCausalLM, AutoTokenizer  
import warnings  
  
warnings.filterwarnings('ignore')  
  
def get_output(model, instruction, tokenizer, input=None, temperature=0.5, top_p=0.2, top_k=40, num_beams=4, max_new_tokens=1, device='cuda'):  
    if input:  
        prompt = instruction + input  
    else:  
        prompt = instruction  
    inputs = tokenizer(prompt, return_tensors="pt")  
    input_ids = inputs["input_ids"].to(device)  
    generation_config = GenerationConfig(  
        temperature=temperature,  
        top_p=top_p,  
        top_k=top_k,  
        pad_token_id=0  
    )  
    generation_output = model.generate(  
        input_ids=input_ids,  
        output_hidden_states=True,  
        generation_config=generation_config,  
        return_dict_in_generate=True,  
        max_new_tokens=max_new_tokens,  
        num_return_sequences=1  
    )  
    return generation_output  
  
def select_different_items(list1, list2):  
    if len(list1) != len(list2):  
        raise ValueError("Both lists must have the same length")  
    index1 = random.randint(0, len(list1) - 1)  
    item1 = list1[index1]  
    valid_indices = [i for i in range(len(list2)) if i != index1]  
    index2 = random.choice(valid_indices)  
    item2 = list2[index2]  
    return item1, item2  
  
def select_same_index_items(list1, list2):  
    if len(list1) != len(list2):  
        raise ValueError("Both lists must have the same length")  
    index1 = 0  
    item1 = list1[index1]  
    item2 = list2[index1]  
    return item1, item2  
  
def get_r_lists_cossim(model, tokenizer, datapath1, datapath2, seed, r=500, same_index=False):  
    with open(datapath1, 'r') as f:  
        sentences_1 = json.load(f)  
    with open(datapath2, 'r') as f:  
        sentences_2 = json.load(f)  
  
    allcos = []  
    for sss in range(r):  
        random.seed(seed)  
        seed = seed + 1  
        if datapath1 == datapath2:  
            instruction1, instruction2 = select_different_items(sentences_1, sentences_2)  
            instruction1 = instruction1['instruction']  
            instruction2 = instruction2['instruction']  
        else:  
            if same_index:  
                instruction1, instruction2 = select_same_index_items(sentences_1, sentences_2)  
                instruction1 = instruction1['instruction']  
                instruction2 = instruction2['instruction']  
            else:  
                instruction1, instruction2 = select_different_items(sentences_1, sentences_2)  
                instruction1 = instruction1['instruction']  
                instruction2 = instruction2['instruction']  
  
        all_vectors = []  
        generation_output1 = get_output(model=model, instruction=instruction1, tokenizer=tokenizer)  
        hs1 = generation_output1['hidden_states']  
        for i in range(len(hs1[0])):  
            if i == 0:  
                continue  
            all_vectors.append(hs1[0][i][0][-1])  
  
        all_vectors2 = []  
        generation_output2 = get_output(model=model, instruction=instruction2, tokenizer=tokenizer)  
        hs2 = generation_output2['hidden_states']  
        for i in range(len(hs2[0])):  
            if i == 0:  
                continue  
            all_vectors2.append(hs2[0][i][0][-1])  
  
        cso = []  
        for k in range(len(all_vectors2)):  
            try:  
                a = all_vectors[k].cpu().detach().numpy()  
                b = all_vectors2[k].cpu().detach().numpy()  
            except:  
                a = all_vectors[k].cpu().detach().to(torch.float32).numpy()  
                b = all_vectors2[k].cpu().detach().to(torch.float32).numpy()  
            cosine_similarity = float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))  
            cso.append(cosine_similarity)  
  
        allcos.append(cso)  
  
    print('end')  
    return allcos  
  
def compute_stats(data):  
    mean = np.mean(data, axis=0)  
    std = np.std(data, axis=0)  
    upper_bound = mean + std  
    lower_bound = mean - std  
    return mean, upper_bound, lower_bound  
  
def compute_differ(mean_list):  
    differences = np.diff(mean_list)  
    differences = np.insert(differences, 0, np.nan)  # Insert NaN at the start  
    return differences  
  
def plot_similarity(results, save_dir, language, fill=True):  
    English_English = np.array(results["English-English-differ"])  
    language_English_same = np.array(results["language-English-same"])  
  
    mean_English_English, en_language_up, en_language_down = compute_stats(English_English)  
    mean_language_English_same, language_English_same_up, language_English_same_down = compute_stats(language_English_same)  
  
    fig, ax = plt.subplots(figsize=(8, 4))  
    len_layers = English_English.shape[1]  
  
    ax.plot(mean_English_English, label="Q1 & Q3 similarity", color="red", linewidth=1.2)  
    if fill:  
        ax.fill_between(np.arange(len_layers), en_language_up, en_language_down, where=(en_language_up > en_language_down), color="red", alpha=0.2)  
  
    ax.plot(mean_language_English_same, label="Q1 & Q2 similarity", color="green", linewidth=1.2)  
    if fill:  
        ax.fill_between(np.arange(len_layers), language_English_same_up, language_English_same_down, where=(language_English_same_up > language_English_same_down), color="green", alpha=0.2)  
  
    ax.set_title("Layer-wise Average Cosine Similarity", fontsize="medium")  
    ax.set_xlabel("Layer", fontsize="small")  
    ax.set_ylabel("Cosine Similarity Value", fontsize="small")  
    ax.tick_params(axis="both", labelsize="x-small")  
    ax.set_xticks(np.arange(0, len_layers, 1))  
  
    ax.legend(loc="lower left", fontsize="small")  
  
    save_path = f"{save_dir}/comparison_plot.png"  
    plt.savefig(save_path, dpi=500)  
  
    return 0, mean_English_English, mean_language_English_same, 0  
  
def main(en_path='normal.json', language_path='malicious.json', model_path='meta-llama/Llama-2-7bf', save_dir='cos_sims/', r=500, language='Chinese'):  
    device_map = 'auto'  
    model = AutoModelForCausalLM.from_pretrained(  
        model_path,  
        device_map=device_map,  
        trust_remote_code=True,  
    )  
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="right", use_fast=False)  
  
    allcos_English_English_pairs = get_r_lists_cossim(model, tokenizer, en_path, en_path, 1000, r, False)  
    allcos_language_English_same_pairs = get_r_lists_cossim(model, tokenizer, en_path, language_path, 2000, r, True)  
  
    results = {  
        "English-English-differ": allcos_English_English_pairs,  
        "language-English-same": allcos_language_English_same_pairs,  
    }  
  
    os.makedirs(save_dir, exist_ok=True)  
    save_path = os.path.join(save_dir, 'all_cos.json')  
    with open(save_path, 'w') as f:  
        json.dump(results, f)  
      
    _, mean_English_English, mean_language_English_same, _ = plot_similarity(results, save_dir, language)  
      
    mean_results = {  
        "mean_English_English": mean_English_English.tolist(),  
        "mean_language_English_same": mean_language_English_same.tolist(),  
    }  
    save_path = os.path.join(save_dir, 'mean_language.json')  
    with open(save_path, 'w') as f:  
        json.dump(mean_results, f)  
  
if __name__ == "__main__":  
    main(  
        en_path='./EN_toy/toy_question.json',  
        language_path='./CH_toy/toy_question.json',  
        model_path="",  
        save_dir='toy_cos_sims/',  
        r=1,  
        language='Chinese'  
    )  