from base import Server, Configuration, Tool
from llm_service import client
import asyncio
import logging
 
client = client()

class LLMClient:
    """Manages communication with the LLM provider."""

    def __init__(self, model="gpt-4o-mini") -> None:
        self.model = model

    def get_response(self, messages: list[dict[str, str]]) -> str:
        """Get a response from the LLM.

        Args:
            messages: A list of message dictionaries.

        Returns:
            The LLM's response as a string.

        """
        try:
            response = client.chat.completions.create(
                model=self.model,
                messages=messages,
                response_format={"type": "json_object"}
            )
            return response.choices[0].message.content
        except Exception as e:
            raise RuntimeError(f"Error getting LLM response: {e}")

class ChatSession:
    """Orchestrates the interaction between user, LLM, and tools."""

    def __init__(self, servers: list[Server], llm_client: LLMClient) -> None:
        self.servers: list[Server] = servers
        self.llm_client: LLMClient = llm_client

    async def cleanup_servers(self) -> None:
        """Clean up all servers properly."""
        cleanup_tasks = [
            asyncio.create_task(server.cleanup()) for server in self.servers
        ]
        if cleanup_tasks:
            try:
                await asyncio.gather(*cleanup_tasks, return_exceptions=True)
            except Exception as e:
                logging.warning(f"Warning during final cleanup: {e}")

    async def process_llm_response(self, llm_response: str) -> str:
        """Process the LLM response and execute tools if needed.

        Args:
            llm_response: The response from the LLM.

        Returns:
            The result of tool execution or the original response.
        """
        import json

        try:
            tool_call = json.loads(llm_response)
            if "tool" in tool_call and "arguments" in tool_call:
                logging.info(f"Executing tool: {tool_call['tool']}")
                logging.info(f"With arguments: {tool_call['arguments']}")

                for server in self.servers:
                    tools = await server.list_tools()
                    if any(tool.name == tool_call["tool"] for tool in tools):
                        try:
                            result = await server.execute_tool(
                                tool_call["tool"], tool_call["arguments"]
                            )
        
                            if isinstance(result, dict) and "progress" in result:
                                progress = result["progress"]
                                total = result["total"]
                                percentage = (progress / total) * 100
                                logging.info(
                                    f"Progress: {progress}/{total} ({percentage:.1f}%)"
                                )

                            return f"Tool execution result: {result}"
                        except Exception as e:
                            error_msg = f"Error executing tool: {str(e)}"
                            logging.error(error_msg)
                            return error_msg

                return f"No server found with tool: {tool_call['tool']}"
            return llm_response
        except json.JSONDecodeError:
            return llm_response

    async def start(self) -> None:
        """Main chat session handler."""
        try:
            for server in self.servers:
                try:
                    await server.initialize()
                except Exception as e:
                    logging.error(f"Failed to initialize server: {e}")
                    await self.cleanup_servers()
                    return

            all_tools = []
            for server in self.servers:
                tools = await server.list_tools()
                all_tools.extend(tools)

            tools_description = "\n".join([tool.format_for_llm() for tool in all_tools])

            system_message = (
                "You are a helpful assistant with access to these tools:\n\n"
                f"{tools_description}\n"
                "Choose the appropriate tool based on the user's question. "
                "If no tool is needed, reply directly.\n\n"
                "IMPORTANT: When you need to use a tool, you must ONLY respond with "
                "the exact JSON object format below, nothing else:\n"
                "{\n"
                '    "tool": "tool-name",\n'
                '    "arguments": {\n'
                '        "argument-name": "value"\n'
                "    }\n"
                "}\n\n"
                "Please use only the tools that are explicitly defined above."
            )

            user_prompt = """
        Please answer the user's question concisely based on the context and the results of the previous tool call.
        Here is the tool result:
        {result}
        output format must be json: {{'result': str, 'observation': str}}. result is the final answer; observation is the process of the answer.
        """
            messages = [{"role": "system", "content": system_message}]

            while True:
                try:
                    user_input = input("You: ").strip().lower()
                    if user_input in ["quit", "exit"]:
                        logging.info("\nExiting...")
                        break

                    messages.append({"role": "user", "content": user_input})

                    llm_response = self.llm_client.get_response(messages)
                    logging.info("\nAssistant: %s", llm_response)

                    result = await self.process_llm_response(llm_response)

                    if result != llm_response:
                        result = result.content[0].text
                        messages.append({"role": "assistant", "content": llm_response})
                        messages.append({"role": "system", "content": result})

                        final_response = self.llm_client.get_response(messages)
                        logging.info("\nFinal response: %s", final_response)
                        messages.append(
                            {"role": "assistant", "content": final_response}
                        )
                    else:
                        messages.append({"role": "assistant", "content": llm_response})

                except KeyboardInterrupt:
                    logging.info("\nExiting...")
                    break

        finally:
            await self.cleanup_servers()


async def main2() -> None:
    """Initialize and run the chat session."""
    config = Configuration()
    server_config = config.load_config("servers_config.json")
    servers = [
        Server(name, srv_config)
        for name, srv_config in server_config["mcpServers"].items()
    ]
    llm_client = LLMClient()
    chat_session = ChatSession(servers, llm_client)
    await chat_session.start()

async def main(question: str) -> None:
    """Initialize and run the chat session."""
    config = Configuration()
    server_config = config.load_config("tools/servers_config.json")

    servers = [
        Server(name, srv_config)
        for name, srv_config in server_config["mcpServers"].items()
    ]

    for server in servers:
        await server.initialize()

    
    all_tools = []
    tool_dict = {}
    for server in servers:
        tools = await server.list_tools()
        tool_dict[server.name] = []
        for tool in tools:
            all_tools.append(tool)
            tool_dict[server.name].append(tool.name)


    available_tools = [{
        "type": "function",
        "function": {
            "name": tool.name,
            "description": tool.description,
            "input_schema": tool.inputSchema
        }
    } for tool in all_tools]

    sys_prompt = """
You are an AI assistant that is very good at using tools.  
Please answer the following questions as well as possible. You can use the following tools:  

{tools}  

Use the following format:  

Question: The input question you must answer  
"Choose the appropriate tool based on the user's question. "
"If no tool is needed, reply directly.\n\n"
"IMPORTANT: When you need to use a tool, you must ONLY respond with "
If tools are needed, please be sure to provide the tool name and arguments.

"Please use only the tools that are explicitly defined above."

Question: {input}
"""

    user_prompt = sys_prompt.format(
        tools=available_tools,
        input=question
    )

    messages = [{"role": "system", "content": sys_prompt}]

    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=messages,
        tools=available_tools
        )
    content = response.choices[0]

    if content.finish_reason == "tool_calls":
        tool_call = content.message.tool_calls[0]
        tool_name = tool_call.function.name
        tool_args = json.loads(tool_call.function.arguments)

        result = await server.call_tool(tool_name, tool_args)
        print(f"\n\n[Calling tool {tool_name} with args {tool_args}]\n\n")
        
        messages.append(content.message.model_dump())
        messages.append({
            "role": "tool",
            "content": result.content[0].text,
            "tool_call_id": tool_call.id,
        })
        user_prompt = f"""
        Please answer the user's question concisely based on the context and the results of the previous tool call.
        output format must be json: {{'result': str, 'observation': str}}. result is the final answer; observation is the process of the answer.
        """
        messages.append({"role": "user", "content": user_prompt})

        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=messages,
            output_format={"type": "json_object"}
        )

        result = response.choices[0].message.content
        print(result)


if __name__ == "__main__":
    asyncio.run(main2())



