import hashlib
import pickle
from pathlib import Path
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType

from ...utils.log import logger

def collection_drop(db_path, collection_name):
    client = MilvusClient(db_path)
    if client.has_collection(collection_name):
        client.drop_collection(collection_name)
        logger.debug(f"Collection {collection_name} dropped.\n")
    client.close()
    

class MilvusDB:
    def __init__(self, db:str | Path | MilvusClient, emb_func, collection_name, emb_dim=None):
        self.client = MilvusClient(db) if isinstance(db, Path) else db
        self.name = collection_name
        self.new = True
        self.emb_func = emb_func
        if not emb_dim:
            emb_dim = len(emb_func('dummy'))
        if not self.client.has_collection(self.name):
            fields = [
                FieldSchema("uni_id", DataType.VARCHAR, max_length=64, is_primary=True),  # Changed to VARCHAR
                FieldSchema("text", DataType.VARCHAR, max_length=20000),
                FieldSchema("emb", DataType.FLOAT_VECTOR, dim=emb_dim)
            ]
            schema = CollectionSchema(
                fields=fields, 
                description="Collection for DRC document chunks",
                enable_dynamic_field=True
            )            
            index_params = self.client.prepare_index_params()
            index_params.add_index(
                field_name="emb",
                metric_type="COSINE",
                index_type="FLAT",
                index_name=f"vector_index{self.name}",
            )
            self.client.create_collection(collection_name=self.name, schema=schema)
            logger.debug(f"Collection {self.name} created.\n")
            self.client.create_index(self.name, index_params) 
            logger.debug(f"Index for emb created.\n")
        else:
            self.new = False
            logger.debug(f"Collection {self.name} already exists.\n")

    def drop_collection(self):
        if self.client.has_collection(self.name):
            self.client.drop_collection(self.name)  

    def insert(self, text, metadata=None):
        # Generate a deterministic ID based on text content to prevent duplicates
        # Create a string hash from the text content
        text_hash = hashlib.sha256(text.encode()).hexdigest()[:16]  
        
        # Otherwise, insert the new document
        emb = self.emb_func(text)
        data = {"uni_id": text_hash, "text": text, "emb": emb}
        
        if metadata:
            data.update(metadata)
        
        self.client.upsert(self.name, data)
        self.client.flush(collection_name=self.name)
        return True

    def search(self, text, top_k=5, output_fields=["text"]):
        content = self.emb_func(text)
        result = self.client.search(self.name,[content],output_fields=output_fields,limit=top_k)
        return result[0]

    def chunks_insert(self, chunk_path):
        if not Path(chunk_path).exists():
            logger.error(f"Chunk file {chunk_path} does not exist.")
            return
        with open(chunk_path, 'rb') as f:
            chunks = pickle.load(f)
        for chunk in chunks:
            text = ''
            if chunk.metadata and isinstance(chunk.metadata, dict):
                text += ", ".join(chunk.metadata.values()) + '\n'
            if chunk.page_content:
                text += chunk.page_content
            self.insert(text)   # no metadata for now
        logger.info(f"Inserted {len(chunks)} chunks into MilvusDB.")

    def collection_exists(self):
        exists = self.client.has_collection(self.name)
        return exists
    

