from string import Template
from typing import List, NamedTuple, Optional, Union

from langchain.tools import BaseTool
from pydantic import Field


from agentverse.memory import BaseMemory, ChatHistoryMemory
from agentverse.message import Message
from agentverse.utils import AgentAction, AgentFinish
from agentverse.logging import logger

# from . import agent_registry
# from .base import BaseAgent

from agentverse.agents import agent_registry
from agentverse.agents.base import BaseAgent


class ToolNotExistError(BaseException):
    """Exception raised when parsing output from a command fails."""

    def __init__(self, tool_name=""):
        self.tool_name = tool_name

    def __str__(self):
        return f"Tool {self.tool_name} does not exist."


@agent_registry.register("tool")
class ToolAgent(BaseAgent):
    tools: List[BaseTool] = Field(default=[])
    tool_memory: BaseMemory = Field(default_factory=ChatHistoryMemory)
    verbose: bool = Field(default=False)

    def step(self, env_description: str = "") -> Message:
        parsed_response = None
        tool_observation = [self.tool_memory.to_string()]
        while True:
            prompt = self._fill_prompt_template(env_description, tool_observation)

            for i in range(self.max_retry):
                try:
                    response = self.llm.generate_response(prompt)
                    parsed_response = self.output_parser.parse(response)
                    if isinstance(parsed_response, AgentAction):
                        observation = self._call_tool(parsed_response)
                        tool_observation.append(
                            parsed_response.log.strip()
                            + f"\nObservation: {observation.strip()}"
                        )
                    break
                except BaseException as e:
                    logger.error(e)
                    logger.warn("Retrying...")
                    continue
            if parsed_response is None or isinstance(parsed_response, AgentFinish):
                break

        if parsed_response is None:
            logger.error(f"{self.name} failed to generate valid response.")

        self._update_tool_memory(tool_observation)

        message = Message(
            content=""
            if parsed_response is None
            else parsed_response.return_values["output"],
            sender=self.name,
            receiver=self.get_receiver(),
        )
        return message

    async def astep(self, env_description: str = "") -> Message:
        """Asynchronous version of step"""
        parsed_response = None
        # Initialize the tool_observation with tool_memory
        tool_observation = [self.tool_memory.to_string()]
        while True:
            prompt = self._fill_prompt_template(env_description, tool_observation)

            for i in range(self.max_retry):
                try:
                    response = await self.llm.agenerate_response(prompt)
                    parsed_response = self.output_parser.parse(response)
                    if isinstance(parsed_response, AgentAction):
                        # If the response is an action, call the tool
                        # and append the observation to tool_observation
                        observation = await self._acall_tool(parsed_response)
                        tool_observation.append(
                            parsed_response.log.strip()
                            + f"\nObservation: {observation.strip()}"
                        )
                    break
                except BaseException as e:
                    logger.error(e)
                    logger.warn("Retrying...")
                    continue
            if parsed_response is None or isinstance(parsed_response, AgentFinish):
                break

        if parsed_response is None:
            logger.error(f"{self.name} failed to generate valid response.")

        self._update_tool_memory(tool_observation)

        message = Message(
            content=""
            if parsed_response is None
            else parsed_response.return_values["output"],
            sender=self.name,
            receiver=self.get_receiver(),
        )
        return message

    def _call_tool(self, response: NamedTuple) -> str:
        """Call a tool and return the output"""
        name_to_tool = {tool.name: tool for tool in self.tools}
        if response.tool not in name_to_tool:
            raise ToolNotExistError(response.tool)
        tool = name_to_tool[response.tool]
        observation = tool.run(response.tool_input, verbose=self.verbose)
        return observation

    async def _acall_tool(self, response: NamedTuple) -> str:
        """Call a tool and return the output"""
        name_to_tool = {tool.name: tool for tool in self.tools}
        if response.tool not in name_to_tool:
            raise ToolNotExistError(response.tool)
        tool = name_to_tool[response.tool]
        observation = await tool.arun(response.tool_input, verbose=self.verbose)
        return observation

    def _update_tool_memory(self, tool_observation: List[str]):
        """Update the memory of the tool"""
        if len(tool_observation) == 1:
            # If no tool is called this turn, do nothing
            return
        messages = [
            Message(content=observation) for observation in tool_observation[1:]
        ]
        self.tool_memory.add_message(messages)

    def _fill_prompt_template(
        self, env_description: str = "", tool_observation: List[str] = []
    ) -> str:
        """Fill the placeholders in the prompt template

        In the tool agent, these placeholders are supported:
        - ${agent_name}: the name of the agent
        - ${env_description}: the description of the environment
        - ${role_description}: the description of the role of the agent
        - ${chat_history}: the chat history of the agent
        - ${tools}: the list of tools and their usage
        - ${tool_names}: the list of tool names
        - ${tool_observations}: the observation of the tool in this turn
        """
        tools = "\n".join([f"> {tool.name}: {tool.description}" for tool in self.tools])
        tools = tools.replace("{{", "{").replace("}}", "}")
        tool_names = ", ".join([tool.name for tool in self.tools])
        input_arguments = {
            "agent_name": self.name,
            "env_description": env_description,
            "role_description": self.role_description,
            "chat_history": self.memory.to_string(add_sender_prefix=True),
            "tools": tools,
            "tool_names": tool_names,
            "tool_observation": "\n".join(tool_observation),
        }
        return Template(self.prompt_template).safe_substitute(input_arguments)

    def add_message_to_memory(self, messages: List[Message]) -> None:
        self.memory.add_message(messages)

    def reset(self) -> None:
        """Reset the agent"""
        self.memory.reset()
        # TODO: reset receiver
