from abc import abstractmethod
from enum import Enum
from typing import Generator, Optional

from serde import JSON, Serializable

from .task_types import UsefulnessEvaluation


class TaskType(str, Enum):
    SIMPLE_QA = "Simple QA"
    MCQ_COT_QA = "MCQ CoT QA"
    SUMMARIZATION = "Summarization"
    CODING = "Coding"
    NATURAL = "Natural"


class Task:
    """A base class for a task which we want the encoder to do."""

    TASK_TYPE: TaskType
    TASK_DESCRIPTION_PROMPT: str
    # If not None, EXPECTED_PERFORMANCE is used for early stopping on the ICL loop
    EXPECTED_PERFORMANCE: Optional[float] = None

    @classmethod
    def get_task_description_prompt(cls) -> str:
        """Get the task description prompt for this task type.

        Returns:
            A string describing what the task is
        """
        assert (
            cls.TASK_DESCRIPTION_PROMPT is not None
        ), "This must be defined or this function must be overriden"
        return cls.TASK_DESCRIPTION_PROMPT

    @abstractmethod
    def __iter__(self) -> Generator["TaskInstance", None, None]:
        """Iterate over the task instances for this task."""
        pass

    @abstractmethod
    def __getitem__(self, idx: int) -> "TaskInstance":
        """Get the item at a particular index."""
        pass

    @abstractmethod
    def get_random_instance(self) -> "TaskInstance":
        """Get a random task instance."""
        pass

    @abstractmethod
    def instance_initializer(self, data: dict[str, JSON]) -> "TaskInstance":
        """Initialize a task instance from a dictionary"""
        pass

    @abstractmethod
    def serialize(self) -> dict[str, Serializable]:
        pass

    @abstractmethod
    def split(self, split_factor: float) -> "TaskInstance":
        pass


class TaskInstance:
    """A specific instance of a task that the encoder must do."""

    @abstractmethod
    def get_user_prompt(self) -> str:
        """Get the user prompt string for this task.

        This prompt describes the exact task (rather than the general type of task) and goes in the user prompt.

        E.g. For a QA setting:
        * task_description_prompt will describe that the encoder must answer the question under certain constraints
        * user_prompt will give the actual question to answer

        Returns:
            A string describing the task case to the encoder
        """
        pass

    @abstractmethod
    def evaluate(self, encoder_message: str) -> UsefulnessEvaluation:
        """Evaluate a (encoder message, decoder message) pair.

        Returns:
            A boolean indicating if the encoder's answer was useful or not useful.
        """
        pass

    def __repr__(self) -> str:
        return str(self)

    @abstractmethod
    def serialize(self) -> dict[str, JSON]:
        pass
