from typing import Annotated

from agentgym.functions_engine.functions_engine import (
    FUNCTIONS_CALLABLES,
    FUNCTIONS_DOCS,
    Depends,
    register_function,
    run_function,
)
from agentgym.strenum import StrEnum
from pydantic import BaseModel, Field


class ATestEnum(StrEnum):
    a = "a"
    b = "b"


def test_register_function():
    class DummyModel(BaseModel):
        a: int

    def test_function(
        a: int,
        b: str,
        c: float,
        d: ATestEnum,
        e: list[str],
        f: int | None = None,
        g: int | str = 2,
        h: DummyModel = DummyModel(a=1),
    ):
        """Test function

        :param a: Description of a
        :param b: Description of b
        :param c: Description of c
        :param d: Description of d
        :param e: Description of e
        :param f: Description of f
        :param g: Description of g
        :param h: Description of h
        """

    class Function(BaseModel):
        model_config = {"title": f"Input schema for `{test_function.__name__}`"}

        a: int = Field(description="Description of a")
        b: str = Field(description="Description of b")
        c: float = Field(description="Description of c")
        d: ATestEnum = Field(description="Description of d")
        e: list[str] = Field(description="Description of e")
        f: int | None = Field(description="Description of f", default=None)
        g: int | str = Field(description="Description of g", default=2)
        h: DummyModel = Field(description="Description of h", default=DummyModel(a=1))

    register_function(test_function)
    assert "test_function" in FUNCTIONS_CALLABLES
    assert "test_function" in FUNCTIONS_DOCS
    assert FUNCTIONS_DOCS["test_function"].name == "test_function"
    assert FUNCTIONS_DOCS["test_function"].description == "Test function"
    assert FUNCTIONS_DOCS["test_function"].parameters.model_json_schema() == Function.model_json_schema()


def test_run_stateful_function():
    class DummyState(BaseModel):
        l: list[int] = Field(default_factory=list)

    class DummyEnv(BaseModel):
        list_state: DummyState = DummyState()

    @register_function
    def test_stateful_sum(l: Annotated[DummyState, Depends("list_state")], a: int, b: int) -> int:
        """Test sum

        :param a: First number
        :param b: Second number
        """
        res = a + b
        l.l.append(res)
        return a + b

    env = DummyEnv()
    result_1, error_1 = run_function(env, "test_stateful_sum", {"a": 1, "b": 2})
    assert result_1 == 3
    assert env.list_state.l == [3]
    assert error_1 is None

    result_2, error_2 = run_function(env, "test_stateful_sum", {"a": 4, "b": 5})
    assert result_2 == 9
    assert env.list_state.l == [3, 9]
    assert error_2 is None

    assert test_stateful_sum(env.list_state, 1, 2) == 3
    assert env.list_state.l == [3, 9, 3]


def test_run_stateless_function():
    @register_function
    def test_sum(a: int, b: int):
        """Test sum

        :param a: First number
        :param b: Second number
        """
        return a + b

    class DummyEnv(BaseModel): ...

    result, error = run_function(DummyEnv(), "test_sum", {"a": 1, "b": 2})
    assert result == 3
    assert test_sum(1, 3) == 4
    assert error is None

    result, error = run_function(DummyEnv(), "test_sum", {"a": 1, "b": "10"})
    assert result == 11
    assert error is None


def test_run_error_function():
    @register_function
    def test_sum_2(a: int, b: int):
        """Test sum

        :param a: First number
        :param b: Second number
        """
        return a + b

    class DummyEnv(BaseModel): ...

    result, error = run_function(DummyEnv(), "test_sum_2", {"a": 1, "b": "abc"})
    assert result == ""
    assert error is not None
    assert "ValidationError" in error
