import sys                                
import os                                
import logging                                      
import uuid                          
import json                                              
import shutil                                          
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime, timezone                  
import logging                             
import uuid                          
import json                                              
import os                        
import shutil                                          

from pydantic import ValidationError
import chromadb                              
import numpy as np                       
import hdbscan                               
from sklearn.cluster import KMeans                              

from fortress.common.data_models import DatabasePromptRecord, NLPFeatures, InputPromptRecord
from fortress.config import get_config

logger = logging.getLogger(__name__)

class VectorStoreInterface(ABC):
    """
    Abstract base class for a vector store.
    Defines the common interface for interacting with different vector database implementations.
    """

    @abstractmethod
    def add_documents(self, documents: List[DatabasePromptRecord]) -> Tuple[List[str], List[str]]:
        """Adds documents to the vector store. Returns (succeeded_ids, failed_ids)."""
        pass             

    @abstractmethod
    def query_similar(self, embedding: List[float], top_k: int = 10, filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
        """Queries for similar documents. Returns a list of results with metadata and distance."""
        pass             

    @abstractmethod
    def get_collection_size(self) -> int:
        """Returns the total number of documents in the collection."""
        pass             

    @abstractmethod
    def get_document_by_id(self, doc_id: str) -> Optional[DatabasePromptRecord]:
        """Retrieves a document by its ID."""
        pass             
    
    @abstractmethod
    def delete_documents(self, ids: List[str]) -> Tuple[List[str], List[str]]:
        """Deletes documents by their IDs. Returns (succeeded_ids, failed_ids)."""
        pass             

    @abstractmethod
    def get_all_unsafe_embeddings_with_ids(self, batch_size: int = 1000) -> List[Tuple[str, List[float]]]:
        """Retrieves all embeddings marked as unsafe, along with their IDs."""
        pass             

    @abstractmethod
    def update_document_metadata(self, doc_id: str, metadata_update: Dict[str, Any]) -> bool:
        """Updates the metadata of a specific document."""
        pass             

    @abstractmethod
    def assign_cluster_to_new_prompt(self,
                                     prompt_embedding: List[float],
                                     cluster_field_name: str = "prompt_category",
                                     default_top_k_neighbors: int = 3,
                                     input_label: Optional[int] = None) -> Optional[Any]:                    
        """
        Assigns a cluster label to a new prompt based on its nearest neighbors.
    Optionally filters neighbors by input_label if provided.
        """
        pass

    @abstractmethod
    def assign_clusters_with_weights_to_new_prompt(self,
                                                   prompt_embedding: List[float],
                                                   cluster_field_name: str = "prompt_category",
                                                   top_k_neighbors: int = 3,
                                                   input_label: Optional[int] = None) -> List[Tuple[Any, float]]:
        """
        Assigns cluster labels with weights to a new prompt based on its nearest neighbors.
        Returns a list of tuples (cluster_label, weight) sorted by weight in descending order.
        Optionally filters neighbors by input_label if provided.
        """
        pass

    @abstractmethod
    def analyze_cluster_drift(self):
        """
        Analyzes changes in cluster characteristics over time.
        Placeholder for future implementation.
        """
        logger.warning("analyze_cluster_drift is not yet implemented.")
                          
                                                       
                                                                                                   
                                                  
                                                                                                            
                                                                             
        raise NotImplementedError("analyze_cluster_drift is not yet implemented.")


class ChromaVectorStore(VectorStoreInterface):
    """
    A ChromaDB implementation of the VectorStoreInterface.
    """
    DEFAULT_COLLECTION_NAME = "fortress_prompts_collection" 

    def __init__(self, collection_name: Optional[str] = None, db_path: Optional[str] = None):
        config = get_config()
        chroma_settings = config.get("vector_database", {})

                            
        resolved_db_path = db_path                      
        if not resolved_db_path:
            resolved_db_path = chroma_settings.get("path")              
        if not resolved_db_path:
                                                                 
            resolved_db_path = "data/default_chroma_db_fallback" 
            logger.warning(f"ChromaVectorStore db_path not provided via argument or config. Defaulting to: {resolved_db_path}")
        self.db_path = resolved_db_path
        
                                                                                    
        os.makedirs(self.db_path, exist_ok=True)

                                    
        resolved_collection_name = collection_name                      
        if not resolved_collection_name:                          
            name_from_config = chroma_settings.get("collection_name")
            if name_from_config:
                resolved_collection_name = name_from_config              
                logger.debug(f"Using collection_name from config: {resolved_collection_name}")
            else:
                resolved_collection_name = self.DEFAULT_COLLECTION_NAME                     
                logger.warning(
                    f"Collection_name not provided via argument or config. Defaulting to class default: {resolved_collection_name}"
                )
        else:                        
             logger.debug(f"Using provided collection_name: {resolved_collection_name}")
        self.collection_name = resolved_collection_name
        
                                                                      
        self.distance_metric = chroma_settings.get("distance_function", "cosine")
        if self.distance_metric not in ["l2", "ip", "cosine"]:
            logger.warning(f"Unsupported distance_function '{self.distance_metric}' in config. Defaulting to 'cosine'.")
            self.distance_metric = "cosine"

        try:
            self.client = chromadb.PersistentClient(path=self.db_path)
            self.collection = self.client.get_or_create_collection(
                name=self.collection_name,
                metadata={"hnsw:space": self.distance_metric} 
            )
            logger.info(f"ChromaDB collection '{self.collection_name}' (metric: {self.distance_metric}) loaded/created successfully at path '{self.db_path}'.")                               

        except Exception as e:
            logger.error(f"Failed to initialize ChromaDB client or collection '{self.collection_name}' at path '{self.db_path}': {e}", exc_info=True)
            raise ConnectionError(f"ChromaDB connection failed for collection '{self.collection_name}' at path '{self.db_path}'") from e

    def _serialize_metadata(self, record: DatabasePromptRecord) -> Dict[str, Any]:
        """Serializes a DatabasePromptRecord into a metadata dictionary for ChromaDB."""
                                                                                                  
                                     
                                                     
                                                                
                                                   
        
                                                                 
        metadata = record.model_dump(
            exclude={'prompt_id', 'original_prompt', 'embedding'},
            exclude_none=True,
        )

        serialized_metadata: Dict[str, Any] = {}
        for key, value in metadata.items():
            if isinstance(value, (dict, list)):
                serialized_metadata[key] = json.dumps(value)
            else:
                serialized_metadata[key] = value
        
                                                                    
                                                                         
                                                                     
        return serialized_metadata

    def _deserialize_metadata(self,
                              metadata_dict: Dict[str, Any],
                              doc_id: str,
                              document_content: Optional[str] = None,
                              embedding: Optional[List[float]] = None
                             ) -> Optional[DatabasePromptRecord]:
        """
        Deserializes metadata from ChromaDB into a DatabasePromptRecord.
        Handles type conversions for fields stored as strings (e.g., JSON strings for dicts, string numbers).
        """
        data_for_model = metadata_dict.copy()

                                                                                
        for field_name in ['sentiment_scores', 'char_level_stats']:                                  
            if field_name in data_for_model and isinstance(data_for_model[field_name], str):
                try:
                    data_for_model[field_name] = json.loads(data_for_model[field_name])
                except json.JSONDecodeError:
                    logger.warning(f"Failed to parse JSON string for {field_name} in doc {doc_id}: {data_for_model[field_name]}")
                                                                                       
                                                                                                           


        if document_content is not None:
            data_for_model['original_prompt'] = document_content
        
                                                                                                                 
        if 'prompt_id' in data_for_model and isinstance(data_for_model['prompt_id'], str):
             data_for_model['prompt_id'] = uuid.UUID(data_for_model['prompt_id'])
        elif 'prompt_id' not in data_for_model:
                                                                                              
            try:
                data_for_model['prompt_id'] = uuid.UUID(doc_id)
            except ValueError:
                logger.warning(f"doc_id {doc_id} is not a valid UUID for prompt_id, Pydantic will generate a new one.")


                                                                         
        nlp_feature_fields = NLPFeatures.model_fields.keys()
        for key, value in data_for_model.items():
            if key in nlp_feature_fields and key not in ['sentiment_scores', 'char_level_stats']:                  
                field_type = NLPFeatures.model_fields[key].annotation
                                                                         
                                                                                                  
                                                                                                   
                pass                                                        


        if embedding:
            data_for_model['embedding'] = embedding
        
                                                                                              
                                                                                                      
        allowed_fields = {f for model in [InputPromptRecord, NLPFeatures, DatabasePromptRecord] for f in model.model_fields.keys()}
                                                                                                                
        allowed_fields.add('embedding') 
        
        keys_to_remove = [k for k in data_for_model if k not in allowed_fields]
        if keys_to_remove:
            logger.debug(f"Removing unexpected fields from metadata for doc {doc_id}: {keys_to_remove}")
            for k in keys_to_remove:
                del data_for_model[k]

        try:
            return DatabasePromptRecord.model_validate(data_for_model)
        except ValidationError as e:
            logger.error(f"Validation error deserializing document {doc_id}: {e}. Data: {data_for_model}")
            return None

    def add_documents(self, documents: List[DatabasePromptRecord]) -> Tuple[List[str], List[str]]:
        if not documents:
            return [], []

        succeeded_ids: List[str] = []
        failed_ids: List[str] = []
        
        embeddings_to_add: List[List[float]] = []
        metadatas_to_add: List[Dict[str, Any]] = []
                                                                                        
        texts_to_add: List[str] = [] 
        ids_to_add: List[str] = []

        for doc in documents:
            if doc.embedding is None:
                logger.warning(f"Skipping document {doc.prompt_id} due to missing embedding.")
                failed_ids.append(str(doc.prompt_id))
                continue
            
            current_metadata = self._serialize_metadata(doc)
                                                                                      
                                                                              
                                                                                                           
                                                             
                                                                                                                  

            embeddings_to_add.append(doc.embedding)
            metadatas_to_add.append(current_metadata)                                
            texts_to_add.append(doc.original_prompt)                                    
            ids_to_add.append(str(doc.prompt_id))

        if not ids_to_add:
            logger.info("No valid documents with embeddings found to add.")
            return [], [str(doc.prompt_id) for doc in documents]                                         

        try:
            logger.debug(f"Attempting to add {len(ids_to_add)} documents to collection '{self.collection_name}'.")
            self.collection.add(
                embeddings=embeddings_to_add,
                metadatas=metadatas_to_add,
                documents=texts_to_add,                                       
                ids=ids_to_add
            )
            succeeded_ids.extend(ids_to_add)
            logger.info(f"Successfully added {len(succeeded_ids)} documents to ChromaDB. Failed: {len(failed_ids)}.")
        except Exception as e:
            logger.error(f"Failed to add batch to ChromaDB collection '{self.collection_name}': {e}", exc_info=True)
                                                                                 
            failed_ids.extend(ids_to_add)                                  

        return succeeded_ids, failed_ids

    def query_similar(self, embedding: List[float], top_k: int = 10, filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
        """
        Queries the ChromaDB collection for documents similar to the given embedding.
        Returns a list of dictionaries, each containing 'id', 'distance', and 'metadata'.
        """
        if not self.collection:
            logger.error("ChromaDB collection is not initialized.")
            return []
        
        if not embedding:
            logger.error("Query embedding is empty.")
            return []

        try:
            query_results = self.collection.query(
                query_embeddings=[embedding],
                n_results=int(top_k),
                where=filters if filters else None,
                include=['metadatas', 'distances', 'documents']                                  
            )

            formatted_results: List[Dict[str, Any]] = []
            
                                                                                               
            if not query_results or not query_results.get('ids') or not query_results['ids'][0]:
                logger.debug("No similar documents found or query_results structure unexpected.")
                return []

            ids_list = query_results['ids'][0]
            distances_list = query_results['distances'][0] if query_results.get('distances') else [None] * len(ids_list)
            metadatas_list = query_results['metadatas'][0] if query_results.get('metadatas') else [{}] * len(ids_list)
            documents_list = query_results['documents'][0] if query_results.get('documents') else [None] * len(ids_list)                     

            for i in range(len(ids_list)):
                doc_id = ids_list[i]
                distance = distances_list[i]
                                                                                             
                metadata = metadatas_list[i] if metadatas_list[i] is not None else {}
                document_content = documents_list[i]                           
                
                                                                      
                                                                                 
                                                                      

                                                                                         
                                                                                              
                                                                       
                                                                                             
                deserialized_record = self._deserialize_metadata(
                    metadata_dict=metadata.copy(),                                                 
                    doc_id=doc_id,
                    document_content=document_content
                )

                result_item = {
                    "id": doc_id,
                    "distance": distance,
                    "metadata": metadata,                                                                                
                    "document": document_content                                             
                }

                if deserialized_record:
                                                                                     
                                                                 
                    result_item["metadata"]["label"] = deserialized_record.label
                    result_item["metadata"]["original_prompt"] = deserialized_record.original_prompt                                     
                    result_item["metadata"]["source_file"] = deserialized_record.source_file
                    result_item["metadata"]["perplexity"] = deserialized_record.perplexity                           
                                                                                              
                else:
                                                                                                  
                    logger.warning(f"Failed to deserialize document {doc_id} fully. Using raw metadata.")
                                                                                                    
                    if "original_prompt" not in result_item["metadata"] and document_content:
                         result_item["metadata"]["original_prompt"] = document_content
                                                                                                         
                    if result_item["metadata"].get("label") is None and metadata.get("label") is not None:
                        try:
                            result_item["metadata"]["label"] = int(str(metadata.get("label")))
                        except ValueError:
                            logger.warning(f"Could not parse 'label' from raw metadata for doc {doc_id}")
                    if result_item["metadata"].get("perplexity") is None and metadata.get("perplexity") is not None:
                        try:
                            result_item["metadata"]["perplexity"] = float(str(metadata.get("perplexity")))
                        except ValueError:
                             logger.warning(f"Could not parse 'perplexity' from raw metadata for doc {doc_id}")


                                                                              
                if result_item["metadata"].get("label") is None:
                    logger.warning(f"Retrieved document {doc_id} is missing critical metadata (label). Metadata: {metadata}")

                formatted_results.append(result_item)
            
            return formatted_results

        except Exception as e:
            logger.error(f"Error querying ChromaDB: {e}", exc_info=True)
            return []

    def get_collection_size(self) -> int:
        try:
            return self.collection.count()
        except Exception as e:
            logger.error(f"Error getting collection size from ChromaDB: {e}")
            return 0              

    def get_document_by_id(self, doc_id: str) -> Optional[DatabasePromptRecord]:
        """Retrieves a single document by its ID from the collection."""
        if not self.collection:
            logger.error("Collection is not initialized.")
            return None
        try:
            results = self.collection.get(ids=[doc_id], include=['embeddings', 'metadatas', 'documents'])
            
            if not results['ids']:
                logger.info(f"No document found with ID: {doc_id}")
                return None

            ret_id = results['ids'][0]
            metadata = results['metadatas'][0] if results['metadatas'] else {}
            embedding = results['embeddings'][0] if results['embeddings'] else None
            document_content = results['documents'][0] if results['documents'] else None

            if document_content is None and 'original_prompt' not in metadata :                                                                           
                 logger.warning(f"Document content for {doc_id} is None and not in metadata.")
                                                                      
            
            return self._deserialize_metadata(metadata, ret_id, document_content=document_content, embedding=embedding)

        except Exception as e:
            logger.error(f"Error retrieving document by ID {doc_id} from Chroma: {e}", exc_info=True)
            return None

    def delete_documents(self, ids: List[str]) -> Tuple[List[str], List[str]]:
        if not ids:
            return [], []
        
        succeeded_ids: List[str] = []
        failed_ids: List[str] = []

        try:
                                                                              
            if not self.collection:                                           
                 logger.error("Collection is not initialized. Cannot delete documents.")
                 return [], ids

            existing_docs_response = self.collection.get(ids=ids)
            ids_that_exist = set(existing_docs_response['ids'])

            ids_to_delete_attempt = [doc_id for doc_id in ids if doc_id in ids_that_exist]
            ids_not_found = [doc_id for doc_id in ids if doc_id not in ids_that_exist]

            if not ids_to_delete_attempt:
                logger.info(f"None of the provided IDs for deletion exist in the collection: {ids}")
                return [], ids                                    

            self.collection.delete(ids=ids_to_delete_attempt)
            succeeded_ids = ids_to_delete_attempt
            failed_ids = ids_not_found                                                          
            logger.info(f"Successfully deleted documents with IDs: {succeeded_ids}. IDs not found: {failed_ids}")

        except Exception as e:
             logger.error(f"Error during ChromaDB deletion process for IDs {ids}: {e}")
                                                                                                             
             return [], list(ids)
        
        return succeeded_ids, failed_ids


    def get_all_unsafe_embeddings_with_ids(self, batch_size: int = 1000) -> List[Tuple[str, List[float]]]:
        results: List[Tuple[str, List[float]]] = []
        offset = 0
                                                                                
                              
        where_filter = {"label": 1} 
        
        while True:
            try:
                fetch_results = self.collection.get(
                    where=where_filter,
                    limit=batch_size,
                    offset=offset,
                    include=["embeddings"]                                                                   
                )
            except Exception as e:
                logger.error(f"Error fetching batch of unsafe embeddings from ChromaDB: {e}")
                break 

            if not fetch_results or not fetch_results['ids']:
                break                                              

            current_batch_ids = fetch_results['ids']
            current_batch_embeddings = fetch_results['embeddings']

            if current_batch_embeddings is None:                                              
                logger.warning(f"Received null embeddings for a batch offset {offset}. Stopping.")
                break

            for i, doc_id in enumerate(current_batch_ids):
                embedding = current_batch_embeddings[i]
                if embedding is not None:
                    results.append((doc_id, embedding))
                else:
                    logger.warning(f"Document ID {doc_id} in unsafe query result has a null embedding. Skipping.")
            
            if len(current_batch_ids) < batch_size:
                break 
            offset += len(current_batch_ids)
            
        logger.info(f"Retrieved {len(results)} unsafe embeddings with IDs.")
        return results

    def update_document_metadata(self, doc_id: str, metadata_update: Dict[str, Any]) -> bool:
        if not self.collection:
            logger.error("Collection not initialized. Cannot update document metadata.")
            return False
        try:
                                                                 
                                                                                               
                                                        
            existing_doc_data = self.collection.get(ids=[doc_id], include=["metadatas"])

            if not existing_doc_data or not existing_doc_data['ids']:
                logger.warning(f"Document with ID {doc_id} not found. Cannot update metadata.")
                return False

            current_metadata = existing_doc_data['metadatas'][0] if existing_doc_data['metadatas'] and existing_doc_data['metadatas'][0] is not None else {}
            
                                                 
                                                                                                    
                                                                            
            updated_metadata = {**current_metadata, **metadata_update}

            self.collection.update(ids=[doc_id], metadatas=[updated_metadata])
            logger.info(f"Successfully updated metadata for document ID: {doc_id} with {metadata_update}")
            return True
        except Exception as e:
            logger.error(f"Failed to update metadata for document ID {doc_id}: {e}", exc_info=True)
            return False

    def _get_embeddings_for_clustering(self, label_filter: Optional[int] = None, batch_size: int = 1000) -> Tuple[List[str], List[List[float]], List[Optional[Dict[str, Any]]]]:
        """
        Fetches document IDs, embeddings, and their current metadatas for clustering.
        Filters by label if provided.
        """
        ids: List[str] = []
        embeddings_list: List[List[float]] = []
        metadatas_list: List[Optional[Dict[str, Any]]] = []

        if not self.collection:
            logger.error("Collection not initialized. Cannot fetch embeddings for clustering.")
            return ids, embeddings_list, metadatas_list

        offset = 0
        where_filter = {}
        if label_filter is not None:
                                                                                    
            where_filter = {"label": label_filter}
        
        logger.info(f"Fetching embeddings for clustering with filter: {where_filter if where_filter else 'None'}")

        while True:
            try:
                results = self.collection.get(
                    where=where_filter if where_filter else None,
                    limit=batch_size,
                    offset=offset,
                    include=["embeddings", "metadatas"]
                )
            except Exception as e:
                logger.error(f"Error fetching batch from ChromaDB for clustering: {e}", exc_info=True)
                break 

            if not results or not results.get('ids'):
                logger.debug(f"No more documents found for clustering at offset {offset} with filter {where_filter}.")
                break

            batch_ids = results['ids']
            batch_embeddings = results['embeddings']
            batch_metadatas = results['metadatas']

            if batch_embeddings is None or batch_metadatas is None:
                 logger.warning(f"Batch embeddings or metadatas are None for IDs: {batch_ids} (offset {offset}). Skipping this batch segment.")
                                                                                            
                                                                                                         
                                                                                      
                 if len(batch_ids) == 0:                                                        
                     break
                 offset += len(batch_ids)
                 if len(batch_ids) < batch_size:
                     break                                              
                 continue


            ids.extend(batch_ids)
            embeddings_list.extend(batch_embeddings)
            metadatas_list.extend(batch_metadatas)
            
            logger.debug(f"Fetched batch of {len(batch_ids)} for clustering. Total fetched: {len(ids)}")

            if len(batch_ids) < batch_size:
                break 
            offset += len(batch_ids)

        logger.info(f"Retrieved {len(ids)} embeddings with their metadatas for clustering.")
        return ids, embeddings_list, metadatas_list

    def assign_cluster_to_new_prompt(self,
                                     prompt_embedding: List[float],
                                     cluster_field_name: str = "prompt_category",
                                     default_top_k_neighbors: int = 3,
                                     input_label: Optional[int] = None) -> Optional[Any]:                    
        """
        Legacy method that returns a single cluster assignment.
        Internally uses assign_clusters_with_weights_to_new_prompt and returns the top cluster.
        """
        weighted_clusters = self.assign_clusters_with_weights_to_new_prompt(
            prompt_embedding, 
            cluster_field_name, 
            default_top_k_neighbors, 
            input_label
        )
        
        if weighted_clusters:
            return weighted_clusters[0][0]                                          
        return None

    def assign_clusters_with_weights_to_new_prompt(self,
                                                   prompt_embedding: List[float],
                                                   cluster_field_name: str = "prompt_category",
                                                   top_k_neighbors: int = 3,
                                                   input_label: Optional[int] = None) -> List[Tuple[Any, float]]:
        """
        Assigns cluster labels with weights to a new prompt based on its nearest neighbors.
        Returns a list of tuples (cluster_label, weight) sorted by weight in descending order.
        Weight is calculated based on distance (1 - normalized_distance).
        """
        if not self.collection:
            logger.warning("Collection not initialized. Cannot assign clusters.")
            return []

        try:
            query_filters: Optional[Dict[str, Any]] = None
            if input_label is not None:
                                                                                                 
                                                                           
                query_filters = {"label": input_label}
                logger.debug(f"Querying for neighbors with label: {input_label} to assign '{cluster_field_name}'.")
            else:
                logger.debug(f"Querying for neighbors (no label filter) to assign '{cluster_field_name}'.")

                                             
            query_results = self.collection.query(
                query_embeddings=[prompt_embedding],
                n_results=int(top_k_neighbors),
                where=query_filters,                                            
                include=["metadatas", "distances"]                               
            )

            metadatas = query_results.get('metadatas', [[]])[0]
            distances = query_results.get('distances', [[]])[0]
            
            if not metadatas:
                log_msg = "No neighbors found for the new prompt"
                if input_label is not None:
                    log_msg += f" matching label {input_label}"
                log_msg += ", cannot assign clusters."
                logger.debug(log_msg)
                return []

                                                                                   
            neighbor_data = []
            for i, meta in enumerate(metadatas):
                if meta and meta.get(cluster_field_name) is not None:
                    cluster = meta.get(cluster_field_name)
                    distance = distances[i] if i < len(distances) else 1.0
                    neighbor_data.append((cluster, distance))

            if not neighbor_data:
                logger.debug(f"No neighbors found with field '{cluster_field_name}' in their metadata, cannot assign clusters.")
                return []

                                                                                       
                                                                 
            max_distance = max(d for _, d in neighbor_data) if neighbor_data else 1.0
            min_distance = min(d for _, d in neighbor_data) if neighbor_data else 0.0
            
                                                 
            from collections import defaultdict
            cluster_weights = defaultdict(float)
            
            for cluster, distance in neighbor_data:
                                                                     
                if max_distance == min_distance:
                    weight = 1.0 / len(neighbor_data)                                               
                else:
                                                             
                    normalized_distance = (distance - min_distance) / (max_distance - min_distance)
                    weight = 1.0 - normalized_distance
                
                cluster_weights[cluster] += weight

                                           
            total_weight = sum(cluster_weights.values())
            if total_weight > 0:
                for cluster in cluster_weights:
                    cluster_weights[cluster] /= total_weight

                                                         
            sorted_clusters = sorted(cluster_weights.items(), key=lambda x: x[1], reverse=True)
            
            logger.debug(f"Assigned new prompt to clusters with weights: {sorted_clusters} for field '{cluster_field_name}' based on {top_k_neighbors} neighbors (label filter: {input_label is not None}).")
            
            return sorted_clusters

        except Exception as e:
            logger.error(f"Error assigning clusters to new prompt: {e}", exc_info=True)
            return []

    def analyze_cluster_drift(self):
        """
        Analyzes changes in cluster characteristics over time.
        Placeholder for future implementation.
        """
        logger.warning("analyze_cluster_drift is not yet implemented.")
                          
                                                       
                                                                                                   
                                                  
                                                                                                            
                                                                             
        raise NotImplementedError("analyze_cluster_drift is not yet implemented.")


    def cluster_prompts(self,
                        cluster_algorithm: str = "hdbscan",
                        label_filter: Optional[int] = None,
                        cluster_field_name: str = "prompt_category",                  
                        min_cluster_size: int = 5,
                        min_samples: Optional[int] = None,
                        hdbscan_metric: str = 'euclidean',
                        n_clusters_kmeans: Optional[int] = None,
                        random_state_kmeans: int = 42
                       ) -> bool:
        """
        Performs clustering on prompts in the database and updates their metadata with cluster IDs.
        The cluster IDs are stored in the field specified by `cluster_field_name`.

        Args:
            cluster_algorithm: "hdbscan" or "kmeans".
            label_filter: Optional integer label to filter prompts before clustering.
            cluster_field_name: Metadata field to store the assigned cluster ID. Defaults to "prompt_category".
            min_cluster_size: Minimum cluster size for HDBSCAN.
            min_samples: Minimum samples for HDBSCAN.
            hdbscan_metric: Distance metric for HDBSCAN.
            n_clusters_kmeans: Number of clusters for KMeans.
            random_state_kmeans: Random state for KMeans.

        Returns:
            True if clustering and metadata update were successful, False otherwise.
        """
        logger.info(f"Starting prompt clustering with algorithm: {cluster_algorithm}, target field: '{cluster_field_name}'")
        if label_filter is not None:
            logger.info(f"Filtering prompts by label: {label_filter}")

        try:
            doc_ids, embeddings, _ = self._get_embeddings_for_clustering(label_filter=label_filter)                                     
        except Exception as e:
            logger.error(f"Failed to retrieve embeddings for clustering: {e}", exc_info=True)
            return False

        if not embeddings:
            logger.warning("No embeddings found for clustering (possibly after filtering).")
            return False

        num_embeddings = len(embeddings)
        logger.info(f"Retrieved {num_embeddings} embeddings for clustering.")

                                           
        embeddings_np = np.array(embeddings)
        if embeddings_np.ndim != 2:
            logger.error(f"Embeddings array has unexpected shape: {embeddings_np.shape}. Expected 2D array.")
            return False

        cluster_labels: Optional[np.ndarray] = None

        if cluster_algorithm == "hdbscan":
                                                   
            if num_embeddings < min_cluster_size:
                logger.warning(f"Number of embeddings ({num_embeddings}) is less than min_cluster_size ({min_cluster_size}). "
                               f"HDBSCAN may not produce meaningful clusters or might error. Adjust parameters or add more data.")
                                                                                  
            
                                                                          
            effective_min_samples = min_samples if min_samples is not None else min_cluster_size
            if num_embeddings < effective_min_samples:
                 logger.warning(f"Number of embeddings ({num_embeddings}) is less than effective min_samples ({effective_min_samples}). "
                               f"Consider adjusting parameters.")


            try:
                import hdbscan              
                clusterer = hdbscan.HDBSCAN(
                    min_cluster_size=min_cluster_size,
                    min_samples=min_samples,                                  
                    metric=hdbscan_metric,
                    core_dist_n_jobs=-1,                          
                    allow_single_cluster=True                                
                )
                logger.info(f"Running HDBSCAN with min_cluster_size={min_cluster_size}, min_samples={min_samples}, metric='{hdbscan_metric}'...")
                cluster_labels = clusterer.fit_predict(embeddings_np)
                num_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
                noise_points = np.sum(cluster_labels == -1)
                logger.info(f"HDBSCAN clustering complete. Found {num_clusters} clusters and {noise_points} noise points.")
            except ImportError:
                logger.error("HDBSCAN library not installed. Please install it: pip install hdbscan")
                return False
            except Exception as e:
                logger.error(f"Error during HDBSCAN clustering: {e}", exc_info=True)
                return False

        elif cluster_algorithm == "kmeans":
            if n_clusters_kmeans is None:
                logger.error("n_clusters_kmeans must be specified for KMeans algorithm.")
                return False
            if num_embeddings < n_clusters_kmeans:
                logger.warning(f"Number of embeddings ({num_embeddings}) is less than n_clusters_kmeans ({n_clusters_kmeans}). "
                               f"KMeans might error or produce trivial results. Adjust parameters or add more data.")
                                                             
                if num_embeddings == 0:                                  
                    return False
                actual_n_clusters = min(n_clusters_kmeans, num_embeddings)                                                             
                if actual_n_clusters != n_clusters_kmeans:
                    logger.warning(f"Adjusting n_clusters_kmeans from {n_clusters_kmeans} to {actual_n_clusters} due to insufficient samples.")

            try:
                                                                 
                kmeans = KMeans(
                    n_clusters=actual_n_clusters,
                    random_state=random_state_kmeans,
                    n_init='auto'                                               
                )
                logger.info(f"Running KMeans with n_clusters={actual_n_clusters}, random_state={random_state_kmeans}...")
                cluster_labels = kmeans.fit_predict(embeddings_np)
                num_clusters = len(set(cluster_labels))
                logger.info(f"KMeans clustering complete. Found {num_clusters} clusters.")
            except Exception as e:
                logger.error(f"Error during KMeans clustering: {e}", exc_info=True)
                return False
        else:
            logger.error(f"Unsupported clustering algorithm: {cluster_algorithm}. Choose 'hdbscan' or 'kmeans'.")
            return False

        if cluster_labels is None:
            logger.error("Cluster labels were not generated. Aborting metadata update.")
            return False

                                    
        update_batch_size = 200                                             
        updated_count = 0
        failed_count = 0

        for i in range(0, len(doc_ids), update_batch_size):
            batch_doc_ids = doc_ids[i:i + update_batch_size]
            batch_cluster_labels = cluster_labels[i:i + update_batch_size]

                                                    
                                                                                                   
                                                                             
            metadatas_to_update = [
                {cluster_field_name: int(label), "category_assignment_source": "algorithmic_batch_clustering"}                                  
                for label in batch_cluster_labels
            ]
            
            try:
                                                                                 
                                                                                   
                                                                                                            
                self.collection.update(ids=batch_doc_ids, metadatas=metadatas_to_update)
                updated_count += len(batch_doc_ids)
                logger.debug(f"Successfully updated metadata for {len(batch_doc_ids)} documents in batch.")
            except Exception as e:
                failed_count += len(batch_doc_ids)
                logger.error(f"Failed to update metadata for batch starting at index {i}: {e}", exc_info=True)
                                                                                               

        logger.info(f"Clustering metadata update complete. Target field: '{cluster_field_name}'.")
        logger.info(f"Successfully updated metadata for {updated_count} documents.")
        if failed_count > 0:
            logger.warning(f"Failed to update metadata for {failed_count} documents.")
        
        return failed_count == 0                                             

