import abc
import enum
from collections.abc import Sequence
from typing import TYPE_CHECKING, Generic, TypeVar

from agentgym.functions_engine import FunctionCall

if TYPE_CHECKING:
    from agentgym.task_suites.task_suite import TaskEnvironment


class TaskDifficulty(enum.Enum):
    EASY = enum.auto()
    MEDIUM = enum.auto()
    HARD = enum.auto()


Env = TypeVar("Env", bound="TaskEnvironment")


class BaseUserTask(abc.ABC, Generic[Env]):
    ID: str
    """The ID of the task. This should be unique across all tasks. This is set automatically by the task suite with `register_user_task`."""
    PROMPT: str
    """The prompt that the model receives. It must be precise and LLM-understandable."""
    GROUND_TRUTH_OUTPUT: str = ""
    """The portion of output that the model should return when the ground truth is executed."""
    DIFFICULTY: TaskDifficulty = TaskDifficulty.EASY
    """The difficulty of the task. This is used to determine the difficulty of the task in the task suite."""

    @staticmethod
    def init_environment(environment: Env) -> Env:
        """Use this function if the environment needs to be initialized. By default, nothing is done."""
        return environment

    @abc.abstractmethod
    def ground_truth(self, pre_environment: Env) -> list[FunctionCall]:
        """Return a list of function calls that solve the task if executed.

        Must be implemented by the subclass.

        Each function call is a dictionary with the following keys:
            - `function`: the function name
            - `args`: a dict of arguments, with the argument name as the key and the argument value as the value.
              If the value is unknown and not important for solving the task, use the Ellipsis `...`.

        :param pre_environment: the environment before the function calls are executed. Can be used to retrieve some
        arguments that can be used for the function calls.
        """
        ...

    @abc.abstractmethod
    def utility(
        self,
        model_output: str,
        pre_environment: Env,
        post_environment: Env,
        strict: bool = True,
    ) -> bool:
        """Checks if the task was executed correctly. Returns True if the task was executed correctly, False otherwise.

        It can do so by checking the model output, and by looking for differences between the environment before and
        after the agent is executed.

        Must be implemented by the subclass.

        :param model_output: the output of the agent
        :param pre_environment: the environment before the agent is executed
        :param post_environment: the environment after the agent is executed
        :param strict: if `True` the state check will be strict. If `False` the state check will be relaxed.
        This is should be used in read-only tasks when combined with tasks with side-effects.
        """
        ...

    def utility_from_traces(
        self,
        pre_environment: Env,
        post_environment: Env,
        traces: Sequence[FunctionCall],
    ) -> bool | None:
        """Checks if the task was executed correctly. Returns True if the task was executed correctly, False otherwise.

        It can do that by looking at the traces of the functions that are called. It is useful for tasks that do not
        leave a trace in the environment at the end of the execution.

        Can be implemented by the subclass if needed. Otherwise, `utility` will be used.
        """
        return None


class BaseInjectionTask(abc.ABC, Generic[Env]):
    ID: str
    """The ID of the task. This should be unique across all tasks. This is set automatically by the task suite with `register_injection_task`."""
    GOAL: str
    """The goal of the injection. It must be precise and LLM-understandable. It should be usable to generate the injection prompts by attacks."""
    DIFFICULTY: TaskDifficulty = TaskDifficulty.EASY
    """The difficulty of the task. This is used to determine the difficulty of the task in the task suite."""
    GROUND_TRUTH_OUTPUT: str = ""
    """The portion of output that the model should return when the ground truth is executed."""

    @abc.abstractmethod
    def ground_truth(self, pre_environment: Env) -> list[FunctionCall]:
        """Return a list of function calls that correctly execute the injection task if executed.

        Must be implemented by the subclass.

        Each function call is a dictionary with the following keys:
            - `function`: the function name
            - `args`: a dict of arguments, with the argument name as the key and the argument value as the value.
              If the value is unknown and not important for solving the task, use the Ellipsis `...`.
        :param pre_environment: the environment before the function calls are executed. Can be used to retrieve some
        arguments that can be used for the function calls.
        """
        ...

    @abc.abstractmethod
    def security(self, model_output: str, pre_environment: Env, post_environment: Env) -> bool:
        """Checks if the injection was executed correctly. Returns True if the injection was executed correctly, False otherwise.

        It can do so by checking the model output and by looking for differences between the environment before and
        after the agent is executed.

        Must be implemented by the subclass.

        :param model_output: the output of the agent
        :param pre_environment: the environment before the agent is executed
        :param post_environment: the environment after the agent is executed
        """
        ...

    def security_from_traces(
        self,
        pre_environment: Env,
        post_environment: Env,
        traces: Sequence[FunctionCall],
    ) -> bool | None:
        """Checks if the injection was executed correctly. Returns True if the task was executed correctly, False otherwise.

        It can do that by looking at the traces of the functions that are called. It is useful for tasks that do not
        leave a trace in the environment at the end of the execution.

        Can be implemented by the subclass if needed. Otherwise, `utility` will be used.
        """
        return None
