import asyncio
import os
from typing import Any, final, List
from dataclasses import dataclass
import numpy as np
import hashlib
import uuid
from ..utils import logger
from ..base import BaseVectorStorage
import configparser
import pipmaster as pm

if not pm.is_installed("qdrant-client"):
    pm.install("qdrant-client")

from qdrant_client import QdrantClient, models  # type: ignore

config = configparser.ConfigParser()
config.read("config.ini", "utf-8")


def compute_mdhash_id_for_qdrant(
    content: str, prefix: str = "", style: str = "simple"
) -> str:
    """
    Generate a UUID based on the content and support multiple formats.

    :param content: The content used to generate the UUID.
    :param style: The format of the UUID, optional values are "simple", "hyphenated", "urn".
    :return: A UUID that meets the requirements of Qdrant.
    """
    if not content:
        raise ValueError("Content must not be empty.")

    # Use the hash value of the content to create a UUID.
    hashed_content = hashlib.sha256((prefix + content).encode("utf-8")).digest()
    generated_uuid = uuid.UUID(bytes=hashed_content[:16], version=4)

    # Return the UUID according to the specified format.
    if style == "simple":
        return generated_uuid.hex
    elif style == "hyphenated":
        return str(generated_uuid)
    elif style == "urn":
        return f"urn:uuid:{generated_uuid}"
    else:
        raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.")


@final
@dataclass
class QdrantVectorDBStorage(BaseVectorStorage):
    def __init__(
        self, namespace, global_config, embedding_func, workspace=None, meta_fields=None
    ):
        super().__init__(
            namespace=namespace,
            workspace=workspace or "",
            global_config=global_config,
            embedding_func=embedding_func,
            meta_fields=meta_fields or set(),
        )
        self.__post_init__()

    @staticmethod
    def create_collection_if_not_exist(
        client: QdrantClient, collection_name: str, **kwargs
    ):
        if client.collection_exists(collection_name):
            return
        client.create_collection(collection_name, **kwargs)

    def __post_init__(self):
        # Check for QDRANT_WORKSPACE environment variable first (higher priority)
        # This allows administrators to force a specific workspace for all Qdrant storage instances
        qdrant_workspace = os.environ.get("QDRANT_WORKSPACE")
        if qdrant_workspace and qdrant_workspace.strip():
            # Use environment variable value, overriding the passed workspace parameter
            effective_workspace = qdrant_workspace.strip()
            logger.info(
                f"Using QDRANT_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
            )
        else:
            # Use the workspace parameter passed during initialization
            effective_workspace = self.workspace
            if effective_workspace:
                logger.debug(
                    f"Using passed workspace parameter: '{effective_workspace}'"
                )

        # Build namespace with workspace prefix for data isolation
        if effective_workspace:
            self.namespace = f"{effective_workspace}_{self.namespace}"
            logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
        # When workspace is empty, keep the original namespace unchanged

        kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
        cosine_threshold = kwargs.get("cosine_better_than_threshold")
        if cosine_threshold is None:
            raise ValueError(
                "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
            )
        self.cosine_better_than_threshold = cosine_threshold

        self._client = QdrantClient(
            url=os.environ.get(
                "QDRANT_URL", config.get("qdrant", "uri", fallback=None)
            ),
            api_key=os.environ.get(
                "QDRANT_API_KEY", config.get("qdrant", "apikey", fallback=None)
            ),
        )
        self._max_batch_size = self.global_config["embedding_batch_num"]
        QdrantVectorDBStorage.create_collection_if_not_exist(
            self._client,
            self.namespace,
            vectors_config=models.VectorParams(
                size=self.embedding_func.embedding_dim, distance=models.Distance.COSINE
            ),
        )

    async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
        logger.debug(f"Inserting {len(data)} to {self.namespace}")
        if not data:
            return

        import time

        current_time = int(time.time())

        list_data = [
            {
                "id": k,
                "created_at": current_time,
                **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
            }
            for k, v in data.items()
        ]
        contents = [v["content"] for v in data.values()]
        batches = [
            contents[i : i + self._max_batch_size]
            for i in range(0, len(contents), self._max_batch_size)
        ]

        embedding_tasks = [self.embedding_func(batch) for batch in batches]
        embeddings_list = await asyncio.gather(*embedding_tasks)

        embeddings = np.concatenate(embeddings_list)

        list_points = []
        for i, d in enumerate(list_data):
            list_points.append(
                models.PointStruct(
                    id=compute_mdhash_id_for_qdrant(d["id"]),
                    vector=embeddings[i],
                    payload=d,
                )
            )

        results = self._client.upsert(
            collection_name=self.namespace, points=list_points, wait=True
        )
        return results

    async def query(
        self, query: str, top_k: int, ids: list[str] | None = None
    ) -> list[dict[str, Any]]:
        embedding = await self.embedding_func(
            [query], _priority=5
        )  # higher priority for query
        results = self._client.search(
            collection_name=self.namespace,
            query_vector=embedding[0],
            limit=top_k,
            with_payload=True,
            score_threshold=self.cosine_better_than_threshold,
        )

        logger.debug(f"query result: {results}")

        return [
            {
                **dp.payload,
                "distance": dp.score,
                "created_at": dp.payload.get("created_at"),
            }
            for dp in results
        ]

    async def index_done_callback(self) -> None:
        # Qdrant handles persistence automatically
        pass

    async def delete(self, ids: List[str]) -> None:
        """Delete vectors with specified IDs

        Args:
            ids: List of vector IDs to be deleted
        """
        try:
            # Convert regular ids to Qdrant compatible ids
            qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
            # Delete points from the collection
            self._client.delete(
                collection_name=self.namespace,
                points_selector=models.PointIdsList(
                    points=qdrant_ids,
                ),
                wait=True,
            )
            logger.debug(
                f"Successfully deleted {len(ids)} vectors from {self.namespace}"
            )
        except Exception as e:
            logger.error(f"Error while deleting vectors from {self.namespace}: {e}")

    async def delete_entity(self, entity_name: str) -> None:
        """Delete an entity by name

        Args:
            entity_name: Name of the entity to delete
        """
        try:
            # Generate the entity ID
            entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-")
            logger.debug(
                f"Attempting to delete entity {entity_name} with ID {entity_id}"
            )

            # Delete the entity point from the collection
            self._client.delete(
                collection_name=self.namespace,
                points_selector=models.PointIdsList(
                    points=[entity_id],
                ),
                wait=True,
            )
            logger.debug(f"Successfully deleted entity {entity_name}")
        except Exception as e:
            logger.error(f"Error deleting entity {entity_name}: {e}")

    async def delete_entity_relation(self, entity_name: str) -> None:
        """Delete all relations associated with an entity

        Args:
            entity_name: Name of the entity whose relations should be deleted
        """
        try:
            # Find relations where the entity is either source or target
            results = self._client.scroll(
                collection_name=self.namespace,
                scroll_filter=models.Filter(
                    should=[
                        models.FieldCondition(
                            key="src_id", match=models.MatchValue(value=entity_name)
                        ),
                        models.FieldCondition(
                            key="tgt_id", match=models.MatchValue(value=entity_name)
                        ),
                    ]
                ),
                with_payload=True,
                limit=1000,  # Adjust as needed for your use case
            )

            # Extract points that need to be deleted
            relation_points = results[0]
            ids_to_delete = [point.id for point in relation_points]

            if ids_to_delete:
                # Delete the relations
                self._client.delete(
                    collection_name=self.namespace,
                    points_selector=models.PointIdsList(
                        points=ids_to_delete,
                    ),
                    wait=True,
                )
                logger.debug(
                    f"Deleted {len(ids_to_delete)} relations for {entity_name}"
                )
            else:
                logger.debug(f"No relations found for entity {entity_name}")
        except Exception as e:
            logger.error(f"Error deleting relations for {entity_name}: {e}")

    async def get_by_id(self, id: str) -> dict[str, Any] | None:
        """Get vector data by its ID

        Args:
            id: The unique identifier of the vector

        Returns:
            The vector data if found, or None if not found
        """
        try:
            # Convert to Qdrant compatible ID
            qdrant_id = compute_mdhash_id_for_qdrant(id)

            # Retrieve the point by ID
            result = self._client.retrieve(
                collection_name=self.namespace,
                ids=[qdrant_id],
                with_payload=True,
            )

            if not result:
                return None

            # Ensure the result contains created_at field
            payload = result[0].payload
            if "created_at" not in payload:
                payload["created_at"] = None

            return payload
        except Exception as e:
            logger.error(f"Error retrieving vector data for ID {id}: {e}")
            return None

    async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
        """Get multiple vector data by their IDs

        Args:
            ids: List of unique identifiers

        Returns:
            List of vector data objects that were found
        """
        if not ids:
            return []

        try:
            # Convert to Qdrant compatible IDs
            qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]

            # Retrieve the points by IDs
            results = self._client.retrieve(
                collection_name=self.namespace,
                ids=qdrant_ids,
                with_payload=True,
            )

            # Ensure each result contains created_at field
            payloads = []
            for point in results:
                payload = point.payload
                if "created_at" not in payload:
                    payload["created_at"] = None
                payloads.append(payload)

            return payloads
        except Exception as e:
            logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
            return []

    async def drop(self) -> dict[str, str]:
        """Drop all vector data from storage and clean up resources

        This method will delete all data from the Qdrant collection.

        Returns:
            dict[str, str]: Operation status and message
            - On success: {"status": "success", "message": "data dropped"}
            - On failure: {"status": "error", "message": "<error details>"}
        """
        try:
            # Delete the collection and recreate it
            if self._client.collection_exists(self.namespace):
                self._client.delete_collection(self.namespace)

            # Recreate the collection
            QdrantVectorDBStorage.create_collection_if_not_exist(
                self._client,
                self.namespace,
                vectors_config=models.VectorParams(
                    size=self.embedding_func.embedding_dim,
                    distance=models.Distance.COSINE,
                ),
            )

            logger.info(
                f"Process {os.getpid()} drop Qdrant collection {self.namespace}"
            )
            return {"status": "success", "message": "data dropped"}
        except Exception as e:
            logger.error(f"Error dropping Qdrant collection {self.namespace}: {e}")
            return {"status": "error", "message": str(e)}
