from collections.abc import Sequence

from agentdojo.agent_pipeline.base_pipeline_element import BasePipelineElement
from agentdojo.functions_runtime import EmptyEnv, Env, Function, FunctionCall, FunctionsRuntime
from agentdojo.types import ChatMessage


class ToolSelector(BasePipelineElement):
    def __init__(self, llm: BasePipelineElement) -> None:
        self.llm = llm

    def select_tools(self, query: str, tools: dict[str, Function]) -> dict[str, Function]:
        raise NotImplementedError()

    def query(
        self,
        query: str,
        runtime: FunctionsRuntime,
        env: Env = EmptyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, FunctionsRuntime, Env, Sequence[ChatMessage], dict]:
        tools_to_use = self.select_tools(query, runtime.functions)
        runtime.update_functions(tools_to_use)
        return query, runtime, env, messages, extra_args


class ToolUsagePlanner(BasePipelineElement):
    def __init__(self, llm: BasePipelineElement) -> None:
        self.llm = llm

    def plan(self, query: str, tools: dict[str, Function]) -> Sequence[FunctionCall]:
        raise NotImplementedError()

    def query(
        self,
        query: str,
        runtime: FunctionsRuntime,
        env: Env = EmptyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, FunctionsRuntime, Env, Sequence[ChatMessage], dict]:
        plan = self.plan(query, runtime.functions)
        tools_to_use_names = {tool_call.function for tool_call in plan}
        tools_to_use = {
            tool_name: tool for tool_name, tool in runtime.functions.items() if tool_name in tools_to_use_names
        }
        runtime.update_functions(tools_to_use)
        extra_args["plan"] = plan
        return query, runtime, env, messages, extra_args
