import functools
import inspect
import sys
import warnings
from collections.abc import Callable, Mapping, Sequence
from types import EllipsisType, NoneType
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Concatenate,
    Literal,
    TypeAlias,
    get_origin,
    get_type_hints,
)

from docstring_parser import parse
from docstring_parser.common import DocstringParam
from pydantic import BaseModel, Field, create_model, ValidationError
from typing_extensions import NotRequired, ParamSpec, TypedDict, TypeVar

if TYPE_CHECKING:
    from agentgym.task_suites.task_suite import TaskEnvironment


class Depends:
    def __init__(self, env_dependency: str) -> None:
        self.env_dependency = env_dependency

    def extract_dep_from_env(self, env: BaseModel) -> BaseModel:
        return getattr(env, self.env_dependency)

    def __repr__(self) -> str:
        return f"Depends('{self.env_dependency}')"


class Function(BaseModel):
    name: str
    description: str
    parameters: type[BaseModel]


FunctionCallArgTypes = str | int | float | bool | NoneType | dict | list | EllipsisType

FunctionCallArgs = Mapping[str, FunctionCallArgTypes]


class FunctionCall(TypedDict):
    function: str
    args: FunctionCallArgs
    id: NotRequired[str]
    placeholder_args: NotRequired[FunctionCallArgs]


FunctionReturnType: TypeAlias = BaseModel | Sequence["FunctionReturnType"] | dict | str | int | float | bool | NoneType

P = ParamSpec("P")
S = TypeVar("S")
Env = TypeVar("Env", bound="TaskEnvironment")
T = TypeVar("T", bound=FunctionReturnType)


def _register_function(
    function: Callable[Concatenate[Env, P], T],
) -> Callable[Concatenate[Env, P], T | str]:
    if function.__name__ in FUNCTIONS_CALLABLES:
        warnings.warn(f"Function {function.__name__} already registered")
    if function.__doc__ is None or function.__doc__ == "":
        raise ValueError(f"Function {function.__name__} has no docstring")

    function_docs = parse(function.__doc__)
    if function_docs.short_description is None:
        raise ValueError(f"Function {function.__name__} has no short description")

    @functools.wraps(function)
    def wrapped_function(env: Env, *args: P.args, **kwargs: P.kwargs) -> T:
        return function(env, *args, **kwargs)

    FUNCTIONS_CALLABLES[function.__name__] = wrapped_function
    FUNCTIONS_DOCS[function.__name__] = Function(
        name=function.__name__,
        description=function_docs.short_description.replace("None", "null"),
        parameters=_parse_args(function.__name__, function_docs.params, inspect.signature(function)),
    )

    return wrapped_function


def _get_dependencies(function: Callable[P, T]) -> dict[str, Depends]:
    arg_types = get_type_hints(function, include_extras=True)
    dependencies: dict[str, Depends] = {}
    for arg_name, arg_type in arg_types.items():
        if get_origin(arg_type) is Annotated:
            if isinstance(dependency := arg_type.__metadata__[0], Depends):
                dependencies[arg_name] = dependency
    return dependencies


def register_function(function: Callable[P, T]) -> Callable[P, T]:
    dependencies: dict[str, Depends] = _get_dependencies(function)

    @functools.wraps(function)
    def wrapped_function(env: BaseModel, *args: P.args, **kwargs: P.kwargs) -> FunctionReturnType:
        env_args = {arg_name: dependency.extract_dep_from_env(env) for arg_name, dependency in dependencies.items()}
        kwargs.update(env_args)
        output = function(*args, **kwargs)
        return output

    # Register the wrapped function for tool use
    try:
        _register_function(wrapped_function)
    except Exception as e:
        print(
            "Error when registering function",
            function.__name__,
            "in",
            function.__module__,
            file=sys.stderr,
            flush=True,
        )
        raise e

    return function


FunctionCallable = Callable[Concatenate[Env, ...], FunctionReturnType]
FUNCTIONS_CALLABLES: dict[str, FunctionCallable] = {}
FUNCTIONS_DOCS: dict[str, Function] = {}


def get_available_tools_callables(tools: list[str] | None) -> list[Callable]:
    if tools is None:
        return [FUNCTIONS_CALLABLES[function] for function in FUNCTIONS_DOCS]
    return [FUNCTIONS_CALLABLES[tool] for tool in tools]


def run_function(
    env: BaseModel, function: str, kwargs: FunctionCallArgs, raise_on_error: bool = False
) -> tuple[FunctionReturnType, str | None]:
    if function not in FUNCTIONS_DOCS:
        return "", f"ToolNotFoundException: The requested function `{function}` is not available."
    function_input_schema = FUNCTIONS_DOCS[function].parameters
    try:
        args = function_input_schema.model_validate(kwargs).model_dump()
    except ValidationError as e:
        return "", f"ValidationError: {e}"

    try:
        return FUNCTIONS_CALLABLES[function](env, **args), None
    except Exception as e:
        if raise_on_error:
            raise e
        return "", f"{type(e).__name__}: {e}"


BM = TypeVar("BM", bound=BaseModel)


def _parse_args(
    function_name: str,
    params_docs: list[DocstringParam],
    signature: inspect.Signature,
    field_fn: Callable = Field,
    create_model_fn: Callable[..., type[BM]] = create_model,
) -> type[BM]:
    params = {}
    for arg in params_docs:
        arg_name = arg.arg_name
        if arg.description is None:
            raise ValueError(f"Argument {arg_name} has no description")
        arg_type = signature.parameters[arg_name].annotation
        default_value = signature.parameters[arg_name].default
        params[arg_name] = (
            arg_type,
            field_fn(
                description=arg.description,
                default=... if default_value == inspect.Parameter.empty else default_value,
            ),
        )

    return create_model_fn(f"Input schema for `{function_name}`", **params)
