import os
import json
import numpy as np

from argparse import ArgumentParser
from transformers import AutoTokenizer, AutoModelForCausalLM

from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

from datasets import load_dataset
import torch

import os

ARC_Challenge_prompts = dict(
    prompt_1="{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_2="Question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_3="Question:\n{question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",

    prompt_4="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_5="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_6="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_8="As an exceptionally resourceful AI assistant, I'm at your service. Address the questions below:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_9="As a helpful Artificial Intelligence Assistant, please answer the following questions\n{question} A. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",

    prompt_10="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_11="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_12="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer this question by replying A, B, C or D.",
)
CommonSenseQA_prompts = dict(
    prompt_1="{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer:",
    prompt_2="Question:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer:",
    prompt_3="Question:\n{question} A. {A} B. {B} C. {C} D. {D} E. {E}\nAnswer:",

    prompt_4="Could you provide a response to the following question: {question} A. {A} B. {B} C. {C} D. {D} E. {E}",
    prompt_5="Please answer the following question:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\n E. {E}",
    prompt_6="Please address the following question:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions: {question} A. {A} B. {B} C. {C} D. {D} E. {E}",
    prompt_8="As an exceptionally resourceful AI assistant, I'm at your service. Address the questions below:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}",
    prompt_9="As a helpful Artificial Intelligence Assistant, please answer the following questions\n{question} A. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}",

    prompt_10="Could you provide a response to the following question: {question} A. {A} B. {B} C. {C} D. {D} E. {E}\nAnswer the question by replying A, B, C, D or E.",
    prompt_11="Please answer the following question:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer the question by replying A, B, C, D or E.",
    prompt_12="Please address the following question:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer this question by replying A, B, C, D or E.",
)
MMLU_prompts = dict(
    prompt_1="{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_2="Question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_3="Question:\n{question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",

    prompt_4="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_5="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_6="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_8="As an exceptionally resourceful AI assistant, I'm at your service. Address the questions below:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_9="As a helpful Artificial Intelligence Assistant, please answer the following questions\n{question} A. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",

    prompt_10="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_11="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_12="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer this question by replying A, B, C or D.",
)
OpenBookQA_prompts = dict(
    prompt_1="{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_2="Question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_3="Question:\n{question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",

    prompt_4="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_5="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_6="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_8="As an exceptionally resourceful AI assistant, I'm at your service. Address the questions below:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_9="As a helpful Artificial Intelligence Assistant, please answer the following questions\n{question} A. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",

    prompt_10="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_11="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_12="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer this question by replying A, B, C or D.",
)

device = "cuda" if torch.cuda.is_available() else "cpu"

def arguments():
    parser = ArgumentParser()
    parser.add_argument('--model_name_or_path',
                        type=str, default="Qwen/Qwen1.5-0.5B")
    parser.add_argument('--dataset', type=str, default="OpenBookQA", 
                        choices=["ARC_Challenge", "CommonSenseQA", "MMLU", "OpenBookQA"])
    parser.add_argument(
        '--cache_path', default='')
    args = parser.parse_args()
    return args


def load_model(args):
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True, use_fast=True, cache_dir=args.cache_path)
    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, cache_dir=args.cache_path).to(device)
    model.eval()
    return model, tokenizer

def build_prompt_ARC_Challenge(data):
    prompt_list = []
    question = data['question']
    choices = data['choices']
    answerKey = data['answerKey']

    for prompt_key, prompt_template in ARC_Challenge_prompts.items():
        choices_text = choices['text']
        if len(choices_text) != 4:
            continue
        prompt = prompt_template.format(question=question,textA=choices_text[0],textB=choices_text[1],textC=choices_text[2],textD=choices_text[3])
        prompt = prompt + " " + answerKey
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list

def build_prompt_CSQA(data):
    prompt_list = []
    question = data['question']
    choices = data['choices']
    answerKey = data['answerKey']

    for prompt_key, prompt_template in CommonSenseQA_prompts.items():
        choices_text = choices['text']
        if len(choices_text) != 5:
            continue
        prompt = prompt_template.format(question=question,A=choices_text[0],B=choices_text[1],C=choices_text[2],D=choices_text[3],E=choices_text[4])
        prompt = prompt + " " + answerKey
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list

def build_prompt_MMLU(data):
    prompt_list = []
    question = data['question']
    choices = data['choices']
    answerKey = data['answer']
    mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
    for prompt_key, prompt_template in MMLU_prompts.items():
        choices_text = choices
        if len(choices_text) != 4:
            continue
        prompt = prompt_template.format(question=question,textA=choices_text[0],textB=choices_text[1],textC=choices_text[2],textD=choices_text[3])
        prompt = prompt + " " + mapping[answerKey]
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list

def build_prompt_OpenBookQA(data):
    prompt_list = []
    question = data['question_stem']
    choices = data['choices']
    answerKey = data['answerKey']

    for prompt_key, prompt_template in OpenBookQA_prompts.items():
        choices_text = choices['text']
        if len(choices_text) != 4:
            continue
        prompt = prompt_template.format(question=question,textA=choices_text[0],textB=choices_text[1],textC=choices_text[2],textD=choices_text[3])
        prompt = prompt + " " + answerKey
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list

def encode_ids_and_mask(tokenizer, text):
    inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    input_ids = inputs['input_ids'].to(device)     
    attention_mask = inputs['attention_mask'].to(device)     
    return input_ids, attention_mask




def saliency(model, input_ids, attention_mask):
    torch.enable_grad()

    correct_id = input_ids[0][-1]
    input_ids = input_ids[0][:-1]
    attention_mask = attention_mask[0][:-1]

    input_ids = input_ids.unsqueeze(0)
    attention_mask = attention_mask.unsqueeze(0)

    model.zero_grad(set_to_none=True)
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_hidden_states=True,
        use_cache=False,           
        return_dict=True
    )

    hidden_states = list(outputs.hidden_states)  # tuple -> list
    for h in hidden_states:
        h.retain_grad()

    last_logits = outputs.logits[:, -1, :]
    last_logprobs = torch.log_softmax(last_logits, dim=-1)
    score = last_logprobs[0, correct_id]

    score.backward()

    log_prob = score.item()

    grads = [
        h.grad.detach().cpu().squeeze(0).numpy() if (h.grad is not None) else None
        for h in hidden_states
    ]
    output_hidden_states = hidden_states = [
            h.detach().cpu().squeeze(0).numpy()   # [seq_len, hidden_size]
            for h in hidden_states
        ]
    return log_prob, grads, output_hidden_states

def frobenius_norm(grads, length_norm=False):
    """Compute Frobenius norm of the gradients matrix.
    Optionally normalize by sequence length."""
    fro = np.linalg.norm(grads, ord='fro')
    if length_norm:
        L = grads.shape[0]
        fro = fro / L
    return fro

def forward_score_from_ids(model, input_ids, input_mask, correct_id=None):
    """Compute scalar score f(E) = logits[last_pos, correct] without tracking grads.
    The input_ids contains the full sequence including the target last token.
    We use all previous tokens as prefix, and take the final position logits.
    """
    device = next(model.parameters()).device
    if correct_id is None:
        correct_id = input_ids[-1]
    # prefix only
    ids = input_ids[:-1]
    mask = input_mask[:-1]
    ids_t = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
    mask_t = torch.tensor(mask, dtype=torch.long, device=device).unsqueeze(0)
    with torch.no_grad():
        outputs = model(
            ids_t, 
            attention_mask=mask_t,
            output_hidden_states=True
        )
        last_logits = outputs.logits[:, -1, :]  # [1, V]
        last_logprobs = torch.log_softmax(last_logits, dim=-1)
        logit = last_logprobs[0, correct_id].item()
    hidden_states = hidden_states = [
            h.detach().cpu().squeeze(0).numpy()   # [seq_len, hidden_size]
            for h in outputs.hidden_states
        ]
    return logit, hidden_states

def calculate_gradient(model, tokenizer, prompt):
    input_ids, attention_mask = encode_ids_and_mask(tokenizer, prompt)
    log_prob, layer_grads, hidden_states = saliency(model, input_ids, attention_mask)
    grads_norms = [float(frobenius_norm(grads)) for grads in layer_grads]
    grads_norms_length_norm = [float(frobenius_norm(grads, length_norm=True)) for grads in layer_grads]
    return {
        "input_ids": input_ids.squeeze(), 
        "attention_mask": attention_mask.squeeze(), 
        "grads": layer_grads, 
        "grads_norms": grads_norms, 
        "grads_norms_length_norm": grads_norms_length_norm, 
        "log_prob": log_prob,
        "hidden_states": hidden_states
    }

def pad_to_len(arr, target_len):
    """把二维数组 (T, D) 补到 target_len 行，后面补零"""
    cur_len, dim = arr.shape
    if cur_len >= target_len:
        return arr[:target_len, :]
    pad_len = target_len - cur_len
    padding = np.zeros((pad_len, dim), dtype=arr.dtype)
    return np.vstack([arr, padding])

def main(args):
    if args.dataset == "ARC_Challenge":
        dataset = load_dataset("ai2_arc", "ARC-Challenge")
        dataset = dataset["train"]
        dataset = dataset.filter(lambda x: len(x["choices"]["label"]) == 4)
        
    elif args.dataset == "CommonSenseQA":
        dataset = load_dataset("commonsense_qa")
        dataset = dataset["train"]
        dataset = dataset.filter(lambda x: len(x["choices"]["label"]) == 5)
        
    elif args.dataset == "MMLU":
        #  ['abstract_algebra', 'all', 'anatomy', 'astronomy', 'auxiliary_train', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions']
        dataset = load_dataset("cais/mmlu", "all")
        dataset = dataset["test"]
        dataset = dataset.filter(lambda x: len(x["choices"]) == 4)
        pass
    elif args.dataset == "OpenBookQA":
        dataset = load_dataset("openbookqa", "main")
        dataset = dataset["train"]
        dataset = dataset.filter(lambda x: len(x["choices"]["label"]) == 4)
        pass
    
    dataset = dataset.select(range(500))
    dataset.to_json(f"dataset/{args.dataset}/data_500.jsonl", orient="records", lines=True)
    
    eval_data_list = []        

    for index, d in enumerate(dataset):
        if d.get("id") is None:
            id = index
        else:
            id = d['id']
        if args.dataset == "ARC_Challenge":
            prompt_list = build_prompt_ARC_Challenge(d)
        elif args.dataset == "CommonSenseQA":
            prompt_list = build_prompt_CSQA(d)
        elif args.dataset == "MMLU":
            prompt_list = build_prompt_MMLU(d)
        elif args.dataset == "OpenBookQA":
            prompt_list = build_prompt_OpenBookQA(d)

        eval_data_list.append(
            {
                "question_id": id,
                "prompt_list": prompt_list
            }
        )
        
    print(len(eval_data_list))

    result_list = []
    for eval_data in tqdm(eval_data_list, total=len(eval_data_list)):
        question_id = eval_data['question_id']
        prompt_list = eval_data['prompt_list']

        group_result_list = []
        for item in prompt_list:
            prompt_key = item['prompt_key']
            prompt = item['prompt']
            result = calculate_gradient(model, tokenizer, prompt)
            group_result_list.append(
                {
                    "prompt_key": prompt_key,
                    "result": result
                }
            )
       
        for i, rst0 in enumerate(group_result_list):
            prompt_key0 = rst0['prompt_key']
            result0 = rst0['result']
            for rst1 in group_result_list[i+1:]:
                prompt_key1 = rst1['prompt_key']
                result1 = rst1['result']
                if prompt_key0 == prompt_key1: continue

                delta_log_prob = result0['log_prob'] - result1['log_prob']
                
                seq_len = result0['hidden_states'][0].shape[0]

                embs0 = [h[:seq_len, :] for h in result0['hidden_states']]
                embs1 = [pad_to_len(h, seq_len) for h in result1['hidden_states']]

                hidden_states0 = [h[:seq_len, :] for h in result0['hidden_states']]
                hidden_states1 = [pad_to_len(h, seq_len) for h in result1['hidden_states']]

                layer_len = len(embs0)

                delta_hiddens = []
                linear_preds = []
                for i in range(layer_len):
                    diff_last = hidden_states0[i] - hidden_states1[i]
                    delta_hiddens.append(float(frobenius_norm(diff_last)))

                    delta_emb = embs1[i] - embs0[i]                  
                    grads = pad_to_len(result0['grads'][i], seq_len) 
                    linear_preds.append(float(np.sum(grads * delta_emb)))

                result_list.append(
                    {   
                        "question_id": question_id, # question_id
                        "prompt_key0": prompt_key0, # prompt_key0
                        "prompt_key1": prompt_key1, # prompt_key1
                        "len0": len(result0['input_ids']), # len0
                        "len1": len(result1['input_ids']), # len1
                        "delta_log_prob": delta_log_prob, # delta_log_prob
                        "grads": result0['grads_norms'], # grad
                        "grads_length_norm": result0['grads_norms_length_norm'], # grad_length_norm
                        "linear_approx": linear_preds,
                        "delta_hiddens": delta_hiddens
                    }
                )
    save_path = f"results/data_results/real_dataset/{args.model_name_or_path}/{args.dataset}_result.jsonl"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, "w", encoding="utf-8") as f:
        for item in result_list:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")

if __name__ == '__main__':
    args = arguments()
    model, tokenizer = load_model(args)
    args.model = model
    args.tokenizer = tokenizer
    main(args)