from agentgym.agent_pipeline.llms.anthropic_llm import (
    _merge_tool_result_messages,
)
from anthropic.types import (
    MessageParam,
    TextBlockParam,
    ToolResultBlockParam,
)


def test_merge_tool_result_messages():
    messages = [
        MessageParam(role="user", content=[TextBlockParam(type="text", text="Hi!")]),
        MessageParam(
            role="user",
            content=[
                ToolResultBlockParam(
                    type="tool_result",
                    tool_use_id="tool_use_id_1",
                    content=[TextBlockParam(type="text", text="output_1")],
                    is_error=False,
                )
            ],
        ),
        MessageParam(
            role="assistant",
            content="Hello",
        ),
        MessageParam(
            role="user",
            content=[
                ToolResultBlockParam(
                    type="tool_result",
                    tool_use_id="tool_use_id_2",
                    content=[TextBlockParam(type="text", text="output_2")],
                    is_error=False,
                )
            ],
        ),
        MessageParam(
            role="user",
            content=[
                ToolResultBlockParam(
                    type="tool_result",
                    tool_use_id="tool_use_id_3",
                    content=[TextBlockParam(type="text", text="output_3")],
                    is_error=False,
                )
            ],
        ),
        MessageParam(role="user", content=[TextBlockParam(type="text", text="Bye")]),
        MessageParam(
            role="assistant",
            content="Hello",
        ),
        MessageParam(
            role="user",
            content=[
                ToolResultBlockParam(
                    type="tool_result",
                    tool_use_id="tool_use_id_4",
                    content=[TextBlockParam(type="text", text="output_4")],
                    is_error=False,
                )
            ],
        ),
        MessageParam(
            role="user",
            content=[
                ToolResultBlockParam(
                    type="tool_result",
                    tool_use_id="tool_use_id_5",
                    content=[TextBlockParam(type="text", text="output_5")],
                    is_error=False,
                )
            ],
        ),
    ]
    expected_messages = [
        MessageParam(
            role="user",
            content=[
                TextBlockParam(type="text", text="Hi!"),
                ToolResultBlockParam(
                    type="tool_result",
                    tool_use_id="tool_use_id_1",
                    content=[TextBlockParam(type="text", text="output_1")],
                    is_error=False,
                ),
            ],
        ),
        MessageParam(
            role="assistant",
            content="Hello",
        ),
        MessageParam(
            role="user",
            content=[
                ToolResultBlockParam(
                    type="tool_result",
                    tool_use_id="tool_use_id_2",
                    content=[TextBlockParam(type="text", text="output_2")],
                    is_error=False,
                ),
                ToolResultBlockParam(
                    type="tool_result",
                    tool_use_id="tool_use_id_3",
                    content=[TextBlockParam(type="text", text="output_3")],
                    is_error=False,
                ),
                TextBlockParam(type="text", text="Bye"),
            ],
        ),
        MessageParam(
            role="assistant",
            content="Hello",
        ),
        MessageParam(
            role="user",
            content=[
                ToolResultBlockParam(
                    type="tool_result",
                    tool_use_id="tool_use_id_4",
                    content=[TextBlockParam(type="text", text="output_4")],
                    is_error=False,
                ),
                ToolResultBlockParam(
                    type="tool_result",
                    tool_use_id="tool_use_id_5",
                    content=[TextBlockParam(type="text", text="output_5")],
                    is_error=False,
                ),
            ],
        ),
    ]

    merged_messages = _merge_tool_result_messages(messages)
    assert merged_messages == expected_messages
