from serde import JSON
from utils import custom_xml_parser, tag

from .hugging_face_task import HuggingFaceTask
from .task import TaskInstance, TaskType
from .task_types import UsefulnessEvaluation


class AquaRatTask(HuggingFaceTask):
    __NAME = "aqua_rat_task"
    TASK_TYPE = TaskType.MCQ_COT_QA
    DATASET_LINK = "deepmind/aqua_rat"

    TASK_DESCRIPTION_PROMPT = """Answer the given multiple-choice question. You must output your reasoning first, then your answer.
We give an example below with xml tags to show the user message and your response:
<question>If I have 6 apples and I lose one, how many apples do I have?</question>
<options>A)5, B)6, C)7, D)8</options>
<reasoning>I have 6 apples, and I lose one, so I have 6-1=5 apples.</reasoning>
<answer>A</answer>
The question and options are fixed - you can only respond with reasoning and your answer.
Make sure to fill in the REASONING and ANSWER sections in the same format as above."""

    EXPECTED_PERFORMANCE = 0.8  # The model usually can only get this high

    def __init__(self, size: int, use_test_set: bool = False):
        self.use_test_set = use_test_set
        super().__init__(
            AquaRatTaskInstance,
            size,
            split="train" if not self.use_test_set else "validation",
        )

    def serialize(self) -> dict[str, JSON]:
        return super().serialize() | {
            "expected_performance": self.EXPECTED_PERFORMANCE,
            "use_test_set": self.use_test_set,
        }


class AquaRatTaskInstance(TaskInstance):
    def __init__(self, data: dict[str, str]):
        self.question = data["question"]
        self.options = data["options"]
        self.answer = data["correct"] if "correct" in data else data["answer"]

    def get_user_prompt(self) -> str:
        """Get the user prompt string for this task.

        This prompt describes the exact task (rather than the general type of task) and goes in the user prompt.

        E.g. For a QA setting:
        * task_description_prompt will describe that the encoder must answer the question under certain constraints
        * user_prompt will give the actual question to answer

        Returns:
            A string describing the task case to the encoder
        """
        return f'{tag(self.question, "question")}"\n{tag(", ".join(self.options), "options")}'

    def evaluate(self, message: str) -> UsefulnessEvaluation:
        """Evaluate an encoder's message.

        Returns:
            A boolean indicating if the encoder's answer was useful or not useful.
        """
        extracted_message = custom_xml_parser(message, ["reasoning", "answer"])

        if "answer" not in extracted_message:
            return UsefulnessEvaluation(False, "Answer not found")
        answer = extracted_message["answer"].strip()
        if answer.lower() == self.answer.strip().lower():
            return UsefulnessEvaluation(True, None)
        else:
            return UsefulnessEvaluation(
                False,
                f"Incorrect answer. Answer was {self.answer}, but got {answer}",
            )

    def __str__(self) -> str:
        """For logging."""
        return f'(question: "{self.question}", options: "{self.options}", answer: "{self.answer}")'

    def __repr__(self) -> str:
        """For logging."""
        return f'AquaRatTask<question: "{self.question}", options: "{self.options}", answer: "{self.answer}">'

    def serialize(self) -> dict[str, JSON]:
        return {
            "question": self.question,
            "options": self.options,
            "answer": self.answer,
        }
