import os
import torch
import numpy as np
import pickle
import sys
sys.path.append('../')
from utils import get_activations_bau, tokenized_tqa, tokenized_tqa_gen, tokenized_tqa_gen_end_q
from utils import alt_tqa_evaluate, flattened_idx_to_layer_head, layer_head_to_flattened_idx, get_interventions_dict, get_top_heads, get_separated_activations, get_com_directions, svd_decomposition,svd_decomposition_2
import llama
import qwen2
import argparse
import json
from tqdm import tqdm
from einops import rearrange
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
from transformers import AutoModelForCausalLM, AutoTokenizer
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str,default="Qwen2.5-7B-Instruct")
parser.add_argument('--debug', type=int, default=1, help='if set, only use 100 samples for debugging')
parser.add_argument('--main_steer_style', type=str,default="DRC",help="main style")
parser.add_argument('--second_steer_style', type=str,default="None",help="second style")
parser.add_argument('--dataset_name', type=str, default="DRC", help='test dataset')
parser.add_argument('--num_heads', type=int, default=64, help='K, number of top heads to intervene on')
args = parser.parse_args()



# # ===================== 准备参数====================
# 确定 model  

if args.model_name == "Qwen1.5-14B-Chat":
    model_path = "/new_disk1/xxxx-3/projects/PretrainModels/Qwen1.5-14B-Chat"
model_name = args.model_name
# if "qwen" in args.model_path.lower():
#     model_name = "Qwen2.5-7B-Instruct"
# elif "llama" in args.model_path.lower():
#     model_name = "Llama3-8b-Instruct"
# else:
#     raise NotImplementedError 

dataset_name = args.dataset_name

# a = 'top_'
# b = '_heads_alpha_'
# index_a = args.model_path.find(a)
# index_b = args.model_path.find(b)
# K = args.model_path[index_a + len(a): index_b]
# alpha = args.model_path[index_b + len(b): ]
# dump_path = K + '_' + alpha

# dump_path = "64_3.0"
print(args.num_heads)
dump_path = str(args.num_heads)  # 64_3.0
print(dump_path)    # 64_3.0

# # ===================== 准备模型====================
tokenizer = qwen2.Qwen2Tokenizer.from_pretrained(model_path)
model = qwen2.Qwen2ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto")
device = "cuda"
# tokenizer = AutoTokenizer.from_pretrained(args.model_path)
# model = AutoModelForCausalLM.from_pretrained(args.model_path, low_cpu_mem_usage=True, torch_dtype=torch.float32, device_map="auto")

# # ===================== 准备测试集====================

if dataset_name == "tqa_gen":
    # main_steer_style = args.main_steer_style 
    with open("dataset/Valid_tqa_gen.json", 'r', encoding='utf-8') as file:
        data_list = json.load(file)
elif dataset_name == "tqa_gen_zh":
    with open("dataset/Valid_tqa_gen_zh.json", 'r', encoding='utf-8') as file:
        data_list = json.load(file)
main_steer_style = args.main_steer_style 
# ===================== data split ====================
questions = []
if args.debug == 1:
    data_list = data_list[:10]
for QA in data_list:
    questions.append(QA["question"])
answers = []


def my_generate(q_tokens, inputs):
    generated = inputs["input_ids"]
    sequence = []
    max_length = 600

    for i in range(max_length):
        with torch.no_grad():
            outputs = model(generated)
            next_token_logits = outputs.logits[:, -1, :]
            probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
            token = torch.tensor([probs.argmax().item()]).unsqueeze(0).to('cuda:0')
            generated = torch.cat((generated, token), dim=1)
            q_tokens = torch.cat((q_tokens, token), dim=1)
            sequence.append(token.cpu().numpy()[0][0])
            if token.item() in {
            tokenizer.convert_tokens_to_ids("<|endoftext|>"),
            tokenizer.convert_tokens_to_ids("<|im_end|>"),
            tokenizer.convert_tokens_to_ids("'<|im_start|>'")
            }:
                break
            if token.cpu().numpy()[0][0] == 151643 or token.cpu().numpy()[0][0] == 151644 or token.cpu().numpy()[0][0] == 151645: 
                break


    generated_text = tokenizer.decode(generated[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
    return generated_text


def main():
        
    # q_tokens = tokenizer(question, return_tensors = 'pt').input_ids
    # get_activations(q_tokens)

    if "DRC" in args.dataset_name :
        systeam_content = "请你以对话的形式直接对下面的语句作出回应："
    elif args.dataset_name == "tqa_gen":
        systeam_content = "Please respond to the following statement, and do not output any unnecessary content:"
    elif "tqa_gen_zh" in args.dataset_name:
        systeam_content = "请你以对话的形式直接对下面的语句作出回应："
    elif args.dataset_name == "Shakespeare":
        systeam_content = "Please respond to the following statement, and do not output any unnecessary content:"
    # questions = [""]
    for index, question in enumerate(questions):
        if "instruct" in args.model_name.lower():
            messages =  [
                            {"role": "system", "content": systeam_content},
                            {"role": "user", "content": f"{question}"},
                        ] 
            prompt_str = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            model_inputs = tokenizer(prompt_str, return_tensors='pt').to(model.device)
            q_tokens = model_inputs.input_ids
            # import ipdb; ipdb.set_trace()
        else:
            if "zh" in args.dataset_name:
                prompt = f"Instruction:\n请你对下面的语句作出回复：\n\n### Input:\n{question}\n\n### Response:\n以下是我对该语句的回复：\n"
            else:
                prompt = f"Please respond to the following statement, and do not output any unnecessary content: \n{question}\nOkay, my answer is as follows:\n"
          
            inputs = tokenizer(prompt, return_tensors = 'pt')
            model_inputs = {k: v.to(model.device) for k, v in inputs.items()} 
            q_tokens = inputs.input_ids.to(model.device)
        sequence = my_generate(q_tokens, model_inputs)
        print(index,question)
        print(sequence)
        answers.append(sequence)
        # answer = tokenizer.decode(sequence, skip_special_tokens=True)
    output_data = []
    for i in range(len(questions)):
        dict = {}
        dict["question"] = questions[i]
        dict["origin_model_completion"] = answers[i]
       
        # dict["model_path"] = model_path
        output_data.append(dict)
    # ============== 存储 ============
    with open(f"results_log/{model_name}_eval{dataset_name}_origin_result.json", 'w', encoding='utf-8') as new_file:
        json.dump(output_data, new_file, ensure_ascii=False, indent=4)

#==================================================================================================
if __name__ == "__main__":
    main()

