from datasets import load_dataset
import copy
import re


INTERMEDIATE_SCHEMA = {
    "task_type": "MCQ",
    "dataset": "",
    "original_dataset_metadata": "https://huggingface.co/datasets/allenai/math_qa",
    "dataset_input": "", # Instance given to the LLM without any instruction. 
    "candidate_answer_set": [], # the list of all posssible answers for that instance
    "candidate_answer_label_space": [], # the list of all posssible answer labels
    "ground_truth_answer_label": "", 
    "ground_truth_answer_text": "",
    "dataset_instruction": "", # Task Prompt. Task prompt should not define how to generate the answer.
    "final_suffix_task_instruction": "", # The final task instruction which gets appended to the input and original_task_instruction
    "final_prefix_task_instruction": "", # The final task instruction which gets prepended to the input and original_task_instruction
    "task_instructions": [],
    "instruction_output": [],
    "instruction_following_errors_set": [],
    "reasoning_error_set":[]
}


def get_choices(options):
    choices = [
        c[4:].rstrip(" ,")
        for c in re.findall(r"[abcd] \) .*?, |e \) .*?$", options)
    ]
    return choices

def transform_mathqa(test_instance: dict):
    schema = copy.deepcopy(INTERMEDIATE_SCHEMA)

    # Adding meta data information
    schema['task_type'] = "MCQ"
    schema["dataset"] = "MathQA"

    candidate_answer_list = ["a", "b", "c", "d", "e"]

    # Copying untransformed data fields
    schema["dataset_input"] = f"Question: {test_instance['Problem']}\nOptions:"
    schema["candidate_answer_set"] = get_choices(test_instance["options"])
    schema["candidate_answer_label_space"] = candidate_answer_list

    schema["ground_truth_answer_label"] = test_instance["correct"]
    schema["ground_truth_answer_text"] = schema["candidate_answer_set"][schema["candidate_answer_label_space"].index(test_instance["correct"])]
    
    # The instruction prefix which gets appended
    schema["dataset_instruction"] = "Given a mathematical question and 5 options namely 'a', 'b', 'c', 'd', and, 'e', as candidate answers, "
    schema["final_prefix_task_instruction"] = "Given a mathematical question and 5 options namely 'a', 'b', 'c', 'd', and, 'e', answer the question by selecting the value associated with the option label corresponding to the correct answer.\n"
    schema["final_suffix_task_instruction"] = "\n"

    schema["instruction_output"] = [test_instance["correct"]]
    #schema["candidate_answer_instruction_output"].append(schema["candidate_answer_label_space"])

    schema["task_instructions"].append(schema["dataset_instruction"])
    
    assert schema["dataset_input"] != ""
    assert schema["ground_truth_answer_text"] != ""
    assert schema["ground_truth_answer_label"] != ""
    assert len(schema["candidate_answer_label_space"]) > 0 
    assert len(schema["candidate_answer_set"]) > 0 

    assert schema["dataset_instruction"] != ""
    assert len(schema["instruction_output"]) > 0
    return schema

class MathQA:
    def __init__(self):
        super().__init__()
        self.dataset = load_dataset("allenai/math_qa", split="test")
        self.intermediate_representation = self.dataset.map(transform_mathqa, remove_columns=self.dataset.column_names, desc="Converting dataset to schema")
