import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from tqdm import tqdm
from argparse import ArgumentParser
from datasets import load_dataset

device = "cuda" if torch.cuda.is_available() else "cpu"

def arguments():
    parser = ArgumentParser()
    parser.add_argument('--model_name_or_path',
                        type=str, default="Qwen/Qwen1.5-0.5B")
    parser.add_argument('--dataset_name', default="ARC_Challenge",
                        choices=["ARC_Challenge", "CommonSenseQA", "MMLU", "OpenBookQA"])
    parser.add_argument('--cache_path', default="")
    args = parser.parse_args()
    return args


def load_dataset_by_name(dataset_name):
    if dataset_name == "ARC_Challenge":
        dataset = load_dataset("ai2_arc", "ARC-Challenge")
        dataset = dataset["train"]
        dataset = dataset.filter(lambda x: len(x["choices"]["label"]) == 4)
    elif dataset_name == "CommonSenseQA":
        dataset = load_dataset("commonsense_qa")
        dataset = dataset["train"]
        dataset = dataset.filter(lambda x: len(x["choices"]["label"]) == 5)
    elif dataset_name == "MMLU":
        dataset = load_dataset("cais/mmlu", "all")
        dataset = dataset["test"]
        dataset = dataset.filter(lambda x: len(x["choices"]) == 4)
    elif dataset_name == "OpenBookQA":
        dataset = load_dataset("openbookqa", "main")
        dataset = dataset["train"]
        dataset = dataset.filter(lambda x: len(x["choices"]["label"]) == 4)

    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    dataset = dataset.select(range(10))
    dataset.to_json(f"dataset/{args.dataset_name}/turb_10.jsonl", orient="records", lines=True)
    return dataset


def load_model_and_tokenizer(model_name_or_path, cache_path):
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, use_fast=True, cache_dir=cache_path)
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, cache_dir=cache_path).to(device)
    model.eval()
    return model, tokenizer

ARC_Challenge_prompts = dict(
    prompt_1="{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_2="Question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_3="Question:\n{question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",

    prompt_4="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_5="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_6="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_8="As an exceptionally resourceful AI assistant, I'm at your service. Address the questions below:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_9="As a helpful Artificial Intelligence Assistant, please answer the following questions\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",

    prompt_10="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_11="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_12="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer this question by replying A, B, C or D.",
)
CommonSenseQA_prompts = dict(
    prompt_1="{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer:",
    prompt_2="Question:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer:",
    prompt_3="Question:\n{question} A. {A} B. {B} C. {C} D. {D}\nE. {E}\nAnswer:",

    prompt_4="Could you provide a response to the following question: {question} A. {A} B. {B} C. {C} D. {D} E. {E}",
    prompt_5="Please answer the following question:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\n E. {E}",
    prompt_6="Please address the following question:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions: {question} A. {A} B. {B} C. {C} D. {D} E. {E}",
    prompt_8="As an exceptionally resourceful AI assistant, I'm at your service. Address the questions below:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}",
    prompt_9="As a helpful Artificial Intelligence Assistant, please answer the following questions\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}",

    prompt_10="Could you provide a response to the following question: {question} A. {A} B. {B} C. {C} D. {D} E. {E}\nAnswer the question by replying A, B, C, D or E.",
    prompt_11="Please answer the following question:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer the question by replying A, B, C, D or E.",
    prompt_12="Please address the following question:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer this question by replying A, B, C, D or E.",
)
MMLU_prompts = dict(
    prompt_1="{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_2="Question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_3="Question:\n{question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",

    prompt_4="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_5="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_6="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_8="As an exceptionally resourceful AI assistant, I'm at your service. Address the questions below:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_9="As a helpful Artificial Intelligence Assistant, please answer the following questions\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",

    prompt_10="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_11="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_12="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer this question by replying A, B, C or D.",
)
OpenBookQA_prompts = dict(
    prompt_1="{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_2="Question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_3="Question:\n{question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",

    prompt_4="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_5="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_6="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_8="As an exceptionally resourceful AI assistant, I'm at your service. Address the questions below:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_9="As a helpful Artificial Intelligence Assistant, please answer the following questions\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",

    prompt_10="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_11="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_12="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer this question by replying A, B, C or D.",
)

def build_prompt_ARC_Challenge(data):
    prompt_list = []
    question = data['question']
    choices = data['choices']
    answerKey = data['answerKey']

    for prompt_key, prompt_template in ARC_Challenge_prompts.items():
        choices_text = choices['text']
        if len(choices_text) != 4:
            continue
        prompt = prompt_template.format(question=question,textA=choices_text[0],textB=choices_text[1],textC=choices_text[2],textD=choices_text[3])
        prompt = prompt + " " + answerKey
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list

def build_prompt_CSQA(data):
    prompt_list = []
    question = data['question']
    choices = data['choices']
    answerKey = data['answerKey']

    for prompt_key, prompt_template in CommonSenseQA_prompts.items():
        choices_text = choices['text']
        if len(choices_text) != 5:
            continue
        prompt = prompt_template.format(question=question,A=choices_text[0],B=choices_text[1],C=choices_text[2],D=choices_text[3],E=choices_text[4])
        prompt = prompt + " " + answerKey
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list

def build_prompt_MMLU(data):
    prompt_list = []
    question = data['question']
    choices = data['choices']
    answerKey = data['answer']
    mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
    for prompt_key, prompt_template in MMLU_prompts.items():
        choices_text = choices
        if len(choices_text) != 4:
            continue
        prompt = prompt_template.format(question=question,textA=choices_text[0],textB=choices_text[1],textC=choices_text[2],textD=choices_text[3])
        prompt = prompt + " " + mapping[answerKey]
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list

def build_prompt_OpenBookQA(data):
    prompt_list = []
    question = data['question_stem']
    choices = data['choices']
    answerKey = data['answerKey']

    for prompt_key, prompt_template in OpenBookQA_prompts.items():
        choices_text = choices['text']
        if len(choices_text) != 4:
            continue
        prompt = prompt_template.format(question=question,textA=choices_text[0],textB=choices_text[1],textC=choices_text[2],textD=choices_text[3])
        prompt = prompt + " " + answerKey
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list
def load_eval_data_list(dataset_name, dataset):
    eval_data_list = []
    for index, d in enumerate(dataset):
        if d.get("id") is None:
            id = index
        else:
            id = d['id']
        if dataset_name == "ARC_Challenge":
            prompt_list = build_prompt_ARC_Challenge(d)
        elif dataset_name == "CommonSenseQA":
            prompt_list = build_prompt_CSQA(d)
        elif dataset_name == "MMLU":
            prompt_list = build_prompt_MMLU(d)
        elif dataset_name == "OpenBookQA":
            prompt_list = build_prompt_OpenBookQA(d)

        eval_data_list.append(
            {
                "question_id": id,
                "prompt_list": prompt_list
            }
        )
    return eval_data_list

def encode_ids_and_mask(tokenizer, text, device):
    inputs = tokenizer(text, add_special_tokens=False, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)        # [1, L]
    attention_mask = inputs["attention_mask"].to(device)
    return input_ids, attention_mask

def saliency(model, input_ids, attention_mask):
    torch.enable_grad()
    
    correct_id = input_ids[0][-1]
    input_ids = input_ids[0][:-1]
    attention_mask = attention_mask[0][:-1]

    embedding_layer = model.get_input_embeddings()
    inputs_embeds = embedding_layer(input_ids).detach()
    inputs_embeds.requires_grad_(True)
    inputs_embeds.retain_grad()

    inputs_embeds = inputs_embeds.unsqueeze(0)
    attention_mask = attention_mask.unsqueeze(0)

    model.zero_grad(set_to_none=True)
    outputs = model(
        inputs_embeds=inputs_embeds,
        attention_mask=attention_mask,
        output_hidden_states=True,
        use_cache=False,   
        return_dict=True
        )

    hidden_states = list(outputs.hidden_states)  # tuple -> list
    for h in hidden_states:
        h.retain_grad()

    last_logits = outputs.logits[:, -1, :]
    last_logprobs = torch.log_softmax(last_logits, dim=-1)
    score = last_logprobs[0, correct_id]

    score.backward()
    
    grads = [
        h.grad.detach().reshape(-1) if (h.grad is not None) else None
        for h in hidden_states
    ]

    log_prob = float(score.detach().item())
    return grads, hidden_states, log_prob, correct_id


@torch.inference_mode()
def forward_score_from_embeds(model, inputs_embeds, attention_mask, correct_id):
    inputs_embeds = inputs_embeds.unsqueeze(0)
    attention_mask = attention_mask.unsqueeze(0)
    outputs = model(
        inputs_embeds=inputs_embeds, 
        attention_mask=attention_mask, 
        output_hidden_states=True,
        use_cache=False,   
        return_dict=True
        )
    
    hidden_states = list(outputs.hidden_states)  

    last_logits = outputs.logits[:, -1, :]
    last_logprobs = torch.log_softmax(last_logits, dim=-1)
    score = last_logprobs[0, correct_id]
    log_prob = float(score.detach().item())

    return log_prob, hidden_states


def frobenius_norm(grads, length_norm=False):
    fro = torch.linalg.matrix_norm(grads, ord='fro')
    if length_norm:
        fro = fro / grads.shape[0]
    return float(fro.item())

def get_result_by_example(model, tokenizer, prompt, radius=0.1,  turb_times=100, seed=42):
    input_ids, attention_mask = encode_ids_and_mask(tokenizer, prompt, device)
    layer_grads0, layer_hidden_states0, log_prob0, correct_id = saliency(model, input_ids, attention_mask)
    inputs_embeds0 = layer_hidden_states0[0][0]
    Lm1, d = inputs_embeds0.shape
    generator = torch.Generator(device=device)
    generator.manual_seed(seed)


    delta_log_prob_list = []
    linear_approx_list = []
    loss_list = []
    delta_z_list = []
    grads_list = []
    for _ in tqdm(range(turb_times), total=turb_times, desc=f"r={radius}"):
        rows = torch.randperm(Lm1, generator=generator, device=device)[:5]
        cols = torch.randint(0, d, (5,), generator=generator, device=device)
        delta = torch.randn(5, device=device)
        delta = delta * (float(radius) / (delta.norm() + 1e-12))

        inputs_embeds1 = inputs_embeds0.clone()
        inputs_embeds1[rows, cols] += delta
        
        # one hidden states shape is [1, 48, 1024]
        log_prob1, layer_hidden_states1 = forward_score_from_embeds(model, inputs_embeds1, attention_mask, correct_id)
        
        delta_log_prob = log_prob1 - log_prob0
        
        diff_list = [(hs1 - hs0).reshape(-1) for hs0, hs1 in zip(layer_hidden_states0, layer_hidden_states1)]

        _linear_approx_list = [torch.dot(grad, diff).item() for grad, diff in zip(layer_grads0, diff_list)]

        delta_log_prob_list.append(delta_log_prob)
        linear_approx_list.append(_linear_approx_list)
        _loss_list = [float(np.abs(delta_log_prob - linear_approx)) for linear_approx in _linear_approx_list]
        loss_list.append(_loss_list)

        _delta_z_list = [frobenius_norm(l1[0]-l0[0]) for l0,l1 in zip(layer_hidden_states0, layer_hidden_states1)]
        delta_z_list.append(_delta_z_list)

        _grads_list = [frobenius_norm(grad.unsqueeze(0).unsqueeze(0)) for grad in layer_grads0]
        grads_list.append(_grads_list)

    result_dict = {
        "delta_log_prob_list": delta_log_prob_list,
        "linear_approx_list": linear_approx_list,
        "loss_list": loss_list,
        "delta_z_list": delta_z_list,
        "grads_list": grads_list,
    }
    return result_dict

if __name__ == "__main__":
    args = arguments()
    model, tokenizer = load_model_and_tokenizer(
        model_name_or_path=args.model_name_or_path,
        cache_path=args.cache_path
    )
    dataset = load_dataset_by_name(args.dataset_name)
    eval_data_list = load_eval_data_list(args.dataset_name, dataset)
    print(len(eval_data_list))

    radius_list = [f"{i/100:.2f}" for i in range(1,  101 , 1)]

    save_result_list = []
    for radius in radius_list:
        for eval_data in tqdm(eval_data_list, total=len(eval_data_list)):
            question_id = eval_data['question_id']
            prompt_list = eval_data['prompt_list']

            for item in prompt_list:
                prompt_key = item['prompt_key']
                prompt = item['prompt']
                result_dict = get_result_by_example(model, tokenizer, prompt, radius = radius)
                
                save_item = {
                    "question_id": question_id,
                    "prompt_key": prompt_key,
                    "radius": radius,
                    "delta_log_prob_list": result_dict["delta_log_prob_list"],
                    "linear_approx_list": result_dict["linear_approx_list"],
                    "loss_list": result_dict["loss_list"],
                    "delta_z_list": result_dict["delta_z_list"],
                    "grads_list": result_dict["grads_list"],
                }
                save_result_list.append(save_item)
    
    save_path = f"results/data_results/turb/{args.model_name_or_path}/d-{args.dataset_name}_results.jsonl"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    import json
    with open(save_path, "w", encoding="utf-8") as f:
        for item in save_result_list:
            f.write(json.dumps(item) + "\n")
    print(f"Results saved to {save_path}")