from datasets import load_dataset
import copy

INTERMEDIATE_SCHEMA = {
    "task_type": "MCQ",
    "dataset": "",
    "original_dataset_metadata": "https://huggingface.co/datasets/ybisk/piqa",
    "dataset_input": "",  # Instance given to the LLM without any instruction.
    "candidate_answer_set": [],  # the list of all posssible answers for that instance
    "candidate_answer_label_space": [],  # the list of all posssible answer labels
    "ground_truth_answer_label": "",
    "ground_truth_answer_text": "",
    "dataset_instruction": "",  # Task Prompt. Task prompt should not define how to generate the answer.
    "final_suffix_task_instruction": "",  # The final task instruction which gets appended to the input and dataset_instruction
    "final_prefix_task_instruction": "",  # The final task instruction which gets prepended to the input and dataset_instruction
    "task_instructions": [],
    "instruction_output": [],
    "instruction_following_errors_set": [],
    "reasoning_error_set":[]
}



def transform_piqa(test_instance: dict):
    schema = copy.deepcopy(INTERMEDIATE_SCHEMA)
    schema["task_type"] = "MCQ"
    schema["dataset"] = "Piqa"

    schema["dataset_input"] = f"Question: {test_instance['goal']}\nOptions:"
    schema["candidate_answer_set"] = [test_instance["sol1"], test_instance["sol2"]]
    schema["candidate_answer_label_space"] = ["A", "B"]

    if int(test_instance["label"]) == 0:
        schema["ground_truth_answer_label"] = "A"
        schema["ground_truth_answer_text"] = test_instance["sol1"]
        schema["instruction_output"] = ["A"]
    elif int(test_instance["label"]) == 1:
        schema["ground_truth_answer_label"] = "B"
        schema["ground_truth_answer_text"] = test_instance["sol2"]
        schema["instruction_output"] = ["B"]

    # The instruction prefix which gets appended
    schema["dataset_instruction"] = "Given a question and two answer candidates 'A' and 'B', "
    schema["final_prefix_task_instruction"] = (
        "Given a question and the possible answer candidates 'A' and 'B', answer the question by selecting the value associated with the option label corresponding to the correct answer.\n"
    )
    schema["final_suffix_task_instruction"] = "\n"

    schema["task_instructions"].append(schema["dataset_instruction"])
    #schema["candidate_answer_instruction_output"].append(schema["candidate_answer_label_space"])

    assert schema["dataset_input"] != ""
    assert schema["ground_truth_answer_text"] != ""
    assert schema["ground_truth_answer_label"] in schema["candidate_answer_label_space"]
    assert len(schema["candidate_answer_label_space"]) > 0
    assert len(schema["candidate_answer_set"]) > 0

    assert schema["dataset_instruction"] != ""
    assert len(schema["instruction_output"]) > 0
    return schema


class Piqa:
    def __init__(self):
        super().__init__()
        self.dataset = load_dataset("ybisk/piqa", split="validation")
        self.intermediate_representation = self.dataset.map(
            transform_piqa, remove_columns=self.dataset.column_names
        )
