from langchain_chroma import Chroma
from langchain.schema import Document
from langchain.embeddings.base import Embeddings
from RoboMemory.agent_utils import ModelConfig, VectorDBConfig
import json
import os
import requests
from typing import List
import logging


class DashScopeEmbeddings(Embeddings):

    def __init__(self, api_key: str, model: str = "text-embedding-v4", base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1"):
        self.api_key = api_key
        self.model = model
        self.base_url = base_url.rstrip('/')
        
        
    def embed_documents(self, texts: List[str]) -> List[List[float]]:

        url = f"{self.base_url}/embeddings"
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        
        # Alibaba Cloud API can process up to 10 texts at once
        batch_size = 10
        all_embeddings = []
        
        for i in range(0, len(texts), batch_size): #     embedding     10   
            batch = texts[i:i + batch_size]
            
            data = {
                "model": self.model,
                "input": batch,
                "encoding_format": "float"
            }
            
            try:
                response = requests.post(url, headers=headers, json=data)
                response.raise_for_status()
                result = response.json()
                
                # Extract embedding vectors
                batch_embeddings = []
                for item in result["data"]:
                    batch_embeddings.append(item["embedding"])
                all_embeddings.extend(batch_embeddings)
                
            except Exception as e:
                logging.error(f"Embedding API call failed: {e}")
                # Return default vector
                for _ in batch:
                    all_embeddings.append([0.0] * 1024)
        
        return all_embeddings
    
    def embed_query(self, text: str) -> List[float]:

        return self.embed_documents([text])[0]

class VectorDB():
  

    def __init__(
            self,
            embedding_conf : ModelConfig,
            vectordb_conf : VectorDBConfig,
    ) -> None:
   
        self.base_url = embedding_conf.base_url
        self.model = embedding_conf.model
        self.api_key = embedding_conf.api_key
        self.collection_name = vectordb_conf.collection_name
        self.persist_directory = vectordb_conf.persist_directory

        # Ensure persistence directory exists
        
        ##                 exist    warning
        
        if os.path.exists(self.persist_directory):
            logging.warning("The path of vector database is already exist!")
        
        os.makedirs(self.persist_directory, exist_ok=True)

        # Use custom DashScope embedding class
        self.embedding = DashScopeEmbeddings(
            api_key=self.api_key,
            model=self.model,
            base_url=self.base_url
        )

        self.db = Chroma(
            collection_name = self.collection_name,
            embedding_function = self.embedding,
            persist_directory = self.persist_directory,
        )
        
        #    message set          

        self.message_set: dict[str, int] = {}  # dict can be saved to json file in a more efficient way

        # Load existing message set
        self._load_message_set()

    def _load_message_set(self):

        message_set_path = f"{self.persist_directory}/message_set.json"
        if os.path.exists(message_set_path):
            try:
                with open(message_set_path, mode="r", encoding="utf-8") as fp:
                    self.message_set = json.load(fp)
            except Exception as e:
                logging.error(f"Failed to load message set: {e}")
                self.message_set = {}

    def _save_message_set(self):
  
        message_set_path = f"{self.persist_directory}/message_set.json"
        try:
            with open(message_set_path, mode="w", encoding="utf-8") as fp:
                json.dump(self.message_set, fp, ensure_ascii=False, indent=2)
        except Exception as e:
            logging.error(f"Failed to save message set: {e}")

    def search_DB(
            self,
            query : str,
            k : int = 10,
    ) -> list[str]:

        try:
            q = query.strip().casefold()
            if not q:
                return []
                
            results = self.db.similarity_search(
                query = q,
                k = k,
            )

            return [message.page_content for message in results]
        except Exception as e:
            logging.error(f"Vector database search failed: {e}")
            return []

    def insert_message2DB(
            self,
            message : str,
    ) -> list[str]:

        try:
            insert_message = message.strip().casefold()
            if insert_message == 'i':
                insert_message = 'I'
            if not insert_message:
                return []
                
            # Check if already exists
            if insert_message in self.message_set:
                return []  # Message already exists
            
            # Insert into vector database
            doc = Document(page_content=insert_message)
            doc_ids = self.db.add_documents([doc])
            
            # Update message set
            self.message_set[insert_message] = len(self.message_set)
            self._save_message_set()
            
            return doc_ids
            
        except Exception as e:
            logging.error(f"Vector database insertion failed: {e}")
            return []

    def get_collection_count(self) -> int:
     
        try:
            return self.db._collection.count()
        except:
            return len(self.message_set)

    def delete_entities(self, entities: List[str]) -> bool:

        try:
            deleted_count = 0
            for entity in entities:
                normalized_entity = entity.strip().casefold()
                if normalized_entity == 'i':
                    normalized_entity = 'I'

                # Check if entity exists in message set
                if normalized_entity in self.message_set:
                    try:
                        # Get ChromaDB collection
                        collection = self.db._collection
                        
                        # Get all documents for exact matching
                        all_docs = collection.get(include=["documents"])
                        
                        # Find document IDs with exact matches
                        ids_to_delete = []
                        if all_docs and 'documents' in all_docs and 'ids' in all_docs:
                            for i, doc in enumerate(all_docs['documents']):
                                if doc == normalized_entity:
                                    ids_to_delete.append(all_docs['ids'][i])
                        
                        # If exact matches found, delete them
                        if ids_to_delete:
                            collection.delete(ids=ids_to_delete)

                        # Remove from message set (regardless of whether found in ChromaDB)
                        del self.message_set[normalized_entity]
                        deleted_count += 1

                    except Exception as e:
                        logging.error(f"Failed to delete entity '{entity}': {e}")
                        continue
                else:
                    logging.warning(f"Entity '{normalized_entity}' doesn't exist in message set, skipping deletion")
            
            if deleted_count > 0:
                # Save updated message set
                self._save_message_set()
                return True
            else:
                logging.warning("No entities were deleted")
                return False
                
        except Exception as e:
            logging.error(f"Error occurred during entity deletion: {e}")
            return False