import dataclasses
import inspect
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from typing import Annotated, Any, Concatenate, Generic, TypeAlias, get_origin, get_type_hints, overload

from docstring_parser import parse
from docstring_parser.common import DocstringParam
from pydantic import BaseModel, Field, ValidationError, create_model
from typing_extensions import ParamSpec, TypeVar


class TaskEnvironment(BaseModel):
    """Base class for the environment in which the agent is running."""

    ...


class EmptyEnv(TaskEnvironment):
    """An empty environment for when a state is not needed."""

    ...


@dataclasses.dataclass(frozen=True)
class Depends:
    """A dependency for a function. It is used to extract information from the environment to pass to the function."""

    env_dependency: str | Callable[[BaseModel], BaseModel]
    """The name of the attribute to extract from the environment or a function that returns the attribute from the environment."""

    def extract_dep_from_env(self, env: BaseModel) -> BaseModel:
        """Extracts the dependency from the environment."""
        if callable(self.env_dependency):
            return self.env_dependency(env)
        return getattr(env, self.env_dependency)

    def __repr__(self) -> str:
        return f"Depends('{self.env_dependency}')"


class FunctionCall(BaseModel):
    """An object containing information about a function call requested by an agent."""

    function: str
    """The name of the function to call."""
    args: MutableMapping[str, "FunctionCallArgTypes"]
    """The arguments to pass to the function."""
    id: str | None = None
    """An optional ID for the function call. E.g., used by OpenAI and Anthropic."""
    placeholder_args: Mapping[str, "FunctionCallArgTypes"] | None = None
    """An optional dictionary of placeholder arguments to use in by ground truth agent in injection tasks."""


FunctionCallArgTypes = str | int | float | bool | None | dict | list | FunctionCall
"""Valid types for function call arguments."""


FunctionReturnType: TypeAlias = BaseModel | Sequence["FunctionReturnType"] | dict | str | int | float | bool | None
"""Union of valid return types for functions. The default [FunctionsRuntime][agentdojo.functions_runtime.FunctionsRuntime]
is not guaranteed to work with other types."""
P = ParamSpec("P")
S = TypeVar("S")
Env = TypeVar("Env", bound=TaskEnvironment)
T = TypeVar("T", bound=FunctionReturnType)
FunctionCallable = Callable[Concatenate[Env, P], FunctionReturnType]


class Function(BaseModel, Generic[P, T]):
    """A function that can be called by an agent and run by a [`FunctionsRuntime`][agentdojo.functions_runtime.FunctionsRuntime]."""

    name: str
    """The name of the function."""
    description: str
    """A short description of the function."""
    parameters: type[BaseModel]
    """The input schema for the function. It is defined as a Pydantic model."""
    dependencies: dict[str, Depends]
    """A dictionary of dependencies for the function. See [`Depends`][agentdojo.functions_runtime.Depends] for more information."""
    run: Callable[P, T]
    """The function implementation."""
    full_docstring: str
    """The full docstring of the functions."""
    return_type: Any | None
    """The return type of the function."""

    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
        return self.run(*args, **kwargs)


def _get_dependencies(function: Callable[P, T]) -> dict[str, Depends]:
    arg_types = get_type_hints(function, include_extras=True)
    dependencies: dict[str, Depends] = {}
    for arg_name, arg_type in arg_types.items():
        if get_origin(arg_type) is Annotated:
            if isinstance(dependency := arg_type.__metadata__[0], Depends):
                dependencies[arg_name] = dependency
    return dependencies


def make_function(function: Callable[P, T]) -> Function:
    """Creates a [`Function`][agentdojo.functions_runtime.Function] instance from a function.

    It does so by extracting the docstrings and type hints from the function. It then creates
    a Pydantic model for the input schema of the function.

    Note:
        ReStructuredText is the recommended docstring format.

    Args:
        function: The function to create a [`Function`][agentdojo.functions_runtime.Function] instance from.
            The function should have docstrings with a short description of the function and descriptions
            for each argument. [`Depends`][agentdojo.functions_runtime.Depends] arguments should not be
            documented in the docstring.

    Returns:
        A [`Function`][agentdojo.functions_runtime.Function] instance.

    Raises:
        ValueError: If the function has no docstring or short description, or if the arguments docstrings are
            invalid.
    """
    dependencies = _get_dependencies(function)

    if function.__doc__ is None or function.__doc__ == "":
        raise ValueError(f"Function {function.__name__} has no docstring")
    function_docs = parse(function.__doc__)
    if function_docs.short_description is None:
        raise ValueError(f"Function {function.__name__} has no short description")
    return Function[P, T](
        name=function.__name__,
        description=function_docs.short_description.replace("None", "null"),
        parameters=_parse_args(function.__name__, function_docs.params, inspect.signature(function)),
        run=function,
        dependencies=dependencies,
        full_docstring=function.__doc__,
        return_type=get_type_hints(function).get("return"),
    )


BM = TypeVar("BM", bound=BaseModel)


def _parse_args(
    function_name: str,
    params_docs: list[DocstringParam],
    signature: inspect.Signature,
    field_fn: Callable = Field,
    create_model_fn: Callable[..., type[BM]] = create_model,
) -> type[BM]:
    params = {}
    for arg in params_docs:
        arg_name = arg.arg_name
        if arg.description is None:
            raise ValueError(f"Argument {arg_name} has no description")
        arg_type = signature.parameters[arg_name].annotation
        default_value = signature.parameters[arg_name].default
        params[arg_name] = (
            arg_type,
            field_fn(
                description=arg.description,
                default=... if default_value == inspect.Parameter.empty else default_value,
            ),
        )

    return create_model_fn(f"Input schema for `{function_name}`", **params)


class ToolNotFoundError(Exception):
    """Exception raised when the requested tool is not found."""

    ...


class FunctionsRuntime:
    """A runtime for functions.

    Args:
        functions: A list of functions that the runtime has access to.
    """

    def __init__(self, functions: Sequence[Function] = []) -> None:
        self.functions: dict[str, Function] = {f.name: f for f in functions}

    def __repr__(self) -> str:
        return f"FunctionsRuntime(functions={self.functions})"

    @overload
    def register_function(self, function: Callable[P, T]) -> Callable[P, T]: ...
    @overload
    def register_function(self, function: Function) -> Function: ...

    def register_function(self, function: Callable[P, T] | Function) -> Callable[P, T] | Function:
        """Registers a function with the runtime.

        If `function` is a function, it will be converted to a `Function` instance and registered,
        otherwise it will be registered as is. The conversion is made with [`make_function`][agentdojo.functions_runtime.make_function].

        Args:
            function: The function to register. It can be a function or a `Function` instance. See [`make_function`][agentdojo.functions_runtime.make_function]
                for more information.

        Returns:
            The registered function.
        """
        if isinstance(function, Function):
            self.functions[function.name] = function
            return function
        try:
            f = make_function(function)
        except Exception as e:
            raise Exception(f"Error when registering function {function.__name__} in {function.__module__}.") from e
        self.functions[function.__name__] = f
        return function

    def _execute_nested_calls(
        self, env: TaskEnvironment | None, kwargs: Mapping[str, FunctionCallArgTypes]
    ) -> dict[str, FunctionCallArgTypes]:
        """Executes nested function calls.

        Args:
            env: The environment to extract dependencies from.
            function: The function to run.
            kwargs: The arguments to pass to the function.

        Returns:
            A tuple containing the result of the function and an error message if an error occurred.
        """
        updated_kwargs = {}
        for arg_name, arg_value in kwargs.items():
            if not isinstance(arg_value, FunctionCall):
                updated_kwargs[arg_name] = arg_value
            else:
                # We need to raise the exception to exit early if an error occurs
                # This will be handled by `run_function` depending on the value of `raise_on_error`
                # in the caller. We can safely take the first element of the tuple
                # as the second will always be None (otherwise an exception would be raised)
                updated_kwargs[arg_name], _ = self.run_function(
                    env, arg_value.function, arg_value.args, raise_on_error=True
                )
        return updated_kwargs

    def run_function(
        self,
        env: TaskEnvironment | None,
        function: str,
        kwargs: Mapping[str, FunctionCallArgTypes],
        raise_on_error: bool = False,
    ) -> tuple[FunctionReturnType, str | None]:
        """Runs a function with the given arguments.

        If `raise_on_error` is `True`, the function will raise an exception if an error occurs. Otherwise, it will return an error message as the second
        argument. The error message has the format `'ErrorType: ErrorMessage'`.

        Args:
            env: The environment to extract dependencies from. It can be `None` if the function has no dependencies.
            function: The name of the function to run.
            kwargs: The arguments to pass to the function.
            raise_on_error: If `True`, the function will raise an exception if an error occurs. Otherwise, it will return an error message.

        Returns:
            A tuple containing the result of the function and an error message if an error occurred. If no error occurred, the second element will be `None`.

        Raises:
            ValidationError: If the arguments are invalid and do not match the function's input schema.
            ToolNotFoundError: If the requested function is not available.
        """
        if function not in self.functions:
            if raise_on_error:
                raise ToolNotFoundError(f"The requested function `{function}` is not available.")
            return "", f"ToolNotFoundError: The requested function `{function}` is not available."
        f = self.functions[function]

        # Execute nested function calls
        try:
            resolved_kwargs = self._execute_nested_calls(env, kwargs)
        except Exception as e:
            if raise_on_error:
                raise e
            return "", f"{type(e).__name__}: {e}"

        # Validate the arguments
        try:
            args = f.parameters.model_validate(resolved_kwargs).model_dump()
        except ValidationError as e:
            if raise_on_error:
                raise e
            return "", f"ValidationError: {e}"

        # Extract dependencies from the environment
        if env is None and len(f.dependencies) > 0:
            raise ValueError(f"Function {function} has dependencies but no environment was provided.")
        else:
            env_args = {
                arg_name: dependency.extract_dep_from_env(env)  # type: ignore -- If there are no dependencies, this line will not be executed
                for arg_name, dependency in f.dependencies.items()
            }
        kwargs_with_deps = args | env_args

        # Run the function
        try:
            return f(**kwargs_with_deps), None
        except Exception as e:
            if raise_on_error:
                raise e
            return "", f"{type(e).__name__}: {e}"

    def update_functions(self, new_functions: dict[str, Function]) -> None:
        self.functions = new_functions
