"""Refactored search engine for ChromaDB interaction."""

import os
import json
import logging
import time
import random
from pathlib import Path
from typing import List, Dict, Any, Optional, Set, Tuple
from dataclasses import dataclass, field
import chromadb
from chromadb.utils import embedding_functions


@dataclass
class SearchResult:
    """Represents a single search result."""
    document_id: str
    content: str
    metadata: Dict[str, Any]
    score: float = 0.0
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary representation."""
        return {
            'id': self.document_id,
            'content': self.content,
            'metadata': self.metadata,
            'score': self.score
        }


@dataclass
class SearchState:
    """Maintains state for paginated search results."""
    query: str
    results: List[SearchResult] = field(default_factory=list)
    current_page: int = 0
    page_size: int = 5
    total_results: int = 0
    found_document_ids: Set[str] = field(default_factory=set)
    
    def get_page(self, page_num: int) -> List[SearchResult]:
        """Get a specific page of results."""
        start_idx = page_num * self.page_size
        end_idx = start_idx + self.page_size
        return self.results[start_idx:end_idx]
    
    def has_next_page(self) -> bool:
        """Check if there are more pages available."""
        return (self.current_page + 1) * self.page_size < len(self.results)


class SearchEngine:
    """Clean, refactored search engine for ChromaDB."""
    
    def __init__(
        self,
        chromadb_path: str,
        collection_name: str = "default",
        embedding_model: Optional[str] = None,
        cache_size: int = 100,
        engine_id: str = "default",
        logger: Optional[logging.Logger] = None
    ):
        """Initialize search engine.
        
        Args:
            chromadb_path: Path to ChromaDB database
            collection_name: Name of the collection to use
            embedding_model: Specific embedding model to use (auto-detects if OpenAI)
            cache_size: Number of results to cache
            engine_id: Unique identifier for this search engine instance
            logger: Logger instance
        """
        self.engine_id = engine_id
        self.chromadb_path = Path(chromadb_path).expanduser()
        self.collection_name = collection_name
        self.embedding_model = embedding_model
        self.cache_size = cache_size
        self.logger = logger or logging.getLogger(__name__)
        
        # Initialize ChromaDB client with retry for concurrent access issues
        self.client = self._create_client_with_retry()
        
        # Load embedding function using extensible loader
        self.embedding_function = self._load_embedding_function(embedding_model)
        
        # Get or create collection
        try:
            if self.embedding_function:
                self.collection = self.client.get_or_create_collection(
                    name=collection_name,
                    embedding_function=self.embedding_function
                )
            else:
                self.collection = self.client.get_or_create_collection(name=collection_name)
            
            self.logger.info(f"Connected to collection '{collection_name}' with {self.collection.count()} documents")
        except Exception as e:
            self.logger.error(f"Failed to connect to ChromaDB: {e}")
            raise
        
        # Initialize search state
        self.current_search_state: Optional[SearchState] = None
    
    def _create_client_with_retry(self, max_retries: int = 5, base_delay: float = 0.1):
        """Create ChromaDB client with retry mechanism for concurrent access issues.
        
        Args:
            max_retries: Maximum number of retry attempts
            base_delay: Base delay between retries (exponential backoff)
            
        Returns:
            ChromaDB PersistentClient instance
            
        Raises:
            Exception: If all retries fail
        """
        last_exception = None
        
        for attempt in range(max_retries + 1):
            try:
                client = chromadb.PersistentClient(path=str(self.chromadb_path))
                if attempt > 0:
                    self.logger.info(f"ChromaDB client created successfully on attempt {attempt + 1}")
                return client
                
            except Exception as e:
                last_exception = e
                
                if "default_tenant" in str(e) or "tenant" in str(e).lower():
                    if attempt < max_retries:
                        # Exponential backoff with jitter for concurrent access
                        delay = base_delay * (2 ** attempt) + random.uniform(0, 0.1)
                        self.logger.warning(
                            f"ChromaDB tenant connection failed (attempt {attempt + 1}/{max_retries + 1}): {e}. "
                            f"Retrying in {delay:.2f}s..."
                        )
                        time.sleep(delay)
                        continue
                    else:
                        self.logger.error(f"ChromaDB tenant connection failed after {max_retries + 1} attempts")
                else:
                    # Non-tenant error, don't retry
                    raise
        
        # All retries exhausted
        raise last_exception
    
    def _load_embedding_function(self, embedding_model: Optional[str]):
        """Load the appropriate embedding function based on the model name.
        
        Simple, explicit mapping of model names to embedding functions.
        
        Args:
            embedding_model: Name of the embedding model
            
        Returns:
            Embedding function or None for default ChromaDB embeddings
        """
        if not embedding_model or embedding_model == "default":
            self.logger.info("Using default ChromaDB embedding function")
            return None
        
        elif embedding_model == "openai-small":
            if not os.environ.get("OPENAI_API_KEY"):
                raise ValueError("OpenAI embedding model requires OPENAI_API_KEY environment variable")
            self.logger.info("Loading OpenAI small embedding model: text-embedding-3-small")
            return embedding_functions.OpenAIEmbeddingFunction(
                api_key=os.environ["OPENAI_API_KEY"],
                model_name="text-embedding-3-small"
            )
        
        elif embedding_model == "openai-large":
            if not os.environ.get("OPENAI_API_KEY"):
                raise ValueError("OpenAI embedding model requires OPENAI_API_KEY environment variable")
            self.logger.info("Loading OpenAI large embedding model: text-embedding-3-large")
            return embedding_functions.OpenAIEmbeddingFunction(
                api_key=os.environ["OPENAI_API_KEY"],
                model_name="text-embedding-3-large"
            )
        
        # Add other providers here as needed:
        # elif embedding_model == "huggingface-mpnet":
        #     return self._load_huggingface_embedding("sentence-transformers/all-mpnet-base-v2")
        # elif embedding_model == "cohere-multilingual":
        #     return self._load_cohere_embedding("embed-multilingual-v3.0")
        
        else:
            # Unknown model - log warning and use default
            self.logger.warning(f"Unknown embedding model '{embedding_model}', using default ChromaDB embeddings")
            return None
    
    @staticmethod
    def requires_openai_api_key(embedding_model: Optional[str]) -> bool:
        """Static utility to check if an embedding model requires OpenAI API key.
        
        Simple explicit check for OpenAI embedding models.
        """
        return embedding_model in ["openai-small", "openai-large"]

    # Future embedding providers can be added here:
    # 
    # def _is_huggingface_model(self, embedding_model: str) -> bool:
    #     """Check if model is from Hugging Face."""
    #     return embedding_model.startswith(('sentence-transformers/', 'huggingface/'))
    # 
    # def _load_huggingface_embedding(self, embedding_model: str):
    #     """Load Hugging Face embedding function."""
    #     # Implementation for HF models
    #     pass
    # 
    # def _is_cohere_model(self, embedding_model: str) -> bool:
    #     """Check if model is from Cohere."""
    #     return 'cohere' in embedding_model.lower()
    # 
    # def _load_cohere_embedding(self, embedding_model: str):
    #     """Load Cohere embedding function.""" 
    #     # Implementation for Cohere models
    #     pass
    
    def search(
        self,
        query: str,
        n_results: Optional[int] = None,
        filter_metadata: Optional[Dict[str, Any]] = None
    ) -> Tuple[List[SearchResult], Set[str]]:
        """Execute a search query.
        
        Args:
            query: Search query text
            n_results: Number of results to retrieve (defaults to cache_size)
            filter_metadata: Metadata filters to apply
            
        Returns:
            Tuple of (search results, set of found document IDs)
        """
        n_results = n_results or self.cache_size
        
        try:
            # Build query parameters
            query_params = {
                'n_results': n_results,
                'include': ['metadatas', 'documents', 'distances']
            }
            
            # Add query text or embedding
            if self.embedding_function:
                query_params['query_embeddings'] = self.embedding_function([query])
            else:
                query_params['query_texts'] = [query]
            
            # Add metadata filter if provided
            if filter_metadata:
                query_params['where'] = filter_metadata
            
            # Execute query
            results = self.collection.query(**query_params)
            
            # Process results
            search_results = []
            found_ids = set()
            
            if results and results.get('documents') and results['documents'][0]:
                documents = results['documents'][0]
                metadatas = results.get('metadatas', [[]])[0]
                for metadata in metadatas:
                    if 'question_id' in metadata:
                        metadata.pop('question_id')
                ids = results.get('ids', [[]])[0]
                distances = results.get('distances', [[]])[0]
                
                for i in range(len(documents)):
                    result = SearchResult(
                        document_id=ids[i] if i < len(ids) else f"doc_{i}",
                        content=documents[i],
                        metadata=metadatas[i] if i < len(metadatas) else {},
                        score=1.0 - distances[i] if i < len(distances) else 0.0
                    )
                    search_results.append(result)
                    found_ids.add(result.document_id)
            
            # Update search state
            self.current_search_state = SearchState(
                query=query,
                results=search_results,
                total_results=len(search_results),
                found_document_ids=found_ids
            )
            
            self.logger.debug(f"Search for '{query}' returned {len(search_results)} results")
            
            return search_results, found_ids
            
        except Exception as e:
            self.logger.error(f"Search failed for query '{query}': {e}")
            return [], set()
    
    def get_first_page(self) -> Tuple[str, Set[str]]:
        """Get the first page of current search results.
        
        Returns:
            Tuple of (formatted results string, set of document IDs in page)
        """
        if not self.current_search_state or not self.current_search_state.results:
            return "No search results available. Please perform a search first.", set()
        
        state = self.current_search_state
        state.current_page = 0
        page_results = state.get_page(0)
        
        return self._format_results_page(page_results, 0), {r.document_id for r in page_results}
    
    def get_next_page(self) -> Tuple[str, Set[str]]:
        """Get the next page of search results.
        
        Returns:
            Tuple of (formatted results string, set of document IDs in page)
        """
        if not self.current_search_state:
            return "No search results available. Please perform a search first.", set()
        
        state = self.current_search_state
        
        if not state.has_next_page():
            # Reset to beginning if we've reached the end
            state.current_page = 0
            return "No more pages available. Returning to first page.", set()
        
        state.current_page += 1
        page_results = state.get_page(state.current_page)
        
        return self._format_results_page(page_results, state.current_page), {r.document_id for r in page_results}
    
    def _format_results_page(self, results: List[SearchResult], page_num: int) -> str:
        """Format search results for display.
        
        Args:
            results: List of search results
            page_num: Current page number
            
        Returns:
            Formatted string of results
        """
        if not results:
            return f"Page {page_num}: No results"
        
        formatted = [f"# Page {page_num}:"]
        
        for i, result in enumerate(results):
            formatted.append(
                f"\nResult {i + 1}:\n"
                f"  ID: {result.document_id}\n"
                f"  Content: {result.content}\n"
                f"  Metadata: {json.dumps(result.metadata, indent=2)}\n"
                f"  Relevance Score: {result.score:.3f}"
            )
        
        total_pages = (len(self.current_search_state.results) - 1) // self.current_search_state.page_size + 1
        formatted.append(f"\n[Page {page_num + 1} of {total_pages}]")
        
        return "\n".join(formatted)
    
    def get_documents_by_ids(self, document_ids: List[str]) -> List[SearchResult]:
        """Retrieve specific documents by their IDs.
        
        Args:
            document_ids: List of document IDs to retrieve
            
        Returns:
            List of search results
        """
        try:
            results = self.collection.get(ids=document_ids)
            
            search_results = []
            if results and results.get('documents'):
                documents = results['documents']
                metadatas = results.get('metadatas', [])
                ids = results.get('ids', [])
                
                for i in range(len(documents)):
                    result = SearchResult(
                        document_id=ids[i] if i < len(ids) else document_ids[i],
                        content=documents[i],
                        metadata=metadatas[i] if i < len(metadatas) else {}
                    )
                    search_results.append(result)
            
            return search_results
            
        except Exception as e:
            self.logger.error(f"Failed to retrieve documents by IDs: {e}")
            return []
    
    def get_collection_stats(self) -> Dict[str, Any]:
        """Get statistics about the collection.
        
        Returns:
            Dictionary of collection statistics
        """
        try:
            count = self.collection.count()
            
            # Get a sample to determine metadata fields
            sample = self.collection.get(limit=10, include=['metadatas'])
            metadata_fields = set()
            
            if sample and sample.get('metadatas'):
                for metadata in sample['metadatas']:
                    if metadata:
                        metadata_fields.update(metadata.keys())
            
            return {
                'collection_name': self.collection_name,
                'document_count': count,
                'metadata_fields': list(metadata_fields),
                'embedding_function': type(self.embedding_function).__name__ if self.embedding_function else 'Default'
            }
            
        except Exception as e:
            self.logger.error(f"Failed to get collection stats: {e}")
            return {
                'collection_name': self.collection_name,
                'error': str(e)
            }
