from collections.abc import Sequence
from typing import get_type_hints

from agentgym.agent_pipeline.base_pipeline_element import BasePipelineElement
from agentgym.agent_pipeline.tool_execution import tool_result_to_str
from agentgym.base_tasks import BaseInjectionTask, BaseUserTask
from agentgym.functions_engine.functions_engine import (
    FUNCTIONS_CALLABLES,
    Env,
    Function,
    FunctionCallArgs,
    FunctionCallable,
    run_function,
)
from agentgym.types import (
    ChatAssistantMessage,
    ChatMessage,
    ChatToolResultMessage,
    DummyEnv,
)

_ARG_PLACEHOLDERS = {
    str: "",
    int: 0,
    float: 0.0,
    bool: False,
    dict: {},
    list: [],
}


def substitute_ellipses(
    tool_name: str,
    args: FunctionCallArgs,
    tool_callables: dict[str, FunctionCallable],
) -> FunctionCallArgs:
    new_args = {}
    tool = tool_callables[tool_name]
    for key, value in args.items():
        if value is Ellipsis:
            arg_type = get_type_hints(tool)[key]
            new_args[key] = _ARG_PLACEHOLDERS[arg_type]
        else:
            new_args[key] = value
    return new_args


class GroundTruthPipeline(BasePipelineElement):
    def __init__(self, task: BaseUserTask | BaseInjectionTask | None) -> None:
        self._task = task

    def query(
        self,
        query: str,
        tools: Sequence[Function],
        env: Env = DummyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, Sequence[Function], Env, Sequence[ChatMessage], dict]:
        if self._task is None:
            raise ValueError("No task set for the ground truth. Call `set_task` first.")
        ground_truth_tool_calls = self._task.ground_truth(env)
        new_messages = []
        for tool_call in ground_truth_tool_calls:
            tool_name = tool_call["function"]
            tool_args = substitute_ellipses(tool_name, tool_call["args"], FUNCTIONS_CALLABLES)
            new_messages.append(ChatAssistantMessage(role="assistant", tool_calls=[tool_call], content=""))
            tool_result, _ = run_function(env, tool_name, tool_args, raise_on_error=True)
            new_messages.append(
                ChatToolResultMessage(
                    role="tool",
                    content=tool_result_to_str(tool_result),
                    tool_call=tool_call,
                    tool_call_id=None,
                    error=None,
                )
            )

        assistant_message_content = self._task.GROUND_TRUTH_OUTPUT

        new_messages.append(
            ChatAssistantMessage(
                role="assistant",
                content=assistant_message_content,
                tool_calls=None,
            )
        )
        return query, tools, env, [*messages, *new_messages], extra_args
