from typing import Any

from src.configs.memory import QdrantVecDBConfig
from src.memorykit.vec_dbs.base import BaseVecDB
from src.memorykit.vec_dbs.item import VecDBItem
from src.utils import require_python_package
import logging

logger = logging.getLogger(__name__)

class QdrantVecDB(BaseVecDB):
    """Qdrant vector database implementation."""

    @require_python_package(
        import_name="qdrant_client",
        install_command="pip install qdrant-client",
        install_link="https://python-client.qdrant.tech/",
    )
    def __init__(self, config: QdrantVecDBConfig):
        """Initialize the Qdrant vector database and the collection."""
        from qdrant_client import QdrantClient

        self.config = config

        # If both host and port are None, we are running in local mode
        if self.config.host is None and self.config.port is None:
            logger.warning(
                "Qdrant is running in local mode (host and port are both None). "
                "In local mode, there may be race conditions during concurrent reads/writes. "
                "It is strongly recommended to deploy a standalone Qdrant server "
                "(e.g., via Docker: https://qdrant.tech/documentation/quickstart/)."
            )

        self.client = QdrantClient(
            host=self.config.host, port=self.config.port, path=self.config.path
        )
        self.create_collection()

    def create_collection(self) -> None:
        """Create a new collection with specified parameters."""
        from qdrant_client.http import models

        if self.collection_exists(self.config.collection_name):
            collection_info = self.client.get_collection(self.config.collection_name)
            logger.warning(
                f"Collection '{self.config.collection_name}' (vector dimension: {collection_info.config.params.vectors.size}) already exists. Skipping creation."
            )

            return

        # Map string distance metric to Qdrant Distance enum
        distance_map = {
            "cosine": models.Distance.COSINE,
            "euclidean": models.Distance.EUCLID,
            "dot": models.Distance.DOT,
        }

        self.client.create_collection(
            collection_name=self.config.collection_name,
            vectors_config=models.VectorParams(
                size=self.config.vector_dimension,
                distance=distance_map[self.config.distance_metric],
            ),
        )

        logger.info(
            f"Collection '{self.config.collection_name}' created with {self.config.vector_dimension} dimensions."
        )

    def list_collections(self) -> list[str]:
        """List all collections."""
        collections = self.client.get_collections()
        return [collection.name for collection in collections.collections]

    def delete_collection(self, name: str) -> None:
        """Delete a collection."""
        self.client.delete_collection(collection_name=name)

    def collection_exists(self, name: str) -> bool:
        """Check if a collection exists."""
        try:
            self.client.get_collection(collection_name=name)
            return True
        except Exception:
            return False

    def search(
        self, query_vector: list[float], top_k: int, filter: dict[str, Any] | None = None, score_threshold: float | None = None
    ) -> list[VecDBItem]:
        """
        Search for similar items in the database.

        Args:
            query_vector: Single vector to search
            top_k: Number of results to return
            filter: Payload filters

        Returns:
            List of search results with distance scores and payloads.
        """
        qdrant_filter = self._dict_to_filter(filter) if filter else None
        response = self.client.search(
            collection_name=self.config.collection_name,
            query_vector=query_vector,
            limit=top_k,
            score_threshold=score_threshold,
            query_filter=qdrant_filter,
            with_vectors=True,
            with_payload=True,
        )
        logger.info(f"Qdrant search completed with {len(response)} results.")
        return [
            VecDBItem(
                id=point.id,
                vector=point.vector,
                payload=point.payload,
                score=point.score,
            )
            for point in response
        ]

    def _dict_to_filter(self, filter_dict: dict[str, Any]) -> Any:
        from qdrant_client.http import models

        """
        TODO: without testing
        Convert a dictionary filter to a Qdrant Filter object.
        
        Supports multiple filter formats:
        1. Simple exact match: {"field": value} -> MatchValue
        2. Operators: {"field": {"$eq": value}} -> MatchValue
                     {"field": {"$in": [val1, val2]}} -> MatchAny
                     {"field": {"$gt": value}} -> Range (gt)
                     {"field": {"$gte": value}} -> Range (gte)
                     {"field": {"$lt": value}} -> Range (lt)
                     {"field": {"$lte": value}} -> Range (lte)
                     {"field": {"$range": {"gte": min, "lte": max}}} -> Range
                     {"field": {"$text": value}} -> MatchText
                     {"field": {"$is_null": bool}} -> IsNull
                     {"field": {"$is_empty": bool}} -> IsEmpty
        3. Logical operators: {"$must": [conditions]} -> Filter(must=...)
                              {"$should": [conditions]} -> Filter(should=...)
                              {"$must_not": [conditions]} -> Filter(must_not=...)
        4. Nested filters: {"$and": [filter1, filter2]} -> Filter(must=[...])
                           {"$or": [filter1, filter2]} -> Filter(should=[...])
                           {"$not": filter} -> Filter(must_not=[...])
        """
        # Handle logical operators at top level
        if "$must" in filter_dict:
            conditions = []
            for cond in filter_dict["$must"]:
                if isinstance(cond, dict) and any(k.startswith("$") for k in cond.keys()):
                    # It's a nested filter
                    nested_filter = self._dict_to_filter(cond)
                    if nested_filter:
                        conditions.append(nested_filter)
                else:
                    # It's a field condition
                    field_cond = self._parse_field_condition(cond)
                    if field_cond:
                        conditions.append(field_cond)
            return models.Filter(must=conditions) if conditions else None
        
        if "$should" in filter_dict:
            conditions = []
            for cond in filter_dict["$should"]:
                if isinstance(cond, dict) and any(k.startswith("$") for k in cond.keys()):
                    # It's a nested filter
                    nested_filter = self._dict_to_filter(cond)
                    if nested_filter:
                        conditions.append(nested_filter)
                else:
                    # It's a field condition
                    field_cond = self._parse_field_condition(cond)
                    if field_cond:
                        conditions.append(field_cond)
            return models.Filter(should=conditions) if conditions else None
        
        if "$must_not" in filter_dict:
            conditions = []
            for cond in filter_dict["$must_not"]:
                if isinstance(cond, dict) and any(k.startswith("$") for k in cond.keys()):
                    # It's a nested filter
                    nested_filter = self._dict_to_filter(cond)
                    if nested_filter:
                        conditions.append(nested_filter)
                else:
                    # It's a field condition
                    field_cond = self._parse_field_condition(cond)
                    if field_cond:
                        conditions.append(field_cond)
            return models.Filter(must_not=conditions) if conditions else None
        
        if "$and" in filter_dict:
            filters = [self._dict_to_filter(f) for f in filter_dict["$and"]]
            must_conditions = []
            for f in filters:
                if f.must:
                    must_conditions.extend(f.must)
                elif f.should:
                    # Convert should to must by creating nested conditions
                    must_conditions.append(f)
                else:
                    must_conditions.extend(f.must_not if f.must_not else [])
            return models.Filter(must=must_conditions)
        
        if "$or" in filter_dict:
            filters = [self._dict_to_filter(f) for f in filter_dict["$or"]]
            should_conditions = []
            for f in filters:
                if f.should:
                    should_conditions.extend(f.should)
                elif f.must:
                    # Convert must to should by creating nested conditions
                    should_conditions.append(f)
                else:
                    should_conditions.extend(f.must_not if f.must_not else [])
            return models.Filter(should=should_conditions)
        
        if "$not" in filter_dict:
            inner_filter = self._dict_to_filter(filter_dict["$not"])
            must_not_conditions = []
            if inner_filter.must:
                must_not_conditions.extend(inner_filter.must)
            elif inner_filter.should:
                must_not_conditions.append(inner_filter)
            else:
                must_not_conditions.extend(inner_filter.must_not if inner_filter.must_not else [])
            return models.Filter(must_not=must_not_conditions)
        
        # Handle field conditions
        conditions = []
        for field, value in filter_dict.items():
            condition = self._parse_field_condition({field: value})
            if condition:
                conditions.append(condition)
        
        return models.Filter(must=conditions) if conditions else None
    
    def _parse_field_condition(self, field_dict: dict[str, Any]) -> Any:
        """Parse a single field condition from a dictionary."""
        from qdrant_client.http import models
        
        if not field_dict or len(field_dict) != 1:
            return None
        
        field, value = next(iter(field_dict.items()))
        
        # Simple exact match: {"field": value}
        if not isinstance(value, dict):
            return models.FieldCondition(
                key=field, match=models.MatchValue(value=value)
            )
        
        # Operator-based conditions: {"field": {"$op": value}}
        operators = value
        
        # Handle multiple operators (take the first one, or combine them)
        # For simplicity, we'll handle one operator at a time
        if "$eq" in operators:
            return models.FieldCondition(
                key=field, match=models.MatchValue(value=operators["$eq"])
            )
        
        if "$in" in operators:
            return models.FieldCondition(
                key=field, match=models.MatchAny(any=operators["$in"])
            )
        
        if "$text" in operators:
            return models.FieldCondition(
                key=field, match=models.MatchText(text=operators["$text"])
            )
        
        if "$is_null" in operators:
            return models.FieldCondition(
                key=field, is_null=models.IsNull(is_null=operators["$is_null"])
            )
        
        if "$is_empty" in operators:
            return models.FieldCondition(
                key=field, is_empty=models.IsEmpty(is_empty=operators["$is_empty"])
            )
        
        # Range conditions
        range_params = {}
        if "$gt" in operators:
            range_params["gt"] = operators["$gt"]
        if "$gte" in operators:
            range_params["gte"] = operators["$gte"]
        if "$lt" in operators:
            range_params["lt"] = operators["$lt"]
        if "$lte" in operators:
            range_params["lte"] = operators["$lte"]
        
        if "$range" in operators:
            range_params.update(operators["$range"])
        
        if range_params:
            return models.FieldCondition(
                key=field, range=models.Range(**range_params)
            )
        
        # If no recognized operator, treat as exact match
        logger.warning(
            f"Unknown filter operator for field '{field}': {operators}. "
            "Treating as exact match."
        )
        return models.FieldCondition(
            key=field, match=models.MatchValue(value=value)
        )

    def get_by_id(self, id: str) -> VecDBItem | None:
        """Get a single item by ID."""
        response = self.client.retrieve(
            collection_name=self.config.collection_name,
            ids=[id],
            with_payload=True,
            with_vectors=True,
        )

        if not response:
            return None

        point = response[0]
        return VecDBItem(
            id=point.id,
            vector=point.vector,
            payload=point.payload,
        )

    def get_by_ids(self, ids: list[str]) -> list[VecDBItem]:
        """Get multiple items by their IDs."""
        response = self.client.retrieve(
            collection_name=self.config.collection_name,
            ids=ids,
            with_payload=True,
            with_vectors=True,
        )

        if not response:
            return []

        return [
            VecDBItem(
                id=point.id,
                vector=point.vector,
                payload=point.payload,
            )
            for point in response
        ]
    def scroll_query(
        self,
        filter: dict[str, Any] | None = None,
        limit: int = 100,
        offset: str | int | None = None,
        with_vectors: bool = False,
        with_payload: bool | list[str] = True,
        order_by: str | None = None,
        order_desc: bool = False,
        start_from: int | float | str | None = None,
    ) -> tuple[list[VecDBItem], str | int | None]:
        """
        Pure payload query (not based on similarity), uses Qdrant scroll API under the hood.

        Args:
            filter: Payload filter conditions (will go through _dict_to_filter)
            limit: Number of items to return this time (page size)
            offset: Scroll offset (for next page), usually use next_page_offset returned last time
            with_vectors: Whether to return vectors (pure payload queries generally False)
            with_payload: True returns all payload; or pass field list to return only partial fields
            order_by: Sort by a payload field (requires corresponding index for that field, especially numeric fields)
            order_desc: Whether descending order
            start_from: Sort starting point (optional)

        Returns:
            (items, next_offset)
        """
        from qdrant_client.http import models

        qdrant_filter = self._dict_to_filter(filter) if filter else None

            qdrant_order_by = None
        if order_by:
            # Qdrant client's scroll supports order_by parameter
            qdrant_order_by = models.OrderBy(
                key=order_by,
                direction=models.Direction.DESC if order_desc else models.Direction.ASC,
                start_from=start_from,
            )

        points, next_offset = self.client.scroll(
            collection_name=self.config.collection_name,
            scroll_filter=qdrant_filter,
            limit=limit,
            offset=offset,
            order_by=qdrant_order_by,
            with_vectors=with_vectors,
            with_payload=with_payload,
        )

        items = [
            VecDBItem(
                id=p.id,
                vector=p.vector if with_vectors else None,
                payload=p.payload,
                score=None,  # scroll does not return similarity score
            )
            for p in points
        ]
        return items, next_offset
    
    def get_by_filter(
        self,
        filter: dict[str, Any],
        scroll_limit: int = 100,
        with_vectors: bool = False,
        with_payload: bool | list[str] = True,
    ) -> list[VecDBItem]:
        """
        Retrieve all items that match the given filter criteria.

        Args:
            filter: Payload filters to match against stored items
            scroll_limit: Maximum number of items to retrieve per scroll request

        Returns:
            List of items including vectors and payload that match the filter
        """
        qdrant_filter = self._dict_to_filter(filter) if filter else None
        all_points = []
        offset = None

        # Use scroll to paginate through all matching points
        while True:
            points, offset = self.client.scroll(
                collection_name=self.config.collection_name,
                limit=scroll_limit,
                scroll_filter=qdrant_filter,
                offset=offset,
                with_vectors=with_vectors,
                with_payload=with_payload,
            )

            if not points:
                break

            all_points.extend(points)

            # Update offset for next iteration
            if offset is None:
                break

        logger.info(f"Qdrant retrieve by filter completed with {len(all_points)} results.")
        return [
            VecDBItem(
                id=point.id,
                vector=point.vector,
                payload=point.payload,
            )
            for point in all_points
        ]

    def get_all(self, scroll_limit: int = 100, with_vectors: bool = False, with_payload: bool | list[str] = True) -> list[VecDBItem]:
        """Retrieve all items in the vector database."""
        return self.get_by_filter({}, scroll_limit=scroll_limit, with_vectors=with_vectors, with_payload=with_payload)

    def count(self, filter: dict[str, Any] | None = None) -> int:
        """Count items in the database, optionally with filter."""
        qdrant_filter = None
        if filter:
            qdrant_filter = self._dict_to_filter(filter)

        response = self.client.count(
            collection_name=self.config.collection_name, count_filter=qdrant_filter
        )

        return response.count

    def add(self, data: list[VecDBItem | dict[str, Any]]) -> None:
        from qdrant_client.http import models

        """
        Add data to the vector database.

        Args:
            data: List of VecDBItem objects or dictionaries containing:
                - 'id': unique identifier
                - 'vector': embedding vector
                - 'payload': additional fields for filtering/retrieval
        """
        points = []
        for item in data:
            if isinstance(item, dict):
                item = item.copy()
                item = VecDBItem.from_dict(item)
            point = models.PointStruct(id=item.id, vector=item.vector, payload=item.payload)
            points.append(point)

        self.client.upsert(collection_name=self.config.collection_name, points=points)

    def update(self, id: str, data: VecDBItem | dict[str, Any]) -> None:
        """Update an item in the vector database."""
        from qdrant_client.http import models

        if isinstance(data, dict):
            data = data.copy()
            data = VecDBItem.from_dict(data)

        if data.vector:
            # For vector updates (with or without payload), use upsert with the same ID
            self.client.upsert(
                collection_name=self.config.collection_name,
                points=[models.PointStruct(id=id, vector=data.vector, payload=data.payload)],
            )
        else:
            # For payload-only updates
            self.client.set_payload(
                collection_name=self.config.collection_name, payload=data.payload, points=[id]
            )

    def ensure_payload_indexes(self, fields: list[str]) -> None:
        """
        Create payload indexes for specified fields in the collection.
        This is idempotent: it will skip if index already exists.

        Args:
            fields (list[str]): List of field names to index (as keyword).
        """
        for field in fields:
            try:
                self.client.create_payload_index(
                    collection_name=self.config.collection_name,
                    field_name=field,
                    field_schema="keyword",  # Could be extended in future
                )
                logger.info(f"Qdrant payload index on '{field}' ensured.")
            except Exception as e:
                logger.warning(f"Failed to create payload index on '{field}': {e}")

    def upsert(self, data: list[VecDBItem | dict[str, Any]]) -> None:
        """
        Add or update data in the vector database.

        If an item with the same ID exists, it will be updated.
        Otherwise, it will be added as a new item.
        """
        # Qdrant's upsert operation already handles this logic
        self.add(data)

    def delete(self, ids: list[str]) -> None:
        from qdrant_client.http import models

        """Delete items from the vector database."""
        point_ids: list[str | int] = ids
        self.client.delete(
            collection_name=self.config.collection_name,
            points_selector=models.PointIdsList(points=point_ids),
        )
