from collections.abc import Sequence

from agentdojo.agent_pipeline.base_pipeline_element import BasePipelineElement
from agentdojo.agent_pipeline.tool_execution import tool_result_to_str
from agentdojo.base_tasks import BaseInjectionTask, BaseUserTask
from agentdojo.functions_runtime import EmptyEnv, Env, FunctionsRuntime
from agentdojo.types import ChatAssistantMessage, ChatMessage, ChatToolResultMessage, text_content_block_from_string

_ARG_PLACEHOLDERS = {str: "", int: 0, float: 0.0, bool: False, dict: {}, list: []}


class GroundTruthPipeline(BasePipelineElement):
    """Pipeline that executes a task's `ground_truth`.

    It is used to identify which injection placeholders are seen by the agent during the task execution."""

    def __init__(self, task: BaseUserTask | BaseInjectionTask | None) -> None:
        self._task = task

    def query(
        self,
        query: str,
        runtime: FunctionsRuntime,
        env: Env = EmptyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, FunctionsRuntime, Env, Sequence[ChatMessage], dict]:
        if self._task is None:
            raise ValueError("No task set for the ground truth. Call `set_task` first.")
        ground_truth_tool_calls = self._task.ground_truth(env)
        new_messages = []
        for tool_call in ground_truth_tool_calls:
            tool_name = tool_call.function
            new_messages.append(
                ChatAssistantMessage(
                    role="assistant", tool_calls=[tool_call], content=[text_content_block_from_string("")]
                )
            )
            tool_result, _ = runtime.run_function(env, tool_name, tool_call.args, raise_on_error=True)
            new_messages.append(
                ChatToolResultMessage(
                    role="tool",
                    content=[text_content_block_from_string(tool_result_to_str(tool_result))],
                    tool_call=tool_call,
                    tool_call_id=None,
                    error=None,
                )
            )

        assistant_message_content = self._task.GROUND_TRUTH_OUTPUT

        new_messages.append(
            ChatAssistantMessage(
                role="assistant", content=[text_content_block_from_string(assistant_message_content)], tool_calls=None
            )
        )
        return query, runtime, env, [*messages, *new_messages], extra_args
