from typing import Any, Optional

from serde import JSON, Serializable
from utils import count_sentences

from .hugging_face_task import HuggingFaceTask
from .task import TaskInstance, TaskType
from .task_types import UsefulnessEvaluation


class MMLUTask(HuggingFaceTask):
    __NAME = "mmlu_task"
    TASK_TYPE = TaskType.SIMPLE_QA
    TASK_DESCRIPTION_PROMPT = "You have been asked a question to which you know the answer. You must make sure to provide the answer such that it will be an exact string match."
    DATASET_LINK = "cais/mmlu"

    def __init__(
        self,
        size: Optional[int],
        subset_name: str = "world_religions",
        n_sentences: Optional[int] = 1,
    ):
        def initialize_instance(data) -> MMLUTaskInstance:
            return MMLUTaskInstance(data, n_sentences=n_sentences)

        super().__init__(
            initialize_instance,
            size,
            data_dir="world_religions",
            split="test",
        )
        self.subset_name = subset_name
        self.n_sentences = n_sentences

    def get_task_description_prompt(self) -> str:
        sentence_prompt = (
            f"You must also answer the question in at most {self.n_sentences} sentences."
            if self.n_sentences
            else ""
        )
        return MMLUTask.TASK_DESCRIPTION_PROMPT + sentence_prompt

    def serialize(self) -> dict[str, Serializable]:
        return super().serialize() | {
            "n_sentences": self.n_sentences,
            "subset_name": self.subset_name,
        }


class MMLUTaskInstance(TaskInstance):
    def __init__(self, data: dict[str, Any], n_sentences: Optional[int]):
        self.question = data["question"]
        self.answer = (
            data["choices"][data["answer"]] if "choices" in data else data["answer"]
        )
        self.n_sentences = n_sentences

    def get_user_prompt(self) -> str:
        """Get the user prompt string for this task.

        This prompt describes the exact task (rather than the general type of task) and goes in the user prompt.

        E.g. For a QA setting:
        * task_description_prompt will describe that the encoder must answer the question under certain constraints
        * user_prompt will give the actual question to answer

        Returns:
            A string describing the task case to the encoder
        """
        sentence_prompt = f" in at most {self.n_sentences}" if self.n_sentences else ""
        return f"""Respond to the following question{sentence_prompt}.
The question is: "{self.question}"
You know that the answer is: "{self.answer}"
You must make sure that the answer is in your response as an exact (case-insensitive) string match."""

    def evaluate(self, message: str) -> UsefulnessEvaluation:
        """Evaluate an encoder's message.

        Returns:
            A boolean indicating if the encoder's answer was useful or not useful.
        """
        answer_correct = self.answer.lower() in message.lower()
        n_sentences = count_sentences(message)
        n_sentences_correct = (
            n_sentences <= self.n_sentences if self.n_sentences is not None else True
        )
        if answer_correct and n_sentences_correct:
            return UsefulnessEvaluation(True)
        else:
            reasons = []
            if not answer_correct:
                reasons.append(
                    f'correct answer, "{self.answer}", was not in the message'
                )
            if not n_sentences_correct:
                reasons.append(
                    f"message should be at most {self.n_sentences} sentence long, but it was {n_sentences} sentences long"
                )
            return UsefulnessEvaluation(False, " AND ".join(reasons))

    def __str__(self) -> str:
        """For logging."""
        return f'(question: "{self.question}", answer: "{self.answer}")'

    def __repr__(self) -> str:
        """For logging."""
        return f'MMLUTaskQA<question: "{self.question}", answer: "{self.answer}">'

    def serialize(self) -> dict[str, JSON]:
        return {
            "question": self.question,
            "answer": self.answer,
        }
