import random
import os
import pandas as pd
subcategories = {
    "abstract_algebra": ["math"],
    "anatomy": ["health"],
    "astronomy": ["physics"],
    "business_ethics": ["business"],
    "clinical_knowledge": ["health"],
    "college_biology": ["biology"],
    "college_chemistry": ["chemistry"],
    "college_computer_science": ["computer science"],
    "college_mathematics": ["math"],
    "college_medicine": ["health"],
    "college_physics": ["physics"],
    "computer_security": ["computer science"],
    "conceptual_physics": ["physics"],
    "econometrics": ["economics"],
    "electrical_engineering": ["engineering"],
    "elementary_mathematics": ["math"],
    "formal_logic": ["philosophy"],
    "global_facts": ["other"],
    "high_school_biology": ["biology"],
    "high_school_chemistry": ["chemistry"],
    "high_school_computer_science": ["computer science"],
    "high_school_european_history": ["history"],
    "high_school_geography": ["geography"],
    "high_school_government_and_politics": ["politics"],
    "high_school_macroeconomics": ["economics"],
    "high_school_mathematics": ["math"],
    "high_school_microeconomics": ["economics"],
    "high_school_physics": ["physics"],
    "high_school_psychology": ["psychology"],
    "high_school_statistics": ["math"],
    "high_school_us_history": ["history"],
    "high_school_world_history": ["history"],
    "human_aging": ["health"],
    "human_sexuality": ["culture"],
    "international_law": ["law"],
    "jurisprudence": ["law"],
    "logical_fallacies": ["philosophy"],
    "machine_learning": ["computer science"],
    "management": ["business"],
    "marketing": ["business"],
    "medical_genetics": ["health"],
    "miscellaneous": ["other"],
    "moral_disputes": ["philosophy"],
    "moral_scenarios": ["philosophy"],
    "nutrition": ["health"],
    "philosophy": ["philosophy"],
    "prehistory": ["history"],
    "professional_accounting": ["other"],
    "professional_law": ["law"],
    "professional_medicine": ["health"],
    "professional_psychology": ["psychology"],
    "public_relations": ["politics"],
    "security_studies": ["politics"],
    "sociology": ["culture"],
    "us_foreign_policy": ["politics"],
    "virology": ["health"],
    "world_religions": ["philosophy"],
}

categories = {
    "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"],
    "humanities": ["history", "philosophy", "law"],
    "social sciences": ["politics", "culture", "economics", "geography", "psychology"],
    "other (business, health, misc.)": ["other", "business", "health"],
}
choices = ["A", "B", "C", "D"]

def format_subject(subject):
    l = subject.split("_")
    s = ""
    for entry in l:
        s += " " + entry
    return s

# 将原来的四选项改成二选项、三选项
# include_answer=False，代表不增加答案，就是问题
def format_example_n_choices(df, idx, include_answer=True, bias_ans="A", n=4):
    assert {"A": 0, "B": 1, "C": 2, "D": 3}[bias_ans] + 1 <= n
    prompt = df.iloc[idx, 0]
    # 只能有偏
    k = df.shape[1] - 2
    true_col = {"A": 0, "B": 1, "C": 2, "D": 3}[df.iloc[idx, k + 1]] + 1
    a_col = {"A": 0, "B": 1, "C": 2, "D": 3}[bias_ans] + 1
    temp = df.iloc[idx, true_col]
    df.iloc[idx, true_col] = df.iloc[idx, a_col]
    df.iloc[idx, a_col] = temp
    for j in range(n):
        prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
    prompt += "\nAnswer:"
    if include_answer:
        prompt += " {}\n\n".format(bias_ans)
    return prompt

# include_answer=False，代表不增加答案，就是问题
def format_example(df, idx, include_answer=True, bias=False, prompt_bias="A"):
    prompt = df.iloc[idx, 0]
    if not bias:
        k = df.shape[1] - 2
        for j in range(k):
            # prompt += "\n({}). {}".format(choices[j], df.iloc[idx, j + 1])
            prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
        prompt += "\nAnswer:"
        if include_answer:
            prompt += " {}\n\n".format(df.iloc[idx, k + 1])
    else: # ans always A
        # 原来的正确答案 如果原来的答案不是A, 那么就和A的答案对调
        # 原来正确答案所在的列
        k = df.shape[1] - 2
        true_col = {"A": 0, "B": 1, "C": 2, "D": 3}[df.iloc[idx, k + 1]] + 1
        a_col = {"A": 0, "B": 1, "C": 2, "D": 3}[prompt_bias] + 1
        temp = df.iloc[idx, true_col]
        df.iloc[idx, true_col] = df.iloc[idx, a_col]
        df.iloc[idx, a_col] = temp
        for j in range(k):
            prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
        prompt += "\nAnswer:"
        if include_answer:
            prompt += " {}\n\n".format(prompt_bias)
    return prompt

def format_problem(df, idx, n=4, bias=False, bias_ans="A"):
    problem = df.iloc[idx, 0]
    if n == 4:
        if not bias:
            k = df.shape[1] - 2
            for j in range(k):
                # problem += "\n({}). {}".format(choices[j], df.iloc[idx, j + 1])
                problem += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
            problem += "\nAnswer:"
            # 不增加正确答案
            # problem += " {}\n\n".format(df.iloc[idx, k + 1])
            label = df.iloc[idx, df.shape[1] - 1]
            y_true = int({"A": 0, "B": 1, "C": 2, "D": 3}[label])
        else:
            k = df.shape[1] - 2
            true_col = {"A": 0, "B": 1, "C": 2, "D": 3}[df.iloc[idx, k + 1]] + 1
            a_col = {"A": 0, "B": 1, "C": 2, "D": 3}[bias_ans] + 1
            temp = df.iloc[idx, true_col]
            df.iloc[idx, true_col] = df.iloc[idx, a_col]
            df.iloc[idx, a_col] = temp
            for j in range(k):
                problem += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
            problem += "\nAnswer:"
            y_true = int({"A": 0, "B": 1, "C": 2, "D": 3}[bias_ans])
    elif n < 4 and n >= 2:
        # 随机生成 "A", "B"
        k = df.shape[1] - 2
        y_true = random.randint(0, n-1)
        true_col = {"A": 0, "B": 1, "C": 2, "D": 3}[df.iloc[idx, k + 1]] + 1
        a_col = y_true + 1
        temp = df.iloc[idx, true_col]
        df.iloc[idx, true_col] = df.iloc[idx, a_col]
        df.iloc[idx, a_col] = temp
        for j in range(n):
            problem += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
        problem += "\nAnswer:"
    return problem, y_true


def gen_prompt(train_df, subject, k=-1, bias=False, prompt_bias="A", n=4, prompt_bias_list=[]):
    assert {"A": 0, "B": 1, "C": 2, "D": 3}[prompt_bias] + 1 <= n
    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
        format_subject(subject)
    )
    if k == -1:
        k = train_df.shape[0]
    if len(prompt_bias_list) == 0:
        prompt_bias_list = [prompt_bias for i in range(k)]
    for i in range(k):
        if n == 4:
            prompt += format_example(train_df, i, bias=bias, prompt_bias=prompt_bias_list[i])
        elif n < 4 and n >=2:
            prompt += format_example_n_choices(train_df, i, include_answer=True, bias_ans="A", n=n)
    return prompt

def get_mmlu_data(llm, 
                  prompt_num=5, 
                  n_choice=4, 
                  prompt_bias="A", 
                  bias=False, 
                  bias_ans="A", 
                  data_path="../../data/MMLU",
                  prompt_bias_list=[]):
    llm.set_label_id("mmlu")
    subjects = sorted(
        [
            f.split("_test.csv")[0]
            for f in os.listdir(os.path.join(data_path, "test"))
            if "_test.csv" in f
        ]
    )
    all_cors = []
    subcat_cors = {
        subcat: [] for subcat_lists in subcategories.values() for subcat in subcat_lists
    }
    cat_cors = {cat: [] for cat in categories}
    all_data = []
    for subject in subjects:
        """ # 特定的领域
        if subcategories[subject][0] not in categories["other (business, health, misc.)"]: # humanities, social sciences, other (business, health, misc.)
            continue
        """
        dev_df = pd.read_csv(
            os.path.join(data_path, "dev", subject + "_dev.csv"), header=None
        )[: 5]
        test_df = pd.read_csv(
            os.path.join(data_path, "test", subject + "_test.csv"), header=None
        )
        for i in range(test_df.shape[0]):
            k = prompt_num
            # 改变不同的bias ans对于最后找到的layer是不变的
            prompt = gen_prompt(
                dev_df, subject, k, bias=False, prompt_bias=prompt_bias, n=n_choice, prompt_bias_list=prompt_bias_list
            )
            bias_prompt = gen_prompt(
                dev_df, subject, k, bias=True, prompt_bias=prompt_bias, n=n_choice, prompt_bias_list=prompt_bias_list
            )
            problem, y_true = format_problem(
                test_df, i, n=n_choice, bias=bias, bias_ans=bias_ans
            )
            prompt = prompt + problem
            bias_prompt = bias_prompt + problem
            all_data.append({"prompt": prompt, "bias_prompt": bias_prompt, "problem": problem, "y_true": y_true})
    return all_data


"""
def get_contrast_data(llm, prompt_num=5, n_choice=4, prompt_bias="A", data_path="../../data/MMLU"):
    llm.set_label_id("mmlu")
    subjects = sorted(
        [
            f.split("_test.csv")[0]
            for f in os.listdir(os.path.join(data_path, "test"))
            if "_test.csv" in f
        ]
    )
    all_data = []
    for subject in subjects:
        dev_df = pd.read_csv(
            os.path.join(data_path, "dev", subject + "_dev.csv"), header=None
        )[: 5]
        dev_df_origin = pd.read_csv(
            os.path.join(data_path, "dev", subject + "_dev.csv"), header=None
        )[: 5]
        # 改变不同的bias ans对于最后找到的layer是不变的
        prompt = gen_prompt(
            dev_df, subject, prompt_num, bias=False, prompt_bias=prompt_bias, n=n_choice
        )
        bias_prompt = gen_prompt(
            dev_df, subject, prompt_num, bias=True, prompt_bias=prompt_bias, n=n_choice
        )
        for i in range(prompt_num):
            problem, y_true = format_problem(dev_df_origin, i, n=n_choice)
            bias_content = bias_prompt + problem
            prompt = prompt + problem
            bias_prompt = bias_prompt + problem
            all_data.append({"prompt": prompt, "bias_prompt": bias_prompt, "problem": problem, "y_true": y_true})
    return all_data
"""