"""
Tool embeddings and vector search utilities for intelligent tool selection.

This module provides functionality to embed tool descriptions and perform
semantic search to find the most relevant tools for a given query.
"""

import asyncio
import json
from typing import List, Dict, Optional, Tuple, Any
from dataclasses import dataclass
import numpy as np

try:
    import chromadb
    from chromadb.config import Settings
    CHROMADB_AVAILABLE = True
except ImportError:
    CHROMADB_AVAILABLE = False

try:
    import openai
    OPENAI_AVAILABLE = True
except ImportError:
    OPENAI_AVAILABLE = False

from mcp.types import Tool
from mcp_agent.logging.logger import get_logger

logger = get_logger(__name__)


@dataclass
class ToolMetadata:
    """Metadata for a tool stored in the vector database."""
    tool_name: str
    server_name: str
    namespaced_name: str
    description: str
    input_schema: Dict[str, Any]
    full_tool_object: Tool


class ToolEmbeddingStore:
    """
    Manages tool embeddings and provides semantic search functionality.
    
    This class handles:
    - Converting tool descriptions to embeddings
    - Storing embeddings in a vector database
    - Performing semantic search to find relevant tools
    """
    
    def __init__(
        self,
        embedding_model: str = "text-embedding-3-large",
        collection_name: str = "mcp_tools",
        persist_directory: Optional[str] = None,
        openai_api_key: Optional[str] = None
    ):
        """
        Initialize the tool embedding store.
        
        Args:
            embedding_model: OpenAI embedding model to use
            collection_name: Name of the ChromaDB collection
            persist_directory: Directory to persist the database (None for in-memory)
            openai_api_key: OpenAI API key (if not set in environment)
        """
        if not CHROMADB_AVAILABLE:
            raise ImportError(
                "ChromaDB is required for tool embeddings. "
                "Install it with: pip install chromadb"
            )
        
        if not OPENAI_AVAILABLE:
            raise ImportError(
                "OpenAI is required for embeddings. "
                "Install it with: pip install openai"
            )
        
        self.embedding_model = embedding_model
        self.collection_name = collection_name
        
        # Initialize OpenAI client
        if openai_api_key:
            self.openai_client = openai.AsyncOpenAI(api_key=openai_api_key)
        else:
            self.openai_client = openai.AsyncOpenAI()
        
        # Initialize ChromaDB
        if persist_directory:
            self.chroma_client = chromadb.PersistentClient(
                path=persist_directory,
                settings=Settings(anonymized_telemetry=False)
            )
        else:
            self.chroma_client = chromadb.Client(
                settings=Settings(anonymized_telemetry=False)
            )
        
        # Get or create collection
        try:
            self.collection = self.chroma_client.get_collection(collection_name)
            logger.info(f"Using existing collection: {collection_name}")
        except:
            self.collection = self.chroma_client.create_collection(
                name=collection_name,
                metadata={"hnsw:space": "cosine"}
            )
            logger.info(f"Created new collection: {collection_name}")
    
    async def _get_embedding(self, text: str) -> List[float]:
        """Get embedding for a text using OpenAI API."""
        try:
            response = await self.openai_client.embeddings.create(
                model=self.embedding_model,
                input=text
            )
            return response.data[0].embedding
        except Exception as e:
            logger.error(f"Error getting embedding: {e}")
            raise
    
    async def _get_embeddings_batch(self, texts: List[str]) -> List[List[float]]:
        """Get embeddings for multiple texts in a batch."""
        try:
            response = await self.openai_client.embeddings.create(
                model=self.embedding_model,
                input=texts
            )
            return [item.embedding for item in response.data]
        except Exception as e:
            logger.error(f"Error getting batch embeddings: {e}")
            raise
    
    def _create_tool_description(self, tool: Tool, server_name: str = "") -> str:
        """
        Create a comprehensive description of a tool for embedding.
        
        This combines the tool name, description, and input schema information
        to create a rich text representation for better semantic matching.
        """
        parts = []
        
        # Add tool name
        parts.append(f"Tool: {tool.name}")
        
        # Add server context if available
        if server_name:
            parts.append(f"Server: {server_name}")
        
        # Add description
        if tool.description:
            parts.append(f"Description: {tool.description}")
        
        # Add input schema information
        if tool.inputSchema:
            schema_info = self._extract_schema_info(tool.inputSchema)
            if schema_info:
                parts.append(f"Parameters: {schema_info}")
        
        return " | ".join(parts)
    
    def _extract_schema_info(self, schema: Dict[str, Any]) -> str:
        """Extract human-readable information from a JSON schema."""
        info_parts = []
        
        if "properties" in schema:
            for prop_name, prop_schema in schema["properties"].items():
                prop_type = prop_schema.get("type", "any")
                prop_desc = prop_schema.get("description", "")
                
                if prop_desc:
                    info_parts.append(f"{prop_name} ({prop_type}): {prop_desc}")
                else:
                    info_parts.append(f"{prop_name} ({prop_type})")
        
        return ", ".join(info_parts)
    
    async def add_tools(
        self,
        tools: List[Tool],
        server_name: str = "",
        batch_size: int = 100
    ) -> None:
        """
        Add tools to the embedding store.
        
        Args:
            tools: List of tools to add
            server_name: Name of the server these tools belong to
            batch_size: Number of tools to process in each batch
        """
        if not tools:
            return
        
        logger.info(f"Adding {len(tools)} tools to embedding store...")
        
        # Process tools in batches
        for i in range(0, len(tools), batch_size):
            batch = tools[i:i + batch_size]
            
            # Prepare data for this batch
            ids = []
            documents = []
            metadatas = []
            
            for tool in batch:
                # Create unique ID
                tool_id = f"{server_name}__{tool.name}" if server_name else tool.name
                ids.append(tool_id)
                
                # Create description for embedding
                description = self._create_tool_description(tool, server_name)
                documents.append(description)
                
                # Create metadata
                metadata = {
                    "tool_name": tool.name,
                    "server_name": server_name,
                    "namespaced_name": tool_id,
                    "description": tool.description or "",
                    "input_schema": json.dumps(tool.inputSchema) if tool.inputSchema else "{}",
                }
                metadatas.append(metadata)
            
            # Get embeddings for this batch
            embeddings = await self._get_embeddings_batch(documents)
            
            # Add to ChromaDB
            self.collection.add(
                ids=ids,
                embeddings=embeddings,
                documents=documents,
                metadatas=metadatas
            )
            
            logger.debug(f"Added batch {i//batch_size + 1} ({len(batch)} tools)")
        
        logger.info(f"Successfully added {len(tools)} tools to embedding store")
    
    async def search_tools(
        self,
        query: str,
        top_k: int = 50,
        min_similarity: float = 0.0
    ) -> List[Tuple[str, float, Dict[str, Any]]]:
        """
        Search for tools relevant to a query.
        
        Args:
            query: The search query
            top_k: Number of top results to return
            min_similarity: Minimum similarity score (0-1) to include results
        
        Returns:
            List of tuples containing (tool_id, similarity_score, metadata)
        """
        logger.debug(f"Searching for tools with query: {query}")
        
        # Check collection status
        collection_count = self.collection.count()
        logger.debug(f"Collection '{self.collection_name}' has {collection_count} items")
        
        if collection_count == 0:
            logger.warning(f"Collection '{self.collection_name}' is empty! No tools to search.")
            return []
        
        # Get embedding for query
        query_embedding = await self._get_embedding(query)
        
        # Search in ChromaDB
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=min(top_k, collection_count)  # Don't request more than available
        )
        
        # Debug the raw results
        logger.debug(f"Raw ChromaDB results: {len(results.get('ids', [[]])[0]) if results.get('ids') else 0} items")
        logger.debug(f"Results structure: ids={len(results.get('ids', [[]])[0]) if results.get('ids') else 0}, "
                    f"distances={len(results.get('distances', [[]])[0]) if results.get('distances') else 0}, "
                    f"metadatas={len(results.get('metadatas', [[]])[0]) if results.get('metadatas') else 0}")
        
        # Process results
        tool_results = []
        
        if results["ids"] and results["ids"][0]:
            for i, tool_id in enumerate(results["ids"][0]):
                # ChromaDB returns distances, convert to similarity (1 - distance for cosine)
                distance = results["distances"][0][i] if results["distances"] else 0
                similarity = 1 - distance
                
                logger.debug(f"Tool {i}: {tool_id}, distance={distance:.4f}, similarity={similarity:.4f}, min_similarity={min_similarity}")
                
                if similarity >= min_similarity:
                    metadata = results["metadatas"][0][i] if results["metadatas"] else {}
                    tool_results.append((tool_id, similarity, metadata))
                else:
                    logger.debug(f"Tool {tool_id} filtered out: similarity {similarity:.4f} < {min_similarity}")
        else:
            logger.warning("No results returned from ChromaDB query")
        
        logger.info(f"Found {len(tool_results)} relevant tools for query")
        return tool_results
    
    async def get_tools_by_ids(self, tool_ids: List[str]) -> List[Dict[str, Any]]:
        """
        Retrieve tool metadata by IDs.
        
        Args:
            tool_ids: List of tool IDs to retrieve
        
        Returns:
            List of tool metadata dictionaries
        """
        if not tool_ids:
            return []
        
        results = self.collection.get(ids=tool_ids)
        
        if results["metadatas"]:
            return results["metadatas"]
        
        return []
    
    def clear_collection(self) -> None:
        """Clear all tools from the collection."""
        try:
            self.chroma_client.delete_collection(self.collection_name)
            self.collection = self.chroma_client.create_collection(
                name=self.collection_name,
                metadata={"hnsw:space": "cosine"}
            )
            logger.info(f"Cleared collection: {self.collection_name}")
        except Exception as e:
            logger.error(f"Error clearing collection: {e}")
            raise
    
    def get_collection_stats(self) -> Dict[str, Any]:
        """Get statistics about the tool collection."""
        count = self.collection.count()
        return {
            "collection_name": self.collection_name,
            "total_tools": count,
            "embedding_model": self.embedding_model
        } 