from collections.abc import Callable, Sequence

import yaml
from pydantic import BaseModel

from agentgym.agent_pipeline.base_pipeline_element import (
    BasePipelineElement,
    DummyEnv,
)
from agentgym.agent_pipeline.llms.google_llm import EMPTY_FUNCTION_NAME
from agentgym.functions_engine.functions_engine import (
    Env,
    Function,
    FunctionReturnType,
    run_function,
)
from agentgym.logging import Logger
from agentgym.types import ChatMessage, ChatToolResultMessage


def tool_result_to_str(
    tool_result: FunctionReturnType,
    dump_fn: Callable[[dict | list[dict]], str] = yaml.safe_dump,
) -> str:
    """Basic tool output formatter with YAML dump by default. Could work with `json.dumps` as
    `dump_fn`."""
    if isinstance(tool_result, BaseModel):
        return dump_fn(tool_result.model_dump()).strip()

    if isinstance(tool_result, list):
        res_items = []
        for item in tool_result:
            if type(item) in [str, int]:
                res_items += [str(item)]
            elif isinstance(item, BaseModel):
                res_items += [item.model_dump()]
            else:
                raise TypeError("Not valid type for item tool result: " + str(type(item)))

        # If type checking passes, this is guaranteed to be a list of BaseModel
        return dump_fn(res_items).strip()

    return str(tool_result)


class ToolsExecutor(BasePipelineElement):
    def __init__(
        self,
        tool_output_formatter: Callable[[FunctionReturnType], str] = tool_result_to_str,
    ) -> None:
        self.output_formatter = tool_output_formatter

    def query(
        self,
        query: str,
        tools: Sequence[Function],
        env: Env = DummyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, Sequence[Function], Env, Sequence[ChatMessage], dict]:
        if len(messages) == 0:
            return query, tools, env, messages, extra_args
        if messages[-1]["role"] != "assistant":
            return query, tools, env, messages, extra_args
        if messages[-1]["tool_calls"] is None or len(messages[-1]["tool_calls"]) == 0:
            return query, tools, env, messages, extra_args

        tool_call_results = []
        for tool_call in messages[-1]["tool_calls"]:
            if tool_call["function"] == EMPTY_FUNCTION_NAME:
                tool_call_results.append(
                    ChatToolResultMessage(
                        role="tool",
                        content="",
                        tool_call_id=tool_call["id"] if "id" in tool_call else None,
                        tool_call=tool_call,
                        error="Empty function name provided. Provide a valid function name.",
                    )
                )
                continue
            if tool_call["function"] not in (tool.name for tool in tools):
                tool_call_results.append(
                    ChatToolResultMessage(
                        role="tool",
                        content="",
                        tool_call_id=tool_call["id"] if "id" in tool_call else None,
                        tool_call=tool_call,
                        error=f"Invalid tool {tool_call['function']} provided.",
                    )
                )
                continue
            tool_call_result, error = run_function(env, tool_call["function"], tool_call["args"])
            tool_call_id = tool_call["id"] if "id" in tool_call else None
            formatted_tool_call_result = self.output_formatter(tool_call_result)
            tool_call_results.append(
                ChatToolResultMessage(
                    role="tool",
                    content=formatted_tool_call_result,
                    tool_call_id=tool_call_id,
                    tool_call=tool_call,
                    error=error,
                )
            )
        return query, tools, env, [*messages, *tool_call_results], extra_args


class ToolsExecutionLoop(BasePipelineElement):
    def __init__(self, elements: Sequence[BasePipelineElement], max_iters: int = 15) -> None:
        self.max_iters = max_iters
        self.elements = elements

    def query(
        self,
        query: str,
        tools: Sequence[Function],
        env: Env = DummyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, Sequence[Function], Env, Sequence[ChatMessage], dict]:
        if len(messages) == 0:
            raise ValueError("Messages should not be empty when calling ToolsExecutionLoop")

        logger = Logger().get()
        for _ in range(self.max_iters):
            last_message = messages[-1]
            if not last_message["role"] == "assistant":
                break
            if last_message["tool_calls"] is None:
                break
            if len(last_message["tool_calls"]) == 0:
                break
            for element in self.elements:
                query, tools, env, messages, extra_args = element.query(query, tools, env, messages, extra_args)
                logger.log(messages)
        return query, tools, env, messages, extra_args
