import os
import pickle
import uuid

import openai
import chromadb
from chromadb import EmbeddingFunction, Documents, Embeddings
from utils.prompts import initial_vdb_content

class VllmEmbeddingFunction(EmbeddingFunction):
    def __init__(self, model_id):
        self.model_id = model_id
    def __call__(self, input: Documents) -> Embeddings:
        response = openai.embeddings.create(
            model=self.model_id,
            input=input
        )
        # Extract embeddings from the response
        embeddings = [item.embedding for item in response.data]

        return embeddings

class ChromaVdb:
    def __init__(self, collection_name, model_id, api_key=None, api_base=None):
        self.model_id = model_id
        openai.api_key = api_key
        openai.base_url = api_base

        self.vdb_content = {
            "documents": [],
            "ids": [],
            "metadatas": [],
        }
        self.pickle_path = f'./{collection_name}.pkl'
        if os.path.exists(self.pickle_path):
            with open(self.pickle_path, 'rb') as f:
                vdb_content = pickle.load(f)
        else:
            vdb_content = initial_vdb_content[collection_name]

        self.client = chromadb.PersistentClient(path=".chroma_db")

        # Check if the collection already exists
        self.collection_name = collection_name
        existing_collections = self.client.list_collections()
        openai_embedder = VllmEmbeddingFunction(self.model_id)

        if self.collection_name not in [col.name for col in existing_collections]:
            # Create the collection only if it doesn't exist
            self.collection = self.client.create_collection(
                name=self.collection_name,
                embedding_function=openai_embedder,
                metadata={"hnsw:space": "cosine"},
            )
            self.add(vdb_content["documents"], vdb_content["metadatas"])
        else:
            # Load the existing collection
            self.collection = self.client.get_collection(
                name=self.collection_name,
                embedding_function=openai_embedder
            )

    def add(self, documents, metadatas):
        ids = [str(uuid.uuid4()) for _ in documents]
        self.vdb_content['documents'] += documents
        self.vdb_content['ids'] += ids
        self.vdb_content['metadatas'] += metadatas
        self.collection.add(
            documents=documents,
            ids=ids,
            metadatas=metadatas
        )
        with open(self.pickle_path, 'wb') as f:
            pickle.dump(self.vdb_content, f)
        assert len(self.vdb_content['documents']) == self.__len__()

    def search(self, query):
        results = self.collection.query(
            query_texts=[query],  # The input query you want to search
            n_results=1  # Number of top results to return (e.g., 1 for the highest similarity)
        )
        document = results["documents"][0][0]
        metadata = results["metadatas"][0][0]
        distance = results["distances"][0][0]
        print(f"Retrieved distance: {distance}")
        #print(document)
        return metadata["template"]

    def clear(self):
        self.collection.delete(where={})

    def delete_collection(self):
        self.client.delete_collection(name=self.collection_name)

    def __len__(self):
        return len(self.collection.get()['documents'])
    
    def save(self):
        with open(self.pickle_path, 'wb') as f:
            pickle.dump(self.vdb_content, f)

if __name__ == "__main__":
    #vdb = ChromaVdb(model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", api_key="EMPTY", api_base='http://localhost:8002/v1/')
    vdb = ChromaVdb(model_id="intfloat/e5-mistral-7b-instruct", api_key="EMPTY", api_base='http://localhost:8001/v1/')