from typing import Dict, Any, List, Optional, Callable, Union, Awaitable
import asyncio
import inspect
from .base import Server, Tool

class UnifiedTool:
    """
    A unified wrapper for both function-based tools and MCP server tools.
    This allows consistent handling of different tool types in the Agent system.
    """
    def __init__(
        self,
        name: str,
        description: str,
        func: Optional[Callable] = None,
        server: Optional[Server] = None,
        input_schema: Optional[Dict[str, Any]] = None
    ):
        """
        Initialize a unified tool.
        
        Args:
            name: The name of the tool
            description: A description of what the tool does
            func: Function to call (for function-based tools)
            server: Server object (for MCP server tools)
            input_schema: JSON schema describing the input parameters
        """
        self.name = name
        self.description = description
        self.func = func
        self.server = server
        self.input_schema = input_schema or {}
        self.is_async = func is not None and asyncio.iscoroutinefunction(func)
        
    async def execute(self, **kwargs) -> Any:
        """
        Execute the tool with the provided arguments.
        
        Args:
            **kwargs: Arguments to pass to the tool
            
        Returns:
            The result of the tool execution
        """
        if self.func is not None:
            # Function-based tool
            if self.is_async:
                return await self.func(**kwargs)
            else:
                # Run synchronous function in a thread pool
                return await asyncio.to_thread(self.func, **kwargs)
        elif self.server is not None:
            # MCP server tool
            return await self.server.execute_tool(self.name, kwargs)
        else:
            raise ValueError(f"Tool {self.name} has no implementation")
    
    def format_for_llm(self) -> str:
        """
        Format tool information for LLM.
        
        Returns:
            A formatted string describing the tool.
        """
        args_desc = []
        if "properties" in self.input_schema:
            for param_name, param_info in self.input_schema["properties"].items():
                arg_desc = (
                    f"- {param_name}: {param_info.get('description', 'No description')}"
                )
                if param_name in self.input_schema.get("required", []):
                    arg_desc += " (required)"
                args_desc.append(arg_desc)
        
        return f"""
Tool: {self.name}
Description: {self.description}
Arguments:
{chr(10).join(args_desc)}
"""

    def get_tool_config(self) -> Dict[str, Any]:
        """
        Get the tool configuration for LLM function calling.
        
        Returns:
            A dictionary with the tool configuration
        """
        return {
            "type": "function",
            "function": {
                "name": self.name,
                "description": self.description,
                "parameters": self.input_schema
            }
        }


class ToolRegistry:
    """
    Registry for managing and accessing unified tools.
    """
    def __init__(self):
        """Initialize an empty tool registry."""
        self.tools: Dict[str, UnifiedTool] = {}
        
    def register_function(
        self,
        func: Callable,
        name: Optional[str] = None,
        description: Optional[str] = None,
        input_schema: Optional[Dict[str, Any]] = None
    ) -> UnifiedTool:
        """
        Register a function as a tool.
        
        Args:
            func: The function to register
            name: Optional name override (defaults to function name)
            description: Optional description (defaults to function docstring)
            input_schema: Optional input schema (defaults to generated from function signature)
            
        Returns:
            The registered UnifiedTool
        """
        name = name or func.__name__
        description = description or func.__doc__ or f"Execute the {name} function"
        
        # Generate input schema from function signature if not provided
        if input_schema is None:
            input_schema = self._generate_schema_from_function(func)
            
        tool = UnifiedTool(
            name=name,
            description=description,
            func=func,
            input_schema=input_schema
        )
        
        self.tools[name] = tool
        return tool
    
    async def register_server_tools(self, server: Server) -> List[UnifiedTool]:
        """
        Register all tools from an MCP server.
        
        Args:
            server: The server to register tools from
            
        Returns:
            List of registered UnifiedTools
        """
        registered_tools = []
        
        # Get tools from server
        server_tools = await server.list_tools()
        
        for tool_info in server_tools:
            tool = UnifiedTool(
                name=tool_info.name,
                description=tool_info.description,
                server=server,
                input_schema=tool_info.input_schema
            )
            
            self.tools[tool_info.name] = tool
            registered_tools.append(tool)
            
        return registered_tools
    
    def get_tool(self, name: str) -> Optional[UnifiedTool]:
        """
        Get a tool by name.
        
        Args:
            name: The name of the tool to get
            
        Returns:
            The tool if found, None otherwise
        """
        return self.tools.get(name)
    
    def get_all_tools(self) -> List[UnifiedTool]:
        """
        Get all registered tools.
        
        Returns:
            List of all registered tools
        """
        return list(self.tools.values())
    
    def get_tool_descriptions(self) -> str:
        """
        Get formatted descriptions of all tools for LLM.
        
        Returns:
            String with formatted tool descriptions
        """
        return "\n\n".join(tool.format_for_llm() for tool in self.tools.values())
    
    def get_tool_configs(self) -> List[Dict[str, Any]]:
        """
        Get tool configurations for LLM function calling.
        
        Returns:
            List of tool configurations
        """
        return [tool.get_tool_config() for tool in self.tools.values()]
    
    def _generate_schema_from_function(self, func: Callable) -> Dict[str, Any]:
        """
        Generate a JSON schema from a function's signature.
        
        Args:
            func: The function to generate a schema for
            
        Returns:
            A JSON schema describing the function's parameters
        """
        signature = inspect.signature(func)
        properties = {}
        required = []
        
        for name, param in signature.parameters.items():
            # Skip self parameter for methods
            if name == "self":
                continue
                
            param_schema = {"type": "string"}  # Default to string
            
            # Try to infer type from annotation
            if param.annotation != inspect.Parameter.empty:
                if param.annotation == int:
                    param_schema["type"] = "integer"
                elif param.annotation == float:
                    param_schema["type"] = "number"
                elif param.annotation == bool:
                    param_schema["type"] = "boolean"
                elif param.annotation == list or param.annotation == List:
                    param_schema["type"] = "array"
                elif param.annotation == dict or param.annotation == Dict:
                    param_schema["type"] = "object"
            
            # Add description from docstring if available
            if func.__doc__:
                param_docs = self._extract_param_docs(func.__doc__, name)
                if param_docs:
                    param_schema["description"] = param_docs
            
            properties[name] = param_schema
            
            # If parameter has no default value, it's required
            if param.default == inspect.Parameter.empty:
                required.append(name)
        
        schema = {
            "type": "object",
            "properties": properties
        }
        
        if required:
            schema["required"] = required
            
        return schema
    
    def _extract_param_docs(self, docstring: str, param_name: str) -> Optional[str]:
        """
        Extract parameter documentation from a docstring.
        
        Args:
            docstring: The function's docstring
            param_name: The parameter name to look for
            
        Returns:
            The parameter description if found, None otherwise
        """
        lines = docstring.split("\n")
        param_marker = f":param {param_name}:" 
        
        for i, line in enumerate(lines):
            if param_marker in line:
                # Extract the description part
                description = line.split(param_marker, 1)[1].strip()
                
                # Check for multi-line descriptions
                j = i + 1
                while j < len(lines) and not lines[j].strip().startswith(":"):
                    description += " " + lines[j].strip()
                    j += 1
                    
                return description
                
        return None
