INTERV_PROMPTS = {
    0: 'Keep in mind that the question and options may be presented in various languages.',
    1: 'Remember that the question and options can be in different languages.',
    2: 'Remember that the question and options can be in different languages. First translate them all to English. Then output the answer.',
   
}

class MMLUProcessor:
    def __init__(self, choices=None, interv_prompt_id=-1):
        self.dataset_choices = ["A", "B", "C", "D"]

        self.choices = choices
        if self.choices is None:
            self.choices = self.dataset_choices
        self.interv_prompt_id = interv_prompt_id
        self.subject_train_prompt = {}

    def format_subject(self, subject):
        l = subject.split("_")
        s = ""
        for entry in l:
            s += " " + entry
        return s

    def format_example(self, df, idx, include_answer=True):
        prompt = df.iloc[idx, 0]
        k = df.shape[1] - 2

        for j in range(k):
            prompt += "\n{}. {}".format(self.choices[j], df.iloc[idx, j + 1])

        prompt += "\nAnswer:"

        if include_answer:
            dataset_answer = df.iloc[idx, k + 1]  # e.g., "A"
            # get the index, e.g., "A" has the index 0
            option_idex = self.dataset_choices.index(dataset_answer)
            # map it into the new choice given the same index
            answer = self.choices[option_idex]
            prompt += " {}\n\n".format(answer)
        return prompt

    def gen_prompt(self, train_df, subject, k=-1):

        instruction = "The following are multiple choice questions (with answers) about {}.".format(
            self.format_subject(subject))

        if self.interv_prompt_id in INTERV_PROMPTS.keys():
            prompt = instruction + \
                " {}\n\n".format(INTERV_PROMPTS[self.interv_prompt_id])
        else:
            prompt = instruction + "\n\n"

        if k == -1:
            k = train_df.shape[0]
        for i in range(k):
            # few-shot demonstration
            prompt += self.format_example(train_df, i, include_answer=True)
        return prompt

    def gen_test_prompt(self, ntrain, test_df, dev_df, idx, subject):

        if subject not in self.subject_train_prompt.keys():
            train_prompt = self.gen_prompt(dev_df, subject, ntrain)
            self.subject_train_prompt[subject] = train_prompt
        else:
            train_prompt = self.subject_train_prompt[subject]

        prompt_end = self.format_example(test_df, idx, include_answer=False)
        prompt = train_prompt + prompt_end

        label = test_df.iloc[idx, test_df.shape[1] - 1]
        option_idex = self.dataset_choices.index(label)
        mapped_label = self.choices[option_idex]

        return prompt, mapped_label


subcategories = {
    "abstract_algebra": ["math"],
    "anatomy": ["health"],
    "astronomy": ["physics"],
    "business_ethics": ["business"],
    "clinical_knowledge": ["health"],
    "college_biology": ["biology"],
    "college_chemistry": ["chemistry"],
    "college_computer_science": ["computer science"],
    "college_mathematics": ["math"],
    "college_medicine": ["health"],
    "college_physics": ["physics"],
    "computer_security": ["computer science"],
    "conceptual_physics": ["physics"],
    "econometrics": ["economics"],
    "electrical_engineering": ["engineering"],
    "elementary_mathematics": ["math"],
    "formal_logic": ["philosophy"],
    "global_facts": ["other"],
    "high_school_biology": ["biology"],
    "high_school_chemistry": ["chemistry"],
    "high_school_computer_science": ["computer science"],
    "high_school_european_history": ["history"],
    "high_school_geography": ["geography"],
    "high_school_government_and_politics": ["politics"],
    "high_school_macroeconomics": ["economics"],
    "high_school_mathematics": ["math"],
    "high_school_microeconomics": ["economics"],
    "high_school_physics": ["physics"],
    "high_school_psychology": ["psychology"],
    "high_school_statistics": ["math"],
    "high_school_us_history": ["history"],
    "high_school_world_history": ["history"],
    "human_aging": ["health"],
    "human_sexuality": ["culture"],
    "international_law": ["law"],
    "jurisprudence": ["law"],
    "logical_fallacies": ["philosophy"],
    "machine_learning": ["computer science"],
    "management": ["business"],
    "marketing": ["business"],
    "medical_genetics": ["health"],
    "miscellaneous": ["other"],
    "moral_disputes": ["philosophy"],
    "moral_scenarios": ["philosophy"],
    "nutrition": ["health"],
    "philosophy": ["philosophy"],
    "prehistory": ["history"],
    "professional_accounting": ["other"],
    "professional_law": ["law"],
    "professional_medicine": ["health"],
    "professional_psychology": ["psychology"],
    "public_relations": ["politics"],
    "security_studies": ["politics"],
    "sociology": ["culture"],
    "us_foreign_policy": ["politics"],
    "virology": ["health"],
    "world_religions": ["philosophy"],
}

categories = {
    "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"],
    "humanities": ["history", "philosophy", "law"],
    "social sciences": ["politics", "culture", "economics", "geography", "psychology"],
    "other (business, health, misc.)": ["other", "business", "health"],
}
