import torch
import sys
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import json
import argparse
from tqdm import tqdm
from pathlib import Path 
file = Path(__file__).resolve()
parent, root = file.parent, file.parents[1]
sys.path.append(str(root))
from decoding_algorithm import ContrastiveDecoding

subcategories = {
    "abstract_algebra": ["math"],
    "anatomy": ["health"],
    "astronomy": ["physics"],
    "business_ethics": ["business"],
    "clinical_knowledge": ["health"],
    "college_biology": ["biology"],
    "college_chemistry": ["chemistry"],
    "college_computer_science": ["computer science"],
    "college_mathematics": ["math"],
    "college_medicine": ["health"],
    "college_physics": ["physics"],
    "computer_security": ["computer science"],
    "conceptual_physics": ["physics"],
    "econometrics": ["economics"],
    "electrical_engineering": ["engineering"],
    "elementary_mathematics": ["math"],
    "formal_logic": ["philosophy"],
    "global_facts": ["other"],
    "high_school_biology": ["biology"],
    "high_school_chemistry": ["chemistry"],
    "high_school_computer_science": ["computer science"],
    "high_school_european_history": ["history"],
    "high_school_geography": ["geography"],
    "high_school_government_and_politics": ["politics"],
    "high_school_macroeconomics": ["economics"],
    "high_school_mathematics": ["math"],
    "high_school_microeconomics": ["economics"],
    "high_school_physics": ["physics"],
    "high_school_psychology": ["psychology"],
    "high_school_statistics": ["math"],
    "high_school_us_history": ["history"],
    "high_school_world_history": ["history"],
    "human_aging": ["health"],
    "human_sexuality": ["culture"],
    "international_law": ["law"],
    "jurisprudence": ["law"],
    "logical_fallacies": ["philosophy"],
    "machine_learning": ["computer science"],
    "management": ["business"],
    "marketing": ["business"],
    "medical_genetics": ["health"],
    "miscellaneous": ["other"],
    "moral_disputes": ["philosophy"],
    "moral_scenarios": ["philosophy"],
    "nutrition": ["health"],
    "philosophy": ["philosophy"],
    "prehistory": ["history"],
    "professional_accounting": ["other"],
    "professional_law": ["law"],
    "professional_medicine": ["health"],
    "professional_psychology": ["psychology"],
    "public_relations": ["politics"],
    "security_studies": ["politics"],
    "sociology": ["culture"],
    "us_foreign_policy": ["politics"],
    "virology": ["health"],
    "world_religions": ["philosophy"],
}

categories = {
    "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"],
    "humanities": ["history", "philosophy", "law"],
    "social sciences": ["politics", "culture", "economics", "geography", "psychology"],
    "other (business, health, misc.)": ["other", "business", "health"],
}

N_TRAIN = 5
SAMPLE_NUM = 80
choices = ["A", "B", "C", "D"]
# prompt_bias_list = ["A", "B", "C", "D", "D"]
# prompt_bias_list = ["D", "C", "B", "A", "D"]
prompt_bias_list = []

data_to_save = []

# filename = 'mmlu_problem_ans_always_a.json'

def format_subject(subject):
    l = subject.split("_")
    s = ""
    for entry in l:
        s += " " + entry
    return s

def format_example(df, idx, include_answer=True, bias=False, ans="A"):
    prompt = df.iloc[idx, 0]
    if not bias:
        k = df.shape[1] - 2
        for j in range(k):
            prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
        prompt += "\nAnswer:"
        if include_answer:
            prompt += " {}\n\n".format(df.iloc[idx, k + 1])
    else: # ans always A
        # 原来的正确答案 如果原来的答案不是A, 那么就和A的答案对调
        # 原来正确答案所在的列
        k = df.shape[1] - 2
        true_col = {"A": 0, "B": 1, "C": 2, "D": 3}[df.iloc[idx, k + 1]] + 1
        a_col = {"A": 1, "B": 2, "C": 3, "D": 4}[ans]
        temp = df.iloc[idx, true_col]
        df.iloc[idx, true_col] = df.iloc[idx, a_col]
        df.iloc[idx, a_col] = temp
        for j in range(k):
            prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
        prompt += "\nAnswer:"
        if include_answer:
            prompt += " {}\n\n".format(ans)
        df.iloc[idx, k + 1] = ans
    return prompt

def gen_prompt(train_df, subject, k=-1, bias=False, prompt_bias_list=[]):
    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
        format_subject(subject)
    )
    if k == -1:
        k = train_df.shape[0]
    if len(prompt_bias_list) == 0:
        prompt_bias_list = ["A" for i in range(k)]
    # else:
    #    assert k == len(prompt_bias_list)
    for i in range(k):
        prompt += format_example(train_df, i, bias=bias, ans=prompt_bias_list[i])
    return prompt

@torch.no_grad()
def eval(subject, llm, attn_t, dev_df, test_df, max_token, k):
    cors = []
    origin_cors = []
    bias_cors = []
    bias_origin_cors = []
    for i in range(test_df.shape[0]):
        # get prompt and make sure it fits
        problem = format_example(test_df, i, include_answer=False)
        label = test_df.iloc[i, test_df.shape[1] - 1]
        y_true = int({"A": 0, "B": 1, "C": 2, "D": 3}[label])
        data_to_save.append({"task": subject, "id": i, "problem": problem, "y_true": y_true})
        for bias in [False, True]:
            prompt = gen_prompt(dev_df, subject, k, bias=bias, prompt_bias_list=prompt_bias_list)
            content = prompt + problem
            input_ids = llm.tokenizer(content, return_tensors="pt").input_ids.to("cuda")
            while input_ids.shape[-1] > max_token and k != 0:
                k -= 1
                prompt = gen_prompt(dev_df, subject, k, bias=bias, prompt_bias_list=prompt_bias_list)
                content = prompt + problem
            input_ids = llm.tokenizer(content, return_tensors="pt").input_ids.to("cuda")
            # interaction
            logits = llm.model(
                input_ids=input_ids,
                attention_temperature=attn_t
            )[0][:,-1].flatten()
            print(content)
            print(input_ids)
            print(logits)
            exit()
            probs = (
                torch.nn.functional.softmax(
                    torch.tensor(
                        [
                            logits[llm.tokenizer("A").input_ids[-1]],
                            logits[llm.tokenizer("B").input_ids[-1]],
                            logits[llm.tokenizer("C").input_ids[-1]],
                            logits[llm.tokenizer("D").input_ids[-1]],
                        ]
                    ),
                    dim=0,
                )
                .detach()
                .cpu()
                .to(torch.float32)
                .numpy()
            )
            pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
            y_pred = int(np.argmax(probs))
            # origin 
            origin_logits = llm.model(
                input_ids=input_ids,
            )[0][:,-1].flatten()
            origin_probs = (
                torch.nn.functional.softmax(
                    torch.tensor(
                        [
                            origin_logits[llm.tokenizer("A").input_ids[-1]],
                            origin_logits[llm.tokenizer("B").input_ids[-1]],
                            origin_logits[llm.tokenizer("C").input_ids[-1]],
                            origin_logits[llm.tokenizer("D").input_ids[-1]],
                        ]
                    ),
                    dim=0,
                )
                .detach()
                .cpu()
                .to(torch.float32)
                .numpy()
            )
            origin_y_pred = int(np.argmax(origin_probs))
            if bias:
                data_to_save[-1]["bias_prompt"] =  prompt
                data_to_save[-1]["bias_content"] = content
                data_to_save[-1]["bias_attn_y_pred"] = y_pred
                data_to_save[-1]["bias_origin_y_pred"] = origin_y_pred
                bias_cor = y_pred == y_true
                bias_cors.append(bias_cor)
                bias_origin_cor = origin_y_pred == y_true
                bias_origin_cors.append(bias_origin_cor)
            else:
                data_to_save[-1]["prompt"] = prompt
                data_to_save[-1]["content"] = content
                data_to_save[-1]["attn_y_pred"] = y_pred
                data_to_save[-1]["origin_y_pred"] = origin_y_pred
                cor = pred == label
                origin_cor = origin_y_pred == y_true
                cors.append(cor)
                origin_cors.append(origin_cor)

    cors = np.array(cors)
    origin_cors = np.array(origin_cors)
    bias_cors = np.array(bias_cors)
    bias_origin_cors = np.array(bias_origin_cors)
    acc = np.mean(cors)
    origin_acc = np.mean(origin_cors)
    bias_acc = np.mean(bias_cors)
    bias_origin_acc = np.mean(bias_origin_cors)
    print("Average accuracy {:.3f} - use_attn_t {}".format(acc, subject))
    print("Average accuracy {:.3f} - origin {}".format(origin_acc, subject))
    print("Average accuracy {:.3f} - bias use_attn_t {}".format(bias_acc, subject))
    print("Average accuracy {:.3f} - bias origin {}".format(bias_origin_acc, subject))
    return cors, bias_cors

def run_eval(args, input_dir, output_dir, attn_t):
    llm = ContrastiveDecoding(model_name=args.model_name,num_gpus=int(args.num_gpus))
    subjects = sorted(
        [
            f.split("_test.csv")[0]
            for f in os.listdir(os.path.join(input_dir, "test"))
            if "_test.csv" in f
        ]
    )
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    if not os.path.exists(os.path.join(output_dir)):
        os.makedirs(os.path.join(output_dir))

    all_cors = []
    subcat_cors = {
        subcat: [] for subcat_lists in subcategories.values() for subcat in subcat_lists
    }
    cat_cors = {cat: [] for cat in categories}
    for subject in subjects:
        dev_df = pd.read_csv(
            os.path.join(input_dir, "dev", subject + "_dev.csv"), header=None
        )[: 5]
        test_df = pd.read_csv(
            os.path.join(input_dir, "test", subject + "_test.csv"), header=None
        )
        cors, bias_cors = eval(subject, llm, attn_t, dev_df, test_df, args.max_token, args.prompt_num)
        subcats = subcategories[subject]
        for subcat in subcats:
            subcat_cors[subcat].append(cors)
            for key in categories.keys():
                if subcat in categories[key]:
                    cat_cors[key].append(cors)
        all_cors.append(cors)
    for subcat in subcat_cors:
        subcat_acc = np.mean(np.concatenate(subcat_cors[subcat]))
        print("==Average accuracy {:.3f} - {}".format(subcat_acc, subcat))

    for cat in cat_cors:
        cat_acc = np.mean(np.concatenate(cat_cors[cat]))
        print("##Average accuracy {:.3f} - {}".format(cat_acc, cat))
    
    weighted_acc = np.mean(np.concatenate(all_cors))
    print("Average accuracy: {:.3f}".format(weighted_acc))

def get_attn_t(model_name, T=0.5):
    if "gemma-2b" in model_name.lower():
        attn_t = [0, {12: ([3, 7, 2, 1], T), 14: ([0, 1, 6, 7], T)}]
    if "gemma-7b" in model_name.lower():
        attn_t = [0, {18: ([0, 8, 6, 2], 0.5), 2: ([1, 5, 3, 0], 0.5)}]
    if "llama-2-7b" in model_name.lower():
        # attn_t = [0, {13: (range(0, 32), T), 14: (range(0, 32), T)}]
        attn_t  = [0, {14: ([24, 4, 20, 31], 0.4), 18: ([30, 10, 25, 28], 0.4)}]
    if "llama-3-8b" in model_name.lower():
        # attn_t = [0, {17: ([0, 1, 3, 4, 5, 6, 7, 9, 10, 12, 13, 14, 16, 17, 20, 21, 24, 25, 26, 27, 28, 29, 30, 31], 0.5)}]
        attn_t = [0, {17: ([24, 25, 26, 28], 0.5), 14: ([23, 5, 4, 20], 0.5)}]
    if "mistral-7b" in model_name.lower():
        # attn_t = [0, {16: ([12, 14, 13, 0], 0.5), 19: ([8, 9, 16, 10], 0.5)}]
        attn_t = [0, {16: ([12, 14, 13, 1], 0.5), 19: ([8, 16, 9, 10], 0.5)}]
    return attn_t

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, default="huggyllama/llama-7b")
    parser.add_argument("--max-token", type=int, default=2048)
    parser.add_argument("--num-gpus", type=str, default="1")
    parser.add_argument("--data-path", type=str, default="/mnt/llms/data/MMLU/test")
    parser.add_argument("--prompt-num", type=int, default=5)
    # parser.add_argument("--mode", type=str, choices=["baseline", "attn_t"])
    args = parser.parse_args()
    filename = 'mmlu_problem_{}_{}_shot_{}.json'.format(args.model_name.split("/")[-1], args.prompt_num, "".join(prompt_bias_list))
    print("filename {}".format(filename))
    attn_t = get_attn_t(args.model_name)
    print("use attn_t {}".format(attn_t))
    run_eval(args, args.data_path, "results", attn_t)
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(data_to_save, f, ensure_ascii=False, indent=4)

if __name__ == "__main__":
    main()
