from datasets import load_dataset
import copy, pdb
import re


INTERMEDIATE_SCHEMA = {
    "task_type": "MCQ",
    "dataset": "",
    "original_dataset_metadata": "TIGER-Lab/MMLU-Pro",
    "dataset_input": "", # Instance given to the LLM without any instruction. 
    "candidate_answer_set": [], # the list of all posssible answers for that instance
    "candidate_answer_label_space": [], # the list of all posssible answer labels
    "ground_truth_answer_label": "", 
    "ground_truth_answer_text": "",
    "dataset_instruction": "", # Task Prompt. Task prompt should not define how to generate the answer.
    "final_suffix_task_instruction": "", # The final task instruction which gets appended to the input and dataset_instruction
    "final_prefix_task_instruction": "", # The final task instruction which gets prepended to the input and dataset_instruction
    "task_instructions": [],
    "instruction_output": [],
    "instruction_following_errors_set": [],
    "reasoning_error_set":[]
}

def get_choices(options):
    return options

def get_label_space(options):
    labels=["A","B","C","D","E","F","G","H","I","J","K","L","M","N","O","P","Q","R","S","T","U","V","W","X","Y","Z"]
    return labels[:len(options)]

def transform_mmlupro(test_instance: dict):
    schema = copy.deepcopy(INTERMEDIATE_SCHEMA)

    # Adding meta data information
    schema["dataset"] = "MMLUPro_"+test_instance['category']

    # Copying untransformed data fields
    schema["dataset_input"] = f"Question: {test_instance['question']}\n"
    schema["candidate_answer_set"] = get_choices(test_instance['options'])
    schema["candidate_answer_label_space"] = get_label_space(test_instance['options'])
    schema["ground_truth_answer_label"] = [test_instance['answer']]
    schema["ground_truth_answer_text"] = schema["candidate_answer_set"][test_instance['answer_index']]

    schema["dataset_instruction"] = "Given a question about "+test_instance['category']+" and " + str(len(schema["candidate_answer_label_space"]))+ " options: "+ ", ".join(schema["candidate_answer_label_space"])+" as candidate answers, "
    schema["final_prefix_task_instruction"] = schema["dataset_instruction"] + ", answer the question by selecting the value associated with the option label corresponding to the correct answer.\n"
    schema["final_suffix_task_instruction"] = "\n"

    schema["instruction_output"] = [test_instance['answer']]
    #schema["candidate_answer_instruction_output"].append(schema["candidate_answer_label_space"])

    assert schema["dataset_input"] != ""
    assert schema["ground_truth_answer_text"] != ""
    assert len(schema["candidate_answer_set"])==len(schema['candidate_answer_label_space'])
    assert schema["ground_truth_answer_label"] != ""
    assert len(schema["candidate_answer_label_space"]) > 0 
    assert len(schema["candidate_answer_set"]) > 0 

    assert schema["dataset_instruction"] != ""
    assert len(schema["instruction_output"]) > 0
    return schema

class MMLUPro:
    def __init__(self):
        super().__init__()
        self.dataset = load_dataset("TIGER-Lab/MMLU-Pro", split="test")
        self.intermediate_representation = self.dataset.map(transform_mmlupro, remove_columns=self.dataset.column_names)
