import os
import numpy as np
import pandas as pd

from argparse import ArgumentParser
from transformers import AutoTokenizer, AutoModelForCausalLM

from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from datasets import load_dataset



device = "cuda" if torch.cuda.is_available() else "cpu"

def arguments():
    parser = ArgumentParser()
    parser.add_argument('--model_name_or_path',
                        type=str, default="meta-llama/Llama-3.2-3B")
    parser.add_argument('--dataset', type=str, default="ARC_Challenge", 
                        choices=["ARC_Challenge", "CommonSenseQA", "MMLU", "OpenBookQA"])
    parser.add_argument(
        '--cache_path', default='')
    args = parser.parse_args()
    return args


ARC_Challenge_prompts = dict(
    prompt_1="{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_2="Question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_3="Question:\n{question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",

    prompt_4="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_5="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_6="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_8="As an exceptionally resourceful AI assistant, I'm at your service. Address the questions below:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_9="As a helpful Artificial Intelligence Assistant, please answer the following questions\n{question} A. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",

    prompt_10="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_11="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_12="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer this question by replying A, B, C or D.",
)
CommonSenseQA_prompts = dict(
    prompt_1="{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer:",
    prompt_2="Question:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer:",
    prompt_3="Question:\n{question} A. {A} B. {B} C. {C} D. {D} E. {E}\nAnswer:",

    prompt_4="Could you provide a response to the following question: {question} A. {A} B. {B} C. {C} D. {D} E. {E}",
    prompt_5="Please answer the following question:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\n E. {E}",
    prompt_6="Please address the following question:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions: {question} A. {A} B. {B} C. {C} D. {D} E. {E}",
    prompt_8="As an exceptionally resourceful AI assistant, I'm at your service. Address the questions below:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}",
    prompt_9="As a helpful Artificial Intelligence Assistant, please answer the following questions\n{question} A. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}",

    prompt_10="Could you provide a response to the following question: {question} A. {A} B. {B} C. {C} D. {D} E. {E}\nAnswer the question by replying A, B, C, D or E.",
    prompt_11="Please answer the following question:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer the question by replying A, B, C, D or E.",
    prompt_12="Please address the following question:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer this question by replying A, B, C, D or E.",
)
MMLU_prompts = dict(
    prompt_1="{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_2="Question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_3="Question:\n{question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",

    prompt_4="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_5="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_6="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_8="As an exceptionally resourceful AI assistant, I'm at your service. Address the questions below:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_9="As a helpful Artificial Intelligence Assistant, please answer the following questions\n{question} A. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",

    prompt_10="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_11="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_12="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer this question by replying A, B, C or D.",
)
OpenBookQA_prompts = dict(
    prompt_1="{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_2="Question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",
    prompt_3="Question:\n{question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",

    prompt_4="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_5="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_6="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions: {question} A. {textA} B. {textB} C. {textC} D. {textD}",
    prompt_8="As an exceptionally resourceful AI assistant, I'm at your service. Address the questions below:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",
    prompt_9="As a helpful Artificial Intelligence Assistant, please answer the following questions\n{question} A. {textA}\nB. {textB}\nC. {textC}\nD. {textD}",

    prompt_10="Could you provide a response to the following question: {question} A. {textA} B. {textB} C. {textC} D. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_11="Please answer the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer the question by replying A, B, C or D.",
    prompt_12="Please address the following question:\n{question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer this question by replying A, B, C or D.",
)

def load_model(args):
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True, use_fast=True, cache_dir=args.cache_path)
    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, cache_dir=args.cache_path).to(device)
    model.eval()
    return model, tokenizer

def build_prompt_ARC_Challenge(data):
    prompt_list = []
    question = data['question']
    choices = data['choices']
    answerKey = data['answerKey']

    for prompt_key, prompt_template in ARC_Challenge_prompts.items():
        choices_text = choices['text']
        if len(choices_text) != 4:
            continue
        prompt = prompt_template.format(question=question,textA=choices_text[0],textB=choices_text[1],textC=choices_text[2],textD=choices_text[3])
        prompt = prompt + " " + answerKey
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list

def build_prompt_CSQA(data):
    prompt_list = []
    question = data['question']
    choices = data['choices']
    answerKey = data['answerKey']

    for prompt_key, prompt_template in CommonSenseQA_prompts.items():
        choices_text = choices['text']
        if len(choices_text) != 5:
            continue
        prompt = prompt_template.format(question=question,A=choices_text[0],B=choices_text[1],C=choices_text[2],D=choices_text[3],E=choices_text[4])
        prompt = prompt + " " + answerKey
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list

def build_prompt_MMLU(data):
    prompt_list = []
    question = data['question']
    choices = data['choices']
    answerKey = data['answer']
    mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
    for prompt_key, prompt_template in MMLU_prompts.items():
        choices_text = choices
        if len(choices_text) != 4:
            continue
        prompt = prompt_template.format(question=question,textA=choices_text[0],textB=choices_text[1],textC=choices_text[2],textD=choices_text[3])
        prompt = prompt + " " + mapping[answerKey]
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list

def build_prompt_OpenBookQA(data):
    prompt_list = []
    question = data['question_stem']
    choices = data['choices']
    answerKey = data['answerKey']

    for prompt_key, prompt_template in OpenBookQA_prompts.items():
        choices_text = choices['text']
        if len(choices_text) != 4:
            continue
        prompt = prompt_template.format(question=question,textA=choices_text[0],textB=choices_text[1],textC=choices_text[2],textD=choices_text[3])
        prompt = prompt + " " + answerKey
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list

def encode_ids_and_mask(tokenizer, text):
    inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    input_ids = inputs['input_ids'].to(device)     
    attention_mask = inputs['attention_mask'].to(device)     
    return input_ids, attention_mask

def saliency(model, input_ids, input_mask, correct=None, foil=None):
    """Compute gradients w.r.t. input embeddings for next-token prediction.
    Works across architectures (GPT-2, Qwen, LLaMA-like) by using inputs_embeds.
    """
    torch.enable_grad()
    model.eval()

    if correct is None:
        correct = input_ids[-1]

    input_ids = input_ids[:-1]
    input_mask = input_mask[:-1]

    device = next(model.parameters()).device
    input_ids_t = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
    input_mask_t = torch.tensor(input_mask, dtype=torch.long, device=device).unsqueeze(0)

    embedding_layer = model.get_input_embeddings()
    emb = embedding_layer(input_ids_t)
    emb = emb.detach()
    emb.requires_grad_(True)

    model.zero_grad(set_to_none=True)
    outputs = model(inputs_embeds=emb, attention_mask=input_mask_t)
    last_logits = outputs.logits[:, -1, :]  # [1, vocab_size]
    last_logprobs = torch.log_softmax(last_logits, dim=-1)

    if foil is not None and correct != foil:
        score = last_logprobs[0, correct] - last_logprobs[0, foil]
    else:
        score = last_logprobs[0, correct]

    score.backward()

    grads = emb.grad.detach().cpu().numpy().squeeze()  # [L-1, d]
    embs = emb.detach().cpu().numpy().squeeze()        # [L-1, d]
    return grads, embs


def frobenius_norm(grads, length_norm=False):
    fro = np.linalg.norm(grads, ord='fro')
    if length_norm:
        L = grads.shape[0]
        fro = fro / L
    return fro


def forward_score_from_ids(model, input_ids, input_mask, correct_id=None):
    device = next(model.parameters()).device
    if correct_id is None:
        correct_id = input_ids[-1]
    # prefix only
    ids = input_ids[:-1]
    mask = input_mask[:-1]
    ids_t = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
    mask_t = torch.tensor(mask, dtype=torch.long, device=device).unsqueeze(0)
    with torch.no_grad():
        outputs = model(
            ids_t, 
            attention_mask=mask_t,
            output_hidden_states=True
        )
        last_logits = outputs.logits[:, -1, :]  # [1, V]
        last_logprobs = torch.log_softmax(last_logits, dim=-1)
        score = last_logprobs[0, correct_id].item()
    hidden_states = hidden_states = [
            h.detach().cpu().squeeze(0).numpy()   # [seq_len, hidden_size]
            for h in outputs.hidden_states
        ]
    return score, hidden_states

def calculate_logit(model, tokenizer, prompt):
    input_ids, attention_mask = encode_ids_and_mask(tokenizer, prompt)

    correct = input_ids[0][-1]
    input_ids = input_ids[0][:-1]
    attention_mask = attention_mask[0][:-1]

    input_ids = input_ids.unsqueeze(0)
    attention_mask = attention_mask.unsqueeze(0)

    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_hidden_states=True,
        use_cache=False,
        return_dict=True
    )

    last_logits = outputs.logits[:, -1, :]
    last_logprobs = torch.log_softmax(last_logits, dim=-1)
    logit = last_logprobs[0, correct]

    return logit.item()

def main(args):
    if args.dataset == "ARC_Challenge":
        dataset = load_dataset("ai2_arc", "ARC-Challenge")
        dataset = dataset["train"]
        dataset = dataset.filter(lambda x: len(x["choices"]["label"]) == 4)
        
    elif args.dataset == "CommonSenseQA":
        dataset = load_dataset("commonsense_qa")
        dataset = dataset["train"]
        dataset = dataset.filter(lambda x: len(x["choices"]["label"]) == 5)
        
    elif args.dataset == "MMLU":
        dataset = load_dataset("cais/mmlu", "all")
        dataset = dataset["test"]
        dataset = dataset.filter(lambda x: len(x["choices"]) == 4)
        pass
    elif args.dataset == "OpenBookQA":
        dataset = load_dataset("openbookqa", "main")
        dataset = dataset["train"]
        dataset = dataset.filter(lambda x: len(x["choices"]["label"]) == 4)
        pass
    
    dataset = dataset.select(range(500))
    dataset.to_json(f"dataset/{args.dataset}/logit_500.jsonl", orient="records", lines=True)
    

    eval_data_list = []        

    for index, d in enumerate(dataset):
        if d.get("id") is None:
            id = index
        else:
            id = d['id']
        if args.dataset == "ARC_Challenge":
            prompt_list = build_prompt_ARC_Challenge(d)
        elif args.dataset == "CommonSenseQA":
            prompt_list = build_prompt_CSQA(d)
        elif args.dataset == "MMLU":
            prompt_list = build_prompt_MMLU(d)
        elif args.dataset == "OpenBookQA":
            prompt_list = build_prompt_OpenBookQA(d)

        eval_data_list.append(
            {
                "question_id": id,
                "prompt_list": prompt_list
            }
        )
        
    print(len(eval_data_list))

    result_list = []
    with torch.no_grad():
        for eval_data in tqdm(eval_data_list, total=len(eval_data_list)):
            question_id = eval_data['question_id']
            prompt_list = eval_data['prompt_list']

            for item in prompt_list:
                prompt_key = item['prompt_key']
                prompt = item['prompt']
                logit = calculate_logit(model, tokenizer, prompt)
                result_list.append([question_id, prompt_key, logit])
            
    header = ["question_id", "prompt_id", "logit"]

    df = pd.DataFrame(result_list, columns=header)
    save_path = f"results/data_results/template_vs_question/{args.model_name_or_path}/{args.dataset}_logit.csv"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    df.to_csv(save_path, index=False, encoding="utf-8")
    print(f"Results saved to {save_path}")

if __name__ == '__main__':
    args = arguments()
    model, tokenizer = load_model(args)
    args.model = model
    args.tokenizer = tokenizer
    main(args)