from typing import Callable

import pytest

from vita.data_model.message import (
    AssistantMessage,
    Message,
    ToolCall,
    ToolMessage,
    UserMessage,
)
from vita.data_model.tasks import (
    EnvAssertion,
    EnvFunctionCall,
)
from vita.environment.environment import Environment
from vita.environment.tool import Tool
from vita.environment.toolkit import ToolKitBase, ToolType, is_tool


@pytest.fixture
def domain_name() -> str:
    return "mock_domain"


@pytest.fixture
def policy() -> str:
    return "You are a helpful assistant."


@pytest.fixture
def mock_toolkit_class() -> Callable[[], ToolKitBase]:
    class MockToolkit(ToolKitBase):
        def __init__(self):
            self.val = 0

        @is_tool(ToolType.READ)
        def tool1(self, param1: int) -> str:
            self.val += param1
            return str(self.val)

        @is_tool(ToolType.READ)
        def tool2(self, param2: int) -> str:
            self.val += param2
            return str(self.val)

    return MockToolkit





@pytest.fixture
def super_mock_toolkit_class(
    mock_toolkit_class: Callable[[], ToolKitBase],
) -> Callable[[], ToolKitBase]:
    class SuperMockToolkit(mock_toolkit_class):
        @is_tool(ToolType.READ)
        def tool3(self, param3: int) -> str:
            self.val += param3
            return str(self.val)

    return SuperMockToolkit


@pytest.fixture
def message_history() -> list[Message]:
    return [
        UserMessage(
            id="1",
            content="Create a task called 'Important Meeting' for user_1",
            role="user",
        ),
        AssistantMessage(
            id="2",
            content=None,
            role="assistant",
            tool_calls=[
                ToolCall(
                    id="3",
                    name="create_task",
                    arguments={"user_id": "user_1", "title": "Important Meeting"},
                )
            ],
        ),
        ToolMessage(
            id="3",
            content='{"task_id": "task_2", "title": "Important Meeting", "description": null, "status": "pending"}',
            role="tool",
        ),
        AssistantMessage(
            id="4",
            content="Ok, I've created the task for you. The task ID is task_2.",
            role="assistant",
        ),
        UserMessage(id="5", content="Please mark task_1 as completed", role="user"),
        AssistantMessage(
            id="6",
            content=None,
            role="assistant",
            tool_calls=[
                ToolCall(
                    id="6",
                    name="update_task_status",
                    arguments={"task_id": "task_1", "status": "completed"},
                )
            ],
        ),
        ToolMessage(
            id="6",
            content='{"task_id": "task_1", "title": "Test task", "description": "A test task", "status": "completed"}',
            role="tool",
        ),
        AssistantMessage(
            id="7", content="I've marked task_1 as completed.", role="assistant"
        ),
    ]











def test_toolkit(
    mock_toolkit_class: Callable[[], ToolKitBase],
    super_mock_toolkit_class: Callable[[], ToolKitBase],
):
    # Test MockToolkit
    mock_toolkit = mock_toolkit_class()
    assert len(mock_toolkit.tools) == 2
    assert mock_toolkit.tools.keys() == {"tool1", "tool2"}
    assert mock_toolkit.get_tools().keys() == {"tool1", "tool2"}
    assert all(isinstance(tool, Tool) for tool in mock_toolkit.get_tools().values())
    assert mock_toolkit.use_tool("tool1", param1=1) == "1"
    assert mock_toolkit.use_tool("tool2", param2=2) == "3"

    # Test SuperMockToolkit
    super_mock_toolkit = super_mock_toolkit_class()
    assert len(super_mock_toolkit.tools) == 3
    assert super_mock_toolkit.tools.keys() == {"tool1", "tool2", "tool3"}
    assert super_mock_toolkit.use_tool("tool1", param1=1) == "1"
    assert super_mock_toolkit.use_tool("tool2", param2=2) == "3"
    assert super_mock_toolkit.use_tool("tool3", param3=3) == "6"




def test_environment(
    mock_toolkit_class: Callable[[], ToolKitBase],
    domain_name: str,
    policy: str,
):
    toolkit = mock_toolkit_class()
    environment = Environment(domain_name=domain_name, policy=policy, tools=toolkit)
    assert environment.get_policy() == policy
    assert environment.use_tool("tool1", param1=1) == "1"
    response = environment.get_response(
        ToolCall(id="1", name="tool2", arguments={"param2": 2})
    )
    assert isinstance(response, ToolMessage)
    assert response.id == "1"
    assert response.content == "3"
    assert response.role == "tool"


def test_weather_function():
    """Test the weather function with date filtering."""
    from vita.environment.toolkit import ToolKitBase
    from vita.data_model.tasks import Weather
    from datetime import datetime
    
    class MockDB:
        def __init__(self):
            self.weather = [
                Weather(
                    city="北京",
                    category="晴天",
                    datetime="2024-01-01",
                    temperature=[5.0, 15.0],
                    humidity=60.0
                ),
                Weather(
                    city="北京",
                    category="多云",
                    datetime="2024-01-02",
                    temperature=[3.0, 12.0],
                    humidity=70.0
                ),
                Weather(
                    city="北京",
                    category="雨天",
                    datetime="2024-01-03",
                    temperature=[0.0, 8.0],
                    humidity=85.0
                )
            ]
    
    class MockToolkit(ToolKitBase):
        def __init__(self):
            super().__init__(MockDB())
    
    toolkit = MockToolkit()
    
    result = toolkit.weather("北京", "2024-01-01", "2024-01-02")
    assert "晴天" in result
    assert "多云" in result
    assert "雨天" not in result
    assert result.count('\n') >= 1
    
    result = toolkit.weather("北京", "2024-01-01", "2024-01-03")
    assert "多云" in result
    assert "雨天" in result
    assert "晴天" in result
    assert result.count('\n') >= 1
    print(f"Multi-day result: {result}")
    
    result = toolkit.weather("北京", "2024-01-02", "2024-01-02")
    assert "多云" in result
    assert result.count('\n') == 0
    print(f"Single day result: {result}")
    
    result = toolkit.weather("北京", "invalid-date", "2024-01-02")
    assert "Invalid date_start format" in result
    
    result = toolkit.weather("北京", "2024-01-01", "invalid-date")
    assert "Invalid date_end format" in result












