import sys
import os
import json
import numpy as np
import random
from tqdm import tqdm
os.environ["CUDA_VISIBLE_DEVICES"]="4"
current_path = os.path.abspath(os.path.dirname(os.getcwd()))
sys.path.append(os.path.join(current_path))
sys.path.append(os.path.join(current_path, "src")) 
from decoding_algorithm import ContrastiveDecoding
from utils.format_data_mmlu import get_mmlu_data
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['font.family'] = 'Times New Roman'
import seaborn as sns
if __name__ == "__main__":
    choice_number_list = [1, 2, 3, 4]
    # model_name = "/mnt/llms/model/meta-llama/Llama-2-7b-hf"
    # model_name = "/mnt/llms/model/meta-llama/Llama-3-8b-hf"
    model_name="/mnt/llms/model/google-gemma/gemma-2b"
    # model_name= "/mnt/llms/model/meta-llama/Llama-2-13b-hf"
    llm = ContrastiveDecoding(model_name)
    stop_word_list = ["Q:"]
    llm.set_stop_words(stop_word_list)
    n = 10
    m = 8
    SAMPLE_NUM = n * m
    ROUND = 20
    head_count = [np.zeros((llm.decoder_layer_num, llm.decoder_head_num)) for _ in choice_number_list]

    for k, n_choice in enumerate(choice_number_list):
        print("=====option num {} =======".format(n_choice))
        all_data = get_mmlu_data(llm, n_choice=n_choice)
        for _ in tqdm(range(ROUND)):
            data_train = []
            data_test = []
            indexes = list(range(len(all_data)))
            random.shuffle(indexes)
            for i in indexes[:SAMPLE_NUM]:
                data_train.append(all_data[i])
            layer_importance = {l: 0 for l in llm.choose_layer_list}
            for idx in range(n):
                data = [data_train[i] for i in range((idx * m), ((idx + 1) * m))]
                layer_list = llm.contrast_find_layer_list(data, T=0.5, threshold=0, bias=True)
                for l in layer_list:
                    layer_importance[l] += 1
            choose_layer_list = list(dict(sorted(layer_importance.items(), key=lambda item: item[1], reverse=True)[:2]).keys())
            attn_t = llm.contrast_find_head_list(data_train, T=0.5, layer_list=choose_layer_list)
            for l in attn_t[1].keys():
                head_list = attn_t[1][l][0]
                for h in head_list:
                    head_count[k][l][h] += 1
    
    head_count[k] /= ROUND
    fig, axes = plt.subplots(ncols=len(choice_number_list), figsize=(20, 10))
    for i in range(len(choice_number_list)):
        sns.heatmap(head_count[i], annot=True, fmt=".2f", cmap='Blues', square=True, cbar=False, ax=axes[i], vmin=0, vmax=1, annot_kws={"fontsize": 15})
        axes[i].set_title(f'({chr(97+i)}) option number = {i+1}', fontsize=20, pad=20)
        axes[i].set_xlabel('head', fontsize=20)
        axes[i].set_ylabel('layer', fontsize=20)
    plt.tight_layout()  # 自动调整子图参数
    plt.savefig("find_head_heatmap.pdf")
