import argparse
import torch 
import numpy as np
import json
import seaborn as sns
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from eval.needle.utils import load_context, insert_needle
from my_utils.my_generation import set_topk, my_greedy_generate_with_probe, my_greedy_generate_standard
from my_utils.load_model import load_model

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='codellama-7b-hf', 
                    choices=['meta-llama/Meta-Llama-3.1-8B-Instruct', 
                             'mistralai/Mistral-Nemo-Instruct-2407',
                             "codellama-7b-hf",
                             'microsoft/Phi-3.5-mini-instruct']) # huggingface model id
parser.add_argument('--modified', type=str, default='gemfilter', choices=['gemfilter', 'snapkv', 'h2o']) # None for standard attention
parser.add_argument('--topk', type=int, default=320, help='KV cache size')
parser.add_argument('--ctx_len', type=int, default=16000, help='haystack context token length')
args = parser.parse_args()

model_id = args.model
model2path = json.load(open("eval/LongBench/config/model2path.json", "r"))

model_path = model2path[model_id]
modified = args.modified 
topk = args.topk
ctx_len = args.ctx_len  

if args.modified == 'h2o':
    flash_attention_2 = False
else:
    flash_attention_2 = True

if model_id == 'meta-llama/Meta-Llama-3.1-8B-Instruct':
    select_layer_idx = 31  # 13, 14 out of 32
elif model_id == 'mistralai/Mistral-Nemo-Instruct-2407':
    select_layer_idx = 19  # 19 out of 40
elif model_id == 'microsoft/Phi-3.5-mini-instruct':
    select_layer_idx = 19  # 19 out of 32
else:
    select_layer_idx = 31  # 13, 14 out of 32
    # raise NotImplementedError


torch_dtype=torch.float16
model, tokenizer = load_model(model_path, modified=modified, torch_dtype=torch_dtype, flash_attention_2=flash_attention_2)
# if modified:
#     set_topk(model, topk, mode=modified)

# Construct the Needle-in-a-HayStack Prompt
needle = "\nThe best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.\n"

att_scores_probes = []
depths = [0.1, 0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
for depth in depths:
    context = load_context(fpath="eval/needle/PaulGrahamEssays/*.txt", ctx_len=ctx_len)
    context = insert_needle(context, needle, depth=depth)
    needle_idx = context.find("The best thing to do in San Francisco is")
    # print("Context has %d chars, needle inserted at %d char location:\n" % (len(context), needle_idx))
    # print(context[needle_idx - 150: needle_idx + 150]) # look at how the needle is inserted 

    prompt ="\n<|im_start|> This is a very long story book: <book> %s </book>.\n" % context
    question = "What is the best thing to do in San Francisco?"
    prompt += "Based on the content of the book, Question: %s\nAnswer:" % question
    # print(prompt) # feel the length of 100K

    # Check how the model performs
    needle_idx = prompt.find(needle)
    begin_token_id = len(tokenizer(prompt[:needle_idx]).input_ids)
    neddle_token_len = len(tokenizer(needle).input_ids) - 2
    prompt = tokenizer(prompt, return_tensors="pt")
    input_ids = prompt['input_ids'].to(model.device)
    attn_mask = prompt["attention_mask"].to(model.device)
    # print(tokenizer.decode(input_ids[:, begin_token_id:begin_token_id+neddle_token_len-2][0]))

    print("After tokenization, there is %d tokens" % len(input_ids[0]))
    with torch.no_grad():
        if modified == 'gemfilter':
            response, compressed_context, att_scores = my_greedy_generate_with_probe(
                input_ids, attn_mask, model, tokenizer, max_gen_len=50, select_layer_idx=select_layer_idx, print_context=False, probe_context=(begin_token_id, neddle_token_len))
            att_scores_probes.append(att_scores)
        else:
            response = my_greedy_generate_standard(input_ids, attn_mask, model, tokenizer, max_gen_len=50)
    print("Response:", response.split("\n")[0])

if len(att_scores_probes) > 0:
    avg_score = torch.stack(att_scores_probes, dim=0).mean(0)
    plt.figure(figsize=avg_score.shape)  # 设置图形大小
    sns.heatmap(avg_score.to('cpu'), annot=True, fmt=".2f", cmap='viridis')  # annot=True 显示数值，fmt 指定格式

    # 添加标题
    plt.title('Heatmap Example')
    plt.savefig('heatmap.png', dpi=300, bbox_inches='tight') 
    # 显示图形

    plt.show()
    score_layers = avg_score.sum(1)
    l = torch.sort(-score_layers).indices
    print('layer order', l)
    score_selected_layer = avg_score[l[0]] # top 1 layer
    h = torch.sort(-score_selected_layer).indices
    print('head order of selected layer', h)
    select_num = 8


    import_heads = torch.topk(score_selected_layer, select_num).values.sum() / score_selected_layer.sum()
    print(f'score:{import_heads}. import_heads:', torch.topk(score_selected_layer, select_num).indices)


