from agentdojo.agent_pipeline.llms.cohere_llm import _merge_tool_result_messages

try:
    from cohere.types import (
        Message_System,  # type: ignore
        Message_Tool,  # type: ignore
        Message_User,  # type: ignore
    )
except ImportError:
    from cohere.types import SystemMessage as Message_System
    from cohere.types import ToolMessage as Message_Tool
    from cohere.types import UserMessage as Message_User
from cohere.types import ToolCall, ToolResult


def test_merge_tool_result_messages():
    messages = [
        Message_System(message="Hi"),
        Message_Tool(
            tool_results=[
                ToolResult(
                    call=ToolCall(name="tool1", parameters={}),
                    outputs=[{"result": "output1"}],
                )
            ]
        ),
        Message_Tool(
            tool_results=[
                ToolResult(
                    call=ToolCall(name="tool2", parameters={"arg1": "value1"}),
                    outputs=[{"result": "output2"}],
                )
            ]
        ),
        Message_Tool(
            tool_results=[
                ToolResult(
                    call=ToolCall(name="tool3", parameters={"arg1": "value1", "arg2": "value2"}),
                    outputs=[{"result": "output3"}],
                )
            ]
        ),
        Message_User(role="USER", message="Bye"),
        Message_Tool(
            tool_results=[
                ToolResult(
                    call=ToolCall(name="tool4", parameters={"arg1": "value1"}),
                    outputs=[{"result": "output2"}],
                )
            ]
        ),
        Message_Tool(
            tool_results=[
                ToolResult(
                    call=ToolCall(name="tool5", parameters={"arg1": "value1", "arg2": "value2"}),
                    outputs=[{"result": "output3"}],
                )
            ]
        ),
    ]
    expected_messages = [
        Message_System(message="Hi"),
        Message_Tool(
            tool_results=[
                ToolResult(
                    call=ToolCall(name="tool1", parameters={}),
                    outputs=[{"result": "output1"}],
                ),
                ToolResult(
                    call=ToolCall(name="tool2", parameters={"arg1": "value1"}),
                    outputs=[{"result": "output2"}],
                ),
                ToolResult(
                    call=ToolCall(name="tool3", parameters={"arg1": "value1", "arg2": "value2"}),
                    outputs=[{"result": "output3"}],
                ),
            ]
        ),
        Message_User(role="USER", message="Bye"),
        Message_Tool(
            tool_results=[
                ToolResult(
                    call=ToolCall(name="tool4", parameters={"arg1": "value1"}),
                    outputs=[{"result": "output2"}],
                ),
                ToolResult(
                    call=ToolCall(name="tool5", parameters={"arg1": "value1", "arg2": "value2"}),
                    outputs=[{"result": "output3"}],
                ),
            ]
        ),
    ]

    merged_messages = _merge_tool_result_messages(messages)
    assert merged_messages == expected_messages
