from google.genai import types

from agentdojo.agent_pipeline.llms.google_llm import (
    _merge_tool_result_messages,
)


def test_merge_tool_result_messages():
    messages = [
        types.Content(
            role="user",
            parts=[types.Part.from_text(text="Hi")],
        ),
        types.Content(
            parts=[
                types.Part.from_function_response(
                    name="tool1",
                    response={"content": "output1"},
                )
            ],
        ),
        types.Content(
            parts=[
                types.Part.from_function_response(
                    name="tool2",
                    response={"content": "output2"},
                )
            ],
        ),
        types.Content(
            role="user",
            parts=[types.Part.from_text(text="Bye")],
        ),
        types.Content(
            parts=[
                types.Part.from_function_response(
                    name="tool3",
                    response={"content": "output3"},
                )
            ],
        ),
        types.Content(
            parts=[
                types.Part.from_function_response(
                    name="tool4",
                    response={"content": "output4"},
                )
            ],
        ),
    ]

    expected_messages = [
        types.Content(
            role="user",
            parts=[types.Part.from_text(text="Hi")],
        ),
        types.Content(
            parts=[
                types.Part.from_function_response(
                    name="tool1",
                    response={"content": "output1"},
                ),
                types.Part.from_function_response(
                    name="tool2",
                    response={"content": "output2"},
                ),
            ],
        ),
        types.Content(
            role="user",
            parts=[types.Part.from_text(text="Bye")],
        ),
        types.Content(
            parts=[
                types.Part.from_function_response(
                    name="tool3",
                    response={"content": "output3"},
                ),
                types.Part.from_function_response(
                    name="tool4",
                    response={"content": "output4"},
                ),
            ],
        ),
    ]
    merged_messages = _merge_tool_result_messages(messages)

    assert merged_messages == expected_messages
