from abc import abstractmethod
from typing import Generator, Optional

from serde import JSON

from .natural_task_types import EpisodeEvaluation


class NaturalTask:
    EXPECTED_PERFORMANCE: Optional[float] = None

    ENCODER_TASK_DESCRIPTION: str = ""
    DECODER_TASK_DESCRIPTION: str = ""
    OVERSEER_DESCRIPTION: str = ""

    ENCODER_MODEL_NAME: str = ""
    DECODER_MODEL_NAME: str = ""

    @abstractmethod
    def __iter__(self) -> Generator["NaturalTaskInstance", None, None]:
        """Iterate over the task instances for this task."""
        pass

    @abstractmethod
    def __getitem__(self, idx: int) -> "NaturalTaskInstance":
        """Get the item at a particular index."""
        pass

    @abstractmethod
    def get_random_instance(self) -> "NaturalTaskInstance":
        pass

    @abstractmethod
    def instance_initializer(self, data: dict[str, JSON]) -> "NaturalTaskInstance":
        """Initialize a task instance from a dictionary"""
        pass

    def serialize(self) -> dict[str, JSON]:
        return {
            "task_type": "natural",
        }

    @abstractmethod
    def results_to_encoder_str(
        self, episode_results: EpisodeEvaluation, n_examples: int
    ) -> str:
        return self.experiment_config.task.results_to_encoder_str(
            episode_results, n_examples
        )

    @abstractmethod
    def split(self, split_factor: float) -> "NaturalTaskInstance":
        pass


class NaturalTaskInstance:
    input: str

    @abstractmethod
    def get_encoder_prompt(self) -> str:
        """Get the encoder prompt string for this task.

        This prompt describes the exact task (rather than the general type of task) and goes in the user prompt.

        E.g. For a QA setting:
        * task_description_prompt will describe that the encoder must answer the question under certain constraints
        * user_prompt will give the actual question to answer

        Returns:
            A string describing the task case to the encoder
        """
        pass

    @abstractmethod
    def get_decoder_prompt(self, message: str) -> str:
        """Get the decoder prompt string for this task.

        This prompt describes the exact task (rather than the general type of task) and goes in the user prompt.

        E.g. For a QA setting:
        * task_description_prompt will describe that the encoder must answer the question under certain constraints
        * user_prompt will give the actual question to answer

        Returns:
            A string describing the task case to the encoder
        """
        pass

    @abstractmethod
    def evaluate(self, decoder_output: str) -> EpisodeEvaluation:
        """Evaluate an decoder's output.

        Returns:
            An evaluation of the decoder's output.
        """
        pass

    @abstractmethod
    def __str__(self) -> str:
        """For logging."""
        pass

    @abstractmethod
    def __repr__(self) -> str:
        """For logging."""
        pass

    @abstractmethod
    def serialize(self) -> dict[str, JSON]:
        pass

    @abstractmethod
    def input_no_label(self) -> str:
        pass
