"""
Intelligent tool selector using semantic search.

This module provides a smart tool selection mechanism that uses embeddings
to find the most relevant tools for a given task, solving the OpenAI 128 tool limit.
"""

import asyncio
from typing import List, Dict, Optional, Set, Callable, Any
from collections import defaultdict

from mcp.types import Tool, ListToolsResult
from mcp_agent.agents.agent import Agent
from mcp_agent.utils.tool_embeddings import ToolEmbeddingStore
from mcp_agent.utils.tool_filter import ToolFilter
from mcp_agent.logging.logger import get_logger

logger = get_logger(__name__)


class IntelligentToolSelector:
    """
    Intelligent tool selector that uses semantic search to find relevant tools.
    
    This selector addresses the OpenAI 128 tool limit by:
    1. Indexing all available tools with embeddings
    2. Using semantic search to find the most relevant tools for a query
    3. Dynamically selecting a subset of tools to stay under the limit
    """
    
    def __init__(
        self,
        embedding_store: Optional[ToolEmbeddingStore] = None,
        max_tools: int = 120,  # Leave some buffer under 128
        min_similarity: float = 0.3,
        always_include: Optional[List[str]] = None,
        embedding_model: str = "text-embedding-3-large",
        persist_directory: Optional[str] = None,
        openai_api_key: Optional[str] = None
    ):
        """
        Initialize the intelligent tool selector.
        
        Args:
            embedding_store: Pre-configured embedding store (optional)
            max_tools: Maximum number of tools to select (default: 120)
            min_similarity: Minimum similarity score to include a tool
            always_include: List of tool names to always include
            embedding_model: OpenAI embedding model to use
            persist_directory: Directory to persist embeddings
            openai_api_key: OpenAI API key
        """
        self.max_tools = max_tools
        self.min_similarity = min_similarity
        self.always_include = set(always_include or [])
        
        # Initialize or use provided embedding store
        if embedding_store:
            self.embedding_store = embedding_store
        else:
            self.embedding_store = ToolEmbeddingStore(
                embedding_model=embedding_model,
                persist_directory=persist_directory,
                openai_api_key=openai_api_key
            )
        
        # Cache for tool objects
        self._tool_cache: Dict[str, Tool] = {}
        self._initialized = False
        
    async def initialize_from_agent(self, agent: Agent) -> None:
        """
        Initialize the selector with tools from an agent.
        
        Args:
            agent: The agent to extract tools from
        """
        logger.info(f"Initializing tool selector from agent: {agent.name}")
        
        # Ensure agent is initialized
        if not agent.initialized:
            await agent.initialize()
        
        # Get all tools from the agent
        all_tools_result = await agent.list_tools()
        all_tools = all_tools_result.tools
        
        logger.info(f"Found {len(all_tools)} total tools from agent")
        
        # Check if collection already has tools
        collection_count = self.embedding_store.collection.count()
        logger.info(f"Current collection has {collection_count} tools")
        
        # Only clear and re-add if collection is empty or has different number of tools
        if collection_count != len(all_tools):
            logger.info(f"Collection size mismatch ({collection_count} vs {len(all_tools)}), rebuilding...")
            
            # Clear existing embeddings
            self.embedding_store.clear_collection()
            
            # Group tools by server
            tools_by_server = defaultdict(list)
            
            for tool in all_tools:
                # Extract server name from namespaced tool name
                if "__" in tool.name:
                    server_name = tool.name.split("__")[0]
                else:
                    server_name = ""
                
                tools_by_server[server_name].append(tool)
                
                # Cache the tool
                self._tool_cache[tool.name] = tool
            
            # Add tools to embedding store by server
            for server_name, server_tools in tools_by_server.items():
                await self.embedding_store.add_tools(
                    tools=server_tools,
                    server_name=server_name
                )
        else:
            logger.info(f"Collection already has {collection_count} tools, using existing embeddings")
            
            # Still need to populate the tool cache
            for tool in all_tools:
                self._tool_cache[tool.name] = tool
        
        self._initialized = True
        
        # Log statistics
        stats = self.embedding_store.get_collection_stats()
        logger.info(f"Tool selector initialized: {stats}")
    
    async def select_tools_for_query(
        self,
        query: str,
        additional_context: Optional[str] = None,
        prefer_servers: Optional[List[str]] = None,
        exclude_servers: Optional[List[str]] = None
    ) -> List[Tool]:
        """
        Select the most relevant tools for a given query.
        
        Args:
            query: The task or query to find tools for
            additional_context: Additional context to help with tool selection
            prefer_servers: List of server names to prefer
            exclude_servers: List of server names to exclude
        
        Returns:
            List of selected tools (up to max_tools)
        """
        if not self._initialized:
            raise RuntimeError("Tool selector not initialized. Call initialize_from_agent first.")
        
        # Check collection status before search
        collection_count = self.embedding_store.collection.count()
        logger.info(f"Before search: Collection has {collection_count} tools, cache has {len(self._tool_cache)} tools")
        
        # Enhance query with additional context
        search_query = query
        if additional_context:
            search_query = f"{query} | Context: {additional_context}"
        
        logger.debug(f"Selecting tools for query: {search_query}")
        
        # Search for relevant tools
        search_results = await self.embedding_store.search_tools(
            query=search_query,
            top_k=self.max_tools * 2,  # Get more results for filtering
            min_similarity=self.min_similarity
        )
        
        logger.info(f"Search returned {len(search_results)} results")
        
        # Filter and rank results
        selected_tools = []
        selected_names = set()
        
        # First, add always-include tools
        for tool_name in self.always_include:
            if tool_name in self._tool_cache and tool_name not in selected_names:
                selected_tools.append(self._tool_cache[tool_name])
                selected_names.add(tool_name)
        
        # Process search results
        for tool_id, similarity, metadata in search_results:
            if len(selected_tools) >= self.max_tools:
                break
            
            tool_name = metadata.get("namespaced_name", tool_id)
            server_name = metadata.get("server_name", "")
            
            # Skip if already selected
            if tool_name in selected_names:
                continue
            
            # Apply server filters
            if exclude_servers and server_name in exclude_servers:
                continue
            
            # Boost score for preferred servers
            if prefer_servers and server_name in prefer_servers:
                similarity *= 1.2  # 20% boost
            
            # Get tool from cache
            if tool_name in self._tool_cache:
                selected_tools.append(self._tool_cache[tool_name])
                selected_names.add(tool_name)
                
                logger.debug(
                    f"Selected tool: {tool_name} "
                    f"(server: {server_name}, similarity: {similarity:.3f})"
                )
            else:
                logger.warning(f"Tool {tool_name} found in search but not in cache!")
        
        logger.info(
            f"Selected {len(selected_tools)} tools for query "
            f"(from {len(search_results)} candidates)"
        )
        
        return selected_tools
    
    def create_dynamic_filter(
        self,
        query: str,
        additional_context: Optional[str] = None,
        prefer_servers: Optional[List[str]] = None,
        exclude_servers: Optional[List[str]] = None
    ) -> ToolFilter:
        """
        Create a ToolFilter that dynamically selects tools based on the query.
        
        This returns a ToolFilter that can be applied to an LLM to limit
        the tools it sees based on semantic relevance.
        
        Args:
            query: The task or query to find tools for
            additional_context: Additional context
            prefer_servers: Preferred servers
            exclude_servers: Excluded servers
        
        Returns:
            ToolFilter configured for dynamic selection
        """
        # Create an async wrapper for the filter
        selected_tool_names: Set[str] = set()
        
        # We need to run the async selection synchronously for the filter
        async def _get_selected_tools():
            tools = await self.select_tools_for_query(
                query=query,
                additional_context=additional_context,
                prefer_servers=prefer_servers,
                exclude_servers=exclude_servers
            )
            return {tool.name for tool in tools}
        
        # Run the async function
        loop = asyncio.new_event_loop()
        try:
            selected_tool_names = loop.run_until_complete(_get_selected_tools())
        finally:
            loop.close()
        
        # Create custom filter function
        def custom_filter(tool: Tool) -> bool:
            return tool.name in selected_tool_names
        
        return ToolFilter(custom_filter=custom_filter)
    
    async def get_tool_selection_explanation(
        self,
        query: str,
        selected_tools: List[Tool]
    ) -> str:
        """
        Generate an explanation of why certain tools were selected.
        
        Args:
            query: The original query
            selected_tools: The tools that were selected
        
        Returns:
            Human-readable explanation
        """
        explanation_parts = [
            f"Tool selection for query: '{query}'",
            f"Selected {len(selected_tools)} tools from {len(self._tool_cache)} available tools:",
            ""
        ]
        
        # Get similarity scores for selected tools
        search_results = await self.embedding_store.search_tools(
            query=query,
            top_k=len(self._tool_cache)
        )
        
        # Create a map of tool names to similarity scores
        similarity_map = {
            result[2].get("namespaced_name", result[0]): result[1]
            for result in search_results
        }
        
        # Add tool explanations
        for i, tool in enumerate(selected_tools[:10], 1):  # Show top 10
            similarity = similarity_map.get(tool.name, 0.0)
            explanation_parts.append(
                f"{i}. {tool.name} (similarity: {similarity:.3f})"
            )
            if tool.description:
                explanation_parts.append(f"   - {tool.description}")
        
        if len(selected_tools) > 10:
            explanation_parts.append(f"... and {len(selected_tools) - 10} more tools")
        
        return "\n".join(explanation_parts)


# Convenience function for creating and applying intelligent tool selection
async def apply_intelligent_tool_selection(
    agent: Agent,
    query: str,
    max_tools: int = 120,
    additional_context: Optional[str] = None,
    persist_directory: Optional[str] = None,
    openai_api_key: Optional[str] = None
) -> Agent:
    """
    Apply intelligent tool selection to an agent for a specific query.
    
    This is a convenience function that:
    1. Creates an IntelligentToolSelector
    2. Initializes it with the agent's tools
    3. Creates a dynamic filter for the query
    4. Returns the agent with the filter applied
    
    Args:
        agent: The agent to apply selection to
        query: The task or query
        max_tools: Maximum number of tools to select
        additional_context: Additional context
        persist_directory: Directory to persist embeddings
        openai_api_key: OpenAI API key
    
    Returns:
        The agent with intelligent tool selection applied
    """
    # Create selector
    selector = IntelligentToolSelector(
        max_tools=max_tools,
        persist_directory=persist_directory,
        openai_api_key=openai_api_key
    )
    
    # Initialize with agent's tools
    await selector.initialize_from_agent(agent)
    
    # Create dynamic filter
    tool_filter = selector.create_dynamic_filter(
        query=query,
        additional_context=additional_context
    )
    
    # Apply filter to agent's LLM (if attached)
    if hasattr(agent, 'llm') and agent.llm:
        from mcp_agent.utils.tool_filter import apply_tool_filter
        agent.llm = apply_tool_filter(agent.llm, tool_filter)
    
    return agent 