import os
import csv
import json
import numpy as np

from argparse import ArgumentParser
from transformers import AutoTokenizer, AutoModelForCausalLM

from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import ast
import re
from datasets import load_dataset
import torch.nn.functional as F


import os



ARC_Challenge_prompts = dict(
    # base
    prompt_1="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_2="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_3="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",

    # first half token
    prompt_4="You are a very useful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_5="You are a very smart AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_6="You are a very friendly AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",

    # latter half token
    prompt_7="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the suitable answer (A, B, C, or D).\nAnswer:",
    prompt_8="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the letter of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_9="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the choice of the correct answer (A, B, C, or D).\nAnswer:",

    # fewer misalignment
    prompt_10="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the answer (A, B, C, or D) below.\nAnswer:",
    prompt_11="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the answer (A, B, C, or D) carefully.\nAnswer:",
    prompt_12="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the answer (A, B, C, or D) now.\nAnswer:",   

    # more misalignment
    prompt_13="Please choose the best option and respond only with the option of the correct answer (A, B, C, or D) below.\nYou are a helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",
    prompt_14="Please choose the best option and respond only with the option of the correct answer (A, B, C, or D) carefully.\nYou are a helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",
    prompt_15="Please choose the best option and respond only with the option of the correct answer (A, B, C, or D) now.\nYou are a very helpful AI assistant. Answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",    
)

CommonSenseQA_prompts = dict(
    prompt_1="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD} E. {textE}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, D, or E).\nAnswer:",
    prompt_2="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD} E. {textE}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, D, or E).\nAnswer:",
    prompt_3="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD} E. {textE}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, D, or E).\nAnswer:",

    prompt_4="You are a very useful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD} E. {textE}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, D, or E).\nAnswer:",
    prompt_5="You are a very smart AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD} E. {textE}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, D, or E).\nAnswer:",
    prompt_6="You are a very friendly AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD} E. {textE}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, D, or E).\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD} E. {textE}\nPlease choose the best option and respond only with the option of the suitable answer (A, B, C, D, or E).\nAnswer:",
    prompt_8="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD} E. {textE}\nPlease choose the best option and respond only with the letter of the correct answer (A, B, C, D, or E).\nAnswer:",
    prompt_9="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD} E. {textE}\nPlease choose the best option and respond only with the choice of the correct answer (A, B, C, D, or E).\nAnswer:",

    prompt_10="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD} E. {textE}\nPlease choose the best option and respond only with the option of the answer (A, B, C, D, or E) below.\nAnswer:",
    prompt_11="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD} E. {textE}\nPlease choose the best option and respond only with the option of the answer (A, B, C, D, or E) carefully.\nAnswer:",
    prompt_12="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD} E. {textE}\nPlease choose the best option and respond only with the option of the answer (A, B, C, D, or E) now.\nAnswer:",   

    prompt_13="Please choose the best option and respond only with the option of the correct answer (A, B, C, D, or E) below.\nYou are a helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD} E. {textE}\nAnswer:",
    prompt_14="Please choose the best option and respond only with the option of the correct answer (A, B, C, D, or E) carefully.\nYou are a helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD} E. {textE}\nAnswer:",
    prompt_15="Please choose the best option and respond only with the option of the correct answer (A, B, C, D, or E) now.\nYou are a very helpful AI assistant. Answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD} E. {textE}\nAnswer:", 
)

MMLU_prompts = dict(
    prompt_1="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_2="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_3="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",

    prompt_4="You are a very useful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_5="You are a very smart AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_6="You are a very friendly AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the suitable answer (A, B, C, or D).\nAnswer:",
    prompt_8="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the letter of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_9="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the choice of the correct answer (A, B, C, or D).\nAnswer:",

    prompt_10="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the answer (A, B, C, or D) below.\nAnswer:",
    prompt_11="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the answer (A, B, C, or D) carefully.\nAnswer:",
    prompt_12="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the answer (A, B, C, or D) now.\nAnswer:",   

    prompt_13="Please choose the best option and respond only with the option of the correct answer (A, B, C, or D) below.\nYou are a helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",
    prompt_14="Please choose the best option and respond only with the option of the correct answer (A, B, C, or D) carefully.\nYou are a helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",
    prompt_15="Please choose the best option and respond only with the option of the correct answer (A, B, C, or D) now.\nYou are a very helpful AI assistant. Answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:", 
)
OpenBookQA_prompts = dict(
    prompt_1="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_2="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_3="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",

    prompt_4="You are a very useful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_5="You are a very smart AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_6="You are a very friendly AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the suitable answer (A, B, C, or D).\nAnswer:",
    prompt_8="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the letter of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_9="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the choice of the correct answer (A, B, C, or D).\nAnswer:",

    prompt_10="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the answer (A, B, C, or D) below.\nAnswer:",
    prompt_11="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the answer (A, B, C, or D) carefully.\nAnswer:",
    prompt_12="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the answer (A, B, C, or D) now.\nAnswer:",   

    prompt_13="Please choose the best option and respond only with the option of the correct answer (A, B, C, or D) below.\nYou are a helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",
    prompt_14="Please choose the best option and respond only with the option of the correct answer (A, B, C, or D) carefully.\nYou are a helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",
    prompt_15="Please choose the best option and respond only with the option of the correct answer (A, B, C, or D) now.\nYou are a very helpful AI assistant. Answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:", 
)

MMLU_prompts = dict(
    prompt_1="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_2="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_3="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",

    prompt_4="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_5="You are a very smart AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_6="You are a very friendly AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_8="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the letter of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_9="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the choice of the correct answer (A, B, C, or D).\nAnswer:",

    prompt_10="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the answer (A, B, C, or D) below.\nAnswer:",
    prompt_11="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the answer (A, B, C, or D) carefully.\nAnswer:",
    prompt_12="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the answer (A, B, C, or D) now.\nAnswer:",   

    prompt_13="Please choose the best option and respond only with the option of the correct answer (A, B, C, or D) below.\nYou are a helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",
    prompt_14="Please choose the best option and respond only with the option of the correct answer (A, B, C, or D) carefully.\nYou are a helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",
    prompt_15="Please choose the best option and respond only with the option of the correct answer (A, B, C, or D) now.\nYou are a very helpful AI assistant. Answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:", 
)
OpenBookQA_prompts = dict(
    prompt_1="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_2="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_3="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",

    prompt_4="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_5="You are a very smart AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_6="You are a very friendly AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",

    prompt_7="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_8="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the letter of the correct answer (A, B, C, or D).\nAnswer:",
    prompt_9="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the choice of the correct answer (A, B, C, or D).\nAnswer:",

    prompt_10="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the answer (A, B, C, or D) below.\nAnswer:",
    prompt_11="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the answer (A, B, C, or D) carefully.\nAnswer:",
    prompt_12="You are a very helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nPlease choose the best option and respond only with the option of the answer (A, B, C, or D) now.\nAnswer:",   

    prompt_13="Please choose the best option and respond only with the option of the correct answer (A, B, C, or D) below.\nYou are a helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",
    prompt_14="Please choose the best option and respond only with the option of the correct answer (A, B, C, or D) carefully.\nYou are a helpful AI assistant. Please answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:",
    prompt_15="Please choose the best option and respond only with the option of the correct answer (A, B, C, or D) now.\nYou are a very helpful AI assistant. Answer the following questions:\nQuestion: {question}\nA. {textA} B. {textB} C. {textC} D. {textD}\nAnswer:", 
)

device = "cuda" if torch.cuda.is_available() else "cpu"

def arguments():
    parser = ArgumentParser()
    parser.add_argument('--model_name_or_path',
                        type=str, default="Qwen/Qwen1.5-1.8B")
    parser.add_argument('--dataset', type=str, default="ARC_Challenge", 
                        choices=["ARC_Challenge", "CommonSenseQA", "MMLU", "OpenBookQA"])
    parser.add_argument(
        '--cache_path', default='')
    args = parser.parse_args()
    return args

def load_model(args):
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True, use_fast=True, cache_dir=args.cache_path)
    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, cache_dir=args.cache_path).to(device)
    model.eval()
    return model, tokenizer

def build_prompt_ARC_Challenge(data):
    prompt_list = []
    question = data['question']
    choices = data['choices']
    answerKey = data['answerKey']

    for prompt_key, prompt_template in ARC_Challenge_prompts.items():
        choices_text = choices['text']
        if len(choices_text) != 4:
            continue
        prompt = prompt_template.format(question=question,textA=choices_text[0],textB=choices_text[1],textC=choices_text[2],textD=choices_text[3])
         
        prompt = prompt + " " + answerKey
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list

def build_prompt_Other(data):
    eval_data_list = [] 
    
    prompt_list = []
    prompt_list.append({
        "prompt_key": data['label'],
        "prompt": "Please predict the next token: " + data['text']
    })

    eval_data_list.append(
                {
                    "question_id": id,
                    "prompt_list": prompt_list
                }
            )
    return prompt_list

def build_prompt_CSQA(data):
    prompt_list = []
    question = data['question']
    choices = data['choices']
    answerKey = data['answerKey']

    for prompt_key, prompt_template in CommonSenseQA_prompts.items():
        choices_text = choices['text']
        if len(choices_text) != 5:
            continue
        prompt = prompt_template.format(question=question,textA=choices_text[0],textB=choices_text[1],textC=choices_text[2], textD=choices_text[3], textE=choices_text[4])
        prompt = prompt + " " + answerKey
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list

def build_prompt_MMLU(data):
    prompt_list = []
    question = data['question']
    choices = data['choices']
    answerKey = data['answer']
    mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
    for prompt_key, prompt_template in MMLU_prompts.items():
        choices_text = choices
        if len(choices_text) != 4:
            continue
        prompt = prompt_template.format(question=question,textA=choices_text[0],textB=choices_text[1],textC=choices_text[2],textD=choices_text[3])
        prompt = prompt + " " + mapping[answerKey]
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list

def build_prompt_OpenBookQA(data):
    prompt_list = []
    question = data['question_stem']
    choices = data['choices']
    answerKey = data['answerKey']

    for prompt_key, prompt_template in OpenBookQA_prompts.items():
        choices_text = choices['text']
        if len(choices_text) != 4:
            continue
        prompt = prompt_template.format(question=question,textA=choices_text[0],textB=choices_text[1],textC=choices_text[2],textD=choices_text[3])
        prompt = prompt + " " + answerKey
        prompt_list.append({
            "prompt_key": prompt_key,
            "prompt": prompt
        })
    return prompt_list

def encode_ids_and_mask(tokenizer, text):
    inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    input_ids = inputs['input_ids'].to(device)     
    attention_mask = inputs['attention_mask'].to(device)     
    return input_ids, attention_mask


def saliency(model, input_ids, attention_mask):
    torch.enable_grad()

    correct_id = input_ids[0][-1]
    input_ids = input_ids[0][:-1]
    attention_mask = attention_mask[0][:-1]

    input_ids = input_ids.unsqueeze(0)
    attention_mask = attention_mask.unsqueeze(0)

    model.zero_grad(set_to_none=True)
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_hidden_states=True,
        use_cache=False,           
        return_dict=True
    )

    hidden_states = list(outputs.hidden_states)  # tuple -> list
    for h in hidden_states:
        h.retain_grad()

    last_logits = outputs.logits[:, -1, :]
    last_logprobs = torch.log_softmax(last_logits, dim=-1)
    score = last_logprobs[0, correct_id]

    score.backward()

    log_prob = score.item()

    grads = [
        h.grad.detach().cpu().squeeze(0).numpy() if (h.grad is not None) else None
        for h in hidden_states
    ]
    output_hidden_states = hidden_states = [
            h.detach().cpu().squeeze(0).numpy()   # [seq_len, hidden_size]
            for h in hidden_states
        ]
    return log_prob, grads, output_hidden_states


def frobenius_norm(grads, length_norm=False):
    """Compute Frobenius norm of the gradients matrix.
    Optionally normalize by sequence length."""
    fro = np.linalg.norm(grads, ord='fro')
    if length_norm:
        L = grads.shape[0]
        fro = fro / L
    return fro

def forward_score_from_ids(model, input_ids, input_mask, correct_id=None):
    if correct_id is None:
        correct_id = input_ids[-1]
    # prefix only
    ids = input_ids[:-1]
    mask = input_mask[:-1]
    ids_t = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
    mask_t = torch.tensor(mask, dtype=torch.long, device=device).unsqueeze(0)
    with torch.no_grad():
        outputs = model(
            ids_t, 
            attention_mask=mask_t,
            output_hidden_states=True
        )
        last_logits = outputs.logits[:, -1, :]  # [1, V]
        last_logprobs = torch.log_softmax(last_logits, dim=-1)
        score = last_logprobs[0, correct_id].item()
    hidden_states = hidden_states = [
            h.detach().cpu().squeeze(0).numpy()   # [seq_len, hidden_size]
            for h in outputs.hidden_states
        ]
    return score, hidden_states


def calculate_gradient(model, tokenizer, prompt):
    input_ids, attention_mask = encode_ids_and_mask(tokenizer, prompt)
    log_prob, layer_grads, hidden_states = saliency(model, input_ids, attention_mask)
    grads_norms = [float(frobenius_norm(grads)) for grads in layer_grads]
    grads_norms_length_norm = [float(frobenius_norm(grads, length_norm=True)) for grads in layer_grads]
    return {
        "input_ids": input_ids.squeeze(), 
        "attention_mask": attention_mask.squeeze(), 
        "grads": layer_grads, 
        "grads_norms": grads_norms, 
        "grads_norms_length_norm": grads_norms_length_norm, 
        "log_prob": log_prob,
        "hidden_states": hidden_states
    }

def pad_to_len(arr, target_len):
    cur_len, dim = arr.shape
    if cur_len >= target_len:
        return arr[:target_len, :]
    pad_len = target_len - cur_len
    padding = np.zeros((pad_len, dim), dtype=arr.dtype)
    return np.vstack([arr, padding])

def main(args):
    if args.dataset == "ARC_Challenge":
        dataset = load_dataset("ai2_arc", "ARC-Challenge")
        dataset = dataset["train"]
        dataset = dataset.filter(lambda x: len(x["choices"]["label"]) == 4)

    elif args.dataset == "CommonSenseQA":
        dataset = load_dataset("commonsense_qa")
        dataset = dataset["train"]
        dataset = dataset.filter(lambda x: len(x["choices"]["label"]) == 5)
    elif args.dataset == "MMLU":
        #  ['abstract_algebra', 'all', 'anatomy', 'astronomy', 'auxiliary_train', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions']
        dataset = load_dataset("cais/mmlu", "all")
        dataset = dataset["test"]
        dataset = dataset.filter(lambda x: len(x["choices"]) == 4)
    elif args.dataset == "OpenBookQA":
        dataset = load_dataset("openbookqa", "main")
        dataset = dataset["train"]
        dataset = dataset.filter(lambda x: len(x["choices"]["label"]) == 4)
    
    dataset = dataset.select(range(500))
    eval_data_list = []        

    for index, d in enumerate(dataset):
        if d.get("id") is None:
            id = index
        else:
            id = d['id']
        if args.dataset == "ARC_Challenge":
            prompt_list = build_prompt_ARC_Challenge(d)
        elif args.dataset == "CommonSenseQA":
            prompt_list = build_prompt_CSQA(d)
        elif args.dataset == "MMLU":
            prompt_list = build_prompt_MMLU(d)
        elif args.dataset == "OpenBookQA":
            prompt_list = build_prompt_OpenBookQA(d)

        eval_data_list.append(
            {
                "question_id": id,
                "prompt_list": prompt_list
            }
        )
        
    print(len(eval_data_list))

    result_list = []
    for eval_data in tqdm(eval_data_list, total=len(eval_data_list)):
        question_id = eval_data['question_id']
        prompt_list = eval_data['prompt_list']

        group_result_list = []
        for item in prompt_list:
            prompt_key = item['prompt_key']
            prompt = item['prompt']
            result = calculate_gradient(model, tokenizer, prompt)
            group_result_list.append(
                {
                    "prompt_key": prompt_key,
                    "result": result
                }
            )
       
        for i, rst0 in enumerate(group_result_list):
            prompt_key0 = rst0['prompt_key']
            result0 = rst0['result']
            for rst1 in group_result_list[i+1:]:
                prompt_key1 = rst1['prompt_key']
                result1 = rst1['result']
                if prompt_key0 == prompt_key1: continue

                delta_log_prob = result0['log_prob'] - result1['log_prob']
                
                seq_len = result0['hidden_states'][0].shape[0]

                # 对齐每层 hidden states
                embs0 = [h[:seq_len, :] for h in result0['hidden_states']]
                embs1 = [pad_to_len(h, seq_len) for h in result1['hidden_states']]

                hidden_states0 = [h[:seq_len, :] for h in result0['hidden_states']]
                hidden_states1 = [pad_to_len(h, seq_len) for h in result1['hidden_states']]

                layer_len = len(embs0)

                delta_hiddens = []
                linear_preds = []
                for i in range(layer_len):
                    # 差异范数（最后一个时间步）
                    diff_last = hidden_states0[i] - hidden_states1[i]
                    delta_hiddens.append(float(frobenius_norm(diff_last)))

                    # 线性近似
                    delta_emb = embs1[i] - embs0[i]                  
                    grads = pad_to_len(result0['grads'][i], seq_len) 
                    linear_preds.append(float(np.sum(grads * delta_emb)))

                result_list.append(
                    {   
                        "question_id": question_id, # question_id
                        "prompt_key0": prompt_key0, # prompt_key0
                        "prompt_key1": prompt_key1, # prompt_key1
                        "len0": len(result0['input_ids']), # len0
                        "len1": len(result1['input_ids']), # len1
                        "delta_log_prob": delta_log_prob, # delta_log_prob
                        "grads": result0['grads_norms'], # grad
                        "grads_length_norm": result0['grads_norms_length_norm'], # grad_length_norm
                        "linear_approx": linear_preds,
                        "delta_hiddens": delta_hiddens
                    }
                )
    save_path = f"results/data_results/real_dataset_misalignment/{args.model_name_or_path}/{args.dataset}_result.jsonl"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, "w", encoding="utf-8") as f:
        for item in result_list:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")

if __name__ == '__main__':
    args = arguments()
    model, tokenizer = load_model(args)
    args.model = model
    args.tokenizer = tokenizer
    main(args)