import types
from typing import Annotated

from pydantic import BaseModel, Field

from agentdojo.functions_runtime import Depends, EmptyEnv, FunctionCall, FunctionsRuntime, TaskEnvironment
from agentdojo.strenum import StrEnum


class ATestEnum(StrEnum):
    a = "a"
    b = "b"


def test_register_function():
    runtime = FunctionsRuntime([])

    class DummyModel(BaseModel):
        a: int

    @runtime.register_function
    def test_function(
        a: int,
        b: str,
        c: float,
        d: ATestEnum,
        e: list[str],
        f: int | None = None,
        g: int | str = 2,
        h: DummyModel = DummyModel(a=1),
    ) -> None:
        """Test function

        :param a: Description of a
        :param b: Description of b
        :param c: Description of c
        :param d: Description of d
        :param e: Description of e
        :param f: Description of f
        :param g: Description of g
        :param h: Description of h
        """

    class Function(BaseModel):
        model_config = {"title": f"Input schema for `{test_function.__name__}`"}

        a: int = Field(description="Description of a")
        b: str = Field(description="Description of b")
        c: float = Field(description="Description of c")
        d: ATestEnum = Field(description="Description of d")
        e: list[str] = Field(description="Description of e")
        f: int | None = Field(description="Description of f", default=None)
        g: int | str = Field(description="Description of g", default=2)
        h: DummyModel = Field(description="Description of h", default=DummyModel(a=1))

    assert "test_function" in runtime.functions
    assert runtime.functions["test_function"].name == "test_function"
    assert runtime.functions["test_function"].description == "Test function"
    assert runtime.functions["test_function"].parameters.model_json_schema() == Function.model_json_schema()
    assert runtime.functions["test_function"].return_type == types.NoneType
    assert runtime.functions["test_function"].full_docstring == test_function.__doc__


def test_register_function_return_type():
    runtime = FunctionsRuntime([])

    class DummyModel(BaseModel):
        a: int

    @runtime.register_function
    def test_function(
        a: int,
        b: str,
        c: float,
        d: ATestEnum,
        e: list[str],
        f: int | None = None,
        g: int | str = 2,
        h: DummyModel = DummyModel(a=1),
    ) -> int:
        """Test function

        :param a: Description of a
        :param b: Description of b
        :param c: Description of c
        :param d: Description of d
        :param e: Description of e
        :param f: Description of f
        :param g: Description of g
        :param h: Description of h
        """
        ...

    class Function(BaseModel):
        model_config = {"title": f"Input schema for `{test_function.__name__}`"}

        a: int = Field(description="Description of a")
        b: str = Field(description="Description of b")
        c: float = Field(description="Description of c")
        d: ATestEnum = Field(description="Description of d")
        e: list[str] = Field(description="Description of e")
        f: int | None = Field(description="Description of f", default=None)
        g: int | str = Field(description="Description of g", default=2)
        h: DummyModel = Field(description="Description of h", default=DummyModel(a=1))

    assert "test_function" in runtime.functions
    assert runtime.functions["test_function"].name == "test_function"
    assert runtime.functions["test_function"].description == "Test function"
    assert runtime.functions["test_function"].parameters.model_json_schema() == Function.model_json_schema()
    assert runtime.functions["test_function"].return_type == int
    assert runtime.functions["test_function"].full_docstring == test_function.__doc__


def test_register_function_return_union_type():
    runtime = FunctionsRuntime([])

    class DummyModel(BaseModel):
        a: int

    @runtime.register_function
    def test_function(
        a: int,
        b: str,
        c: float,
        d: ATestEnum,
        e: list[str],
        f: int | None = None,
        g: int | str = 2,
        h: DummyModel = DummyModel(a=1),
    ) -> int | str | dict[str, list[int]]:
        """Test function

        :param a: Description of a
        :param b: Description of b
        :param c: Description of c
        :param d: Description of d
        :param e: Description of e
        :param f: Description of f
        :param g: Description of g
        :param h: Description of h
        """
        ...

    class Function(BaseModel):
        model_config = {"title": f"Input schema for `{test_function.__name__}`"}

        a: int = Field(description="Description of a")
        b: str = Field(description="Description of b")
        c: float = Field(description="Description of c")
        d: ATestEnum = Field(description="Description of d")
        e: list[str] = Field(description="Description of e")
        f: int | None = Field(description="Description of f", default=None)
        g: int | str = Field(description="Description of g", default=2)
        h: DummyModel = Field(description="Description of h", default=DummyModel(a=1))

    assert "test_function" in runtime.functions
    assert runtime.functions["test_function"].name == "test_function"
    assert runtime.functions["test_function"].description == "Test function"
    assert runtime.functions["test_function"].parameters.model_json_schema() == Function.model_json_schema()
    assert runtime.functions["test_function"].return_type == int | str | dict[str, list[int]]
    assert runtime.functions["test_function"].full_docstring == test_function.__doc__


def test_register_function_return_none():
    runtime = FunctionsRuntime([])

    class DummyModel(BaseModel):
        a: int

    @runtime.register_function
    def test_function(
        a: int,
        b: str,
        c: float,
        d: ATestEnum,
        e: list[str],
        f: int | None = None,
        g: int | str = 2,
        h: DummyModel = DummyModel(a=1),
    ) -> None:
        """Test function

        :param a: Description of a
        :param b: Description of b
        :param c: Description of c
        :param d: Description of d
        :param e: Description of e
        :param f: Description of f
        :param g: Description of g
        :param h: Description of h
        """
        ...

    class Function(BaseModel):
        model_config = {"title": f"Input schema for `{test_function.__name__}`"}

        a: int = Field(description="Description of a")
        b: str = Field(description="Description of b")
        c: float = Field(description="Description of c")
        d: ATestEnum = Field(description="Description of d")
        e: list[str] = Field(description="Description of e")
        f: int | None = Field(description="Description of f", default=None)
        g: int | str = Field(description="Description of g", default=2)
        h: DummyModel = Field(description="Description of h", default=DummyModel(a=1))

    assert "test_function" in runtime.functions
    assert runtime.functions["test_function"].name == "test_function"
    assert runtime.functions["test_function"].description == "Test function"
    assert runtime.functions["test_function"].parameters.model_json_schema() == Function.model_json_schema()
    assert runtime.functions["test_function"].return_type == types.NoneType
    assert runtime.functions["test_function"].full_docstring == test_function.__doc__


def test_register_function_no_return_type():
    runtime = FunctionsRuntime([])

    class DummyModel(BaseModel):
        a: int

    @runtime.register_function
    def test_function(
        a: int,
        b: str,
        c: float,
        d: ATestEnum,
        e: list[str],
        f: int | None = None,
        g: int | str = 2,
        h: DummyModel = DummyModel(a=1),
    ):
        """Test function

        :param a: Description of a
        :param b: Description of b
        :param c: Description of c
        :param d: Description of d
        :param e: Description of e
        :param f: Description of f
        :param g: Description of g
        :param h: Description of h
        """
        ...

    class Function(BaseModel):
        model_config = {"title": f"Input schema for `{test_function.__name__}`"}

        a: int = Field(description="Description of a")
        b: str = Field(description="Description of b")
        c: float = Field(description="Description of c")
        d: ATestEnum = Field(description="Description of d")
        e: list[str] = Field(description="Description of e")
        f: int | None = Field(description="Description of f", default=None)
        g: int | str = Field(description="Description of g", default=2)
        h: DummyModel = Field(description="Description of h", default=DummyModel(a=1))

    assert "test_function" in runtime.functions
    assert runtime.functions["test_function"].name == "test_function"
    assert runtime.functions["test_function"].description == "Test function"
    assert runtime.functions["test_function"].parameters.model_json_schema() == Function.model_json_schema()
    assert runtime.functions["test_function"].return_type == None
    assert runtime.functions["test_function"].full_docstring == test_function.__doc__


def test_register_function_return_union_type_nested():
    runtime = FunctionsRuntime([])

    class DummyModel(BaseModel):
        a: int

    @runtime.register_function
    def test_function(
        a: int,
        b: str,
        c: float,
        d: ATestEnum,
        e: list[str],
        f: int | None = None,
        g: int | str = 2,
        h: DummyModel = DummyModel(a=1),
    ) -> int | str | (None | float):
        """Test function

        :param a: Description of a
        :param b: Description of b
        :param c: Description of c
        :param d: Description of d
        :param e: Description of e
        :param f: Description of f
        :param g: Description of g
        :param h: Description of h
        """
        ...

    class Function(BaseModel):
        model_config = {"title": f"Input schema for `{test_function.__name__}`"}

        a: int = Field(description="Description of a")
        b: str = Field(description="Description of b")
        c: float = Field(description="Description of c")
        d: ATestEnum = Field(description="Description of d")
        e: list[str] = Field(description="Description of e")
        f: int | None = Field(description="Description of f", default=None)
        g: int | str = Field(description="Description of g", default=2)
        h: DummyModel = Field(description="Description of h", default=DummyModel(a=1))

    assert "test_function" in runtime.functions
    assert runtime.functions["test_function"].name == "test_function"
    assert runtime.functions["test_function"].description == "Test function"
    assert runtime.functions["test_function"].parameters.model_json_schema() == Function.model_json_schema()
    assert runtime.functions["test_function"].return_type == int | str | None | float
    assert runtime.functions["test_function"].full_docstring == test_function.__doc__


def test_run_stateful_function():
    runtime = FunctionsRuntime([])

    class DummyState(BaseModel):
        l: list[int] = Field(default_factory=list)

    class EmptyEnv(TaskEnvironment):
        list_state: DummyState = DummyState()

    @runtime.register_function
    def test_stateful_sum(l: Annotated[DummyState, Depends("list_state")], a: int, b: int) -> int:
        """Test sum

        :param a: First number
        :param b: Second number
        """
        res = a + b
        l.l.append(res)
        return a + b

    env = EmptyEnv()
    result_1, error_1 = runtime.run_function(env, "test_stateful_sum", {"a": 1, "b": 2})
    assert result_1 == 3
    assert env.list_state.l == [3]
    assert error_1 is None

    result_2, error_2 = runtime.run_function(env, "test_stateful_sum", {"a": 4, "b": 5})
    assert result_2 == 9
    assert env.list_state.l == [3, 9]
    assert error_2 is None

    assert test_stateful_sum(env.list_state, 1, 2) == 3
    assert env.list_state.l == [3, 9, 3]


def test_run_stateless_function():
    runtime = FunctionsRuntime([])

    @runtime.register_function
    def test_sum(a: int, b: int):
        """Test sum

        :param a: First number
        :param b: Second number
        """
        return a + b

    result, error = runtime.run_function(EmptyEnv(), "test_sum", {"a": 1, "b": 2})
    assert result == 3
    assert test_sum(1, 3) == 4
    assert error is None

    result, error = runtime.run_function(EmptyEnv(), "test_sum", {"a": 1, "b": "10"})
    assert result == 11
    assert error is None


def test_run_nested_function_call():
    runtime = FunctionsRuntime([])

    @runtime.register_function
    def test_sum(a: int, b: int):
        """Test sum

        :param a: First number
        :param b: Second number
        """
        return a + b

    result, error = runtime.run_function(
        EmptyEnv(), "test_sum", {"a": 1, "b": FunctionCall(function="test_sum", args={"a": 1, "b": 2})}
    )

    assert error is None
    assert result == 4


def test_run_error_function():
    runtime = FunctionsRuntime([])

    @runtime.register_function
    def test_sum_2(a: int, b: int):
        """Test sum

        :param a: First number
        :param b: Second number
        """
        return a + b

    # Test error in top-level function
    result, error = runtime.run_function(EmptyEnv(), "test_sum_2", {"a": 1, "b": "abc"})
    assert result == ""
    assert error is not None
    assert "ValidationError" in error

    # Test error in nested function
    result, error = runtime.run_function(
        EmptyEnv(), "test_sum_2", {"a": 1, "b": FunctionCall(function="test_sum_2", args={"a": 1, "b": "abc"})}
    )
    assert result == ""
    assert error is not None
    assert "ValidationError" in error
