# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import logging
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple

from camel.storages.vectordb_storages import (
    BaseVectorStorage,
    VectorDBQuery,
    VectorDBQueryResult,
    VectorDBStatus,
    VectorRecord,
)
from camel.utils import dependencies_required

logger = logging.getLogger(__name__)


class MilvusStorage(BaseVectorStorage):
    r"""An implementation of the `BaseVectorStorage` for interacting with
    Milvus, a cloud-native vector search engine.

    The detailed information about Milvus is available at:
    `Milvus <https://milvus.io/docs/overview.md/>`_

    Args:
        vector_dim (int): The dimension of storing vectors.
        url_and_api_key (Tuple[str, str]): Tuple containing
           the URL and API key for connecting to a remote Milvus instance.
           URL maps to Milvus uri concept, typically "endpoint:port".
           API key maps to Milvus token concept, for self-hosted it's
           "username:pwd", for Zilliz Cloud (fully-managed Milvus) it's API
           Key.
        collection_name (Optional[str], optional): Name for the collection in
            the Milvus. If not provided, set it to the current time with iso
            format. (default: :obj:`None`)
        **kwargs (Any): Additional keyword arguments for initializing
            `MilvusClient`.

    Raises:
        ImportError: If `pymilvus` package is not installed.
    """

    @dependencies_required('pymilvus')
    def __init__(
        self,
        vector_dim: int,
        url_and_api_key: Tuple[str, str],
        collection_name: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        from pymilvus import MilvusClient

        self._client: MilvusClient
        self._create_client(url_and_api_key, **kwargs)
        self.vector_dim = vector_dim
        self.collection_name = (
            collection_name or self._generate_collection_name()
        )
        self._check_and_create_collection()

    def _create_client(
        self,
        url_and_api_key: Tuple[str, str],
        **kwargs: Any,
    ) -> None:
        r"""Initializes the Milvus client with the provided connection details.

        Args:
            url_and_api_key (Tuple[str, str]): The URL and API key for the
                Milvus server.
            **kwargs: Additional keyword arguments passed to the Milvus client.
        """
        from pymilvus import MilvusClient

        self._client = MilvusClient(
            uri=url_and_api_key[0],
            token=url_and_api_key[1],
            **kwargs,
        )

    def _check_and_create_collection(self) -> None:
        r"""Checks if the specified collection exists in Milvus and creates it
        if it doesn't, ensuring it matches the specified vector dimensionality.
        """
        if self._collection_exists(self.collection_name):
            in_dim = self._get_collection_info(self.collection_name)[
                "vector_dim"
            ]
            if in_dim != self.vector_dim:
                # The name of collection has to be confirmed by the user
                raise ValueError(
                    "Vector dimension of the existing collection "
                    f'"{self.collection_name}" ({in_dim}) is different from '
                    f"the given embedding dim ({self.vector_dim})."
                )
        else:
            self._create_collection(
                collection_name=self.collection_name,
            )

    def _create_collection(
        self,
        collection_name: str,
        **kwargs: Any,
    ) -> None:
        r"""Creates a new collection in the database.

        Args:
            collection_name (str): Name of the collection to be created.
            **kwargs (Any): Additional keyword arguments pass to create
                collection.
        """

        from pymilvus import DataType

        # Set the schema
        schema = self._client.create_schema(
            auto_id=False,
            enable_dynamic_field=True,
            description='collection schema',
        )

        schema.add_field(
            field_name="id",
            datatype=DataType.VARCHAR,
            description='A unique identifier for the vector',
            is_primary=True,
            max_length=65535,
        )
        # max_length reference: https://milvus.io/docs/limitations.md
        schema.add_field(
            field_name="vector",
            datatype=DataType.FLOAT_VECTOR,
            description='The numerical representation of the vector',
            dim=self.vector_dim,
        )
        schema.add_field(
            field_name="payload",
            datatype=DataType.JSON,
            description=(
                'Any additional metadata or information related'
                'to the vector'
            ),
        )

        # Create the collection
        self._client.create_collection(
            collection_name=collection_name,
            schema=schema,
            **kwargs,
        )

        # Set the index of the parameters
        index_params = self._client.prepare_index_params()

        index_params.add_index(
            field_name="vector",
            metric_type="COSINE",
            index_type="AUTOINDEX",
            index_name="vector_index",
        )

        self._client.create_index(
            collection_name=collection_name, index_params=index_params
        )

    def _delete_collection(
        self,
        collection_name: str,
    ) -> None:
        r"""Deletes an existing collection from the database.

        Args:
            collection (str): Name of the collection to be deleted.
        """
        self._client.drop_collection(collection_name=collection_name)

    def _collection_exists(self, collection_name: str) -> bool:
        r"""Checks whether a collection with the specified name exists in the
        database.

        Args:
            collection_name (str): The name of the collection to check.

        Returns:
            bool: True if the collection exists, False otherwise.
        """
        return self._client.has_collection(collection_name)

    def _generate_collection_name(self) -> str:
        r"""Generates a unique name for a new collection based on the current
        timestamp. Milvus collection names can only contain alphanumeric
        characters and underscores.

        Returns:
            str: A unique, valid collection name.
        """
        timestamp = datetime.now().isoformat()
        transformed_name = re.sub(r'[^a-zA-Z0-9_]', '_', timestamp)
        valid_name = "Time" + transformed_name
        return valid_name

    def _get_collection_info(self, collection_name: str) -> Dict[str, Any]:
        r"""Retrieves details of an existing collection.

        Args:
            collection_name (str): Name of the collection to be checked.

        Returns:
            Dict[str, Any]: A dictionary containing details about the
                collection.
        """
        vector_count = self._client.get_collection_stats(collection_name)[
            'row_count'
        ]
        collection_info = self._client.describe_collection(collection_name)
        collection_id = collection_info['collection_id']

        dim_value = next(
            (
                field['params']['dim']
                for field in collection_info['fields']
                if field['description']
                == 'The numerical representation of the vector'
            ),
            None,
        )

        return {
            "id": collection_id,  # the id of the collection
            "vector_count": vector_count,  # the number of the vector
            "vector_dim": dim_value,  # the dimension of the vector
        }

    def _validate_and_convert_vectors(
        self, records: List[VectorRecord]
    ) -> List[dict]:
        r"""Validates and converts VectorRecord instances to the format
        expected by Milvus.

        Args:
            records (List[VectorRecord]): List of vector records to validate
            and convert.

        Returns:
            List[dict]: A list of dictionaries formatted for Milvus insertion.
        """

        validated_data = []

        for record in records:
            record_dict = {
                "id": record.id,
                "payload": record.payload
                if record.payload is not None
                else '',
                "vector": record.vector,
            }
            validated_data.append(record_dict)

        return validated_data

    def add(
        self,
        records: List[VectorRecord],
        **kwargs,
    ) -> None:
        r"""Adds a list of vectors to the specified collection.

        Args:
            records (List[VectorRecord]): List of vectors to be added.
            **kwargs (Any): Additional keyword arguments pass to insert.

        Raises:
            RuntimeError: If there was an error in the addition process.
        """
        validated_records = self._validate_and_convert_vectors(records)

        op_info = self._client.insert(
            collection_name=self.collection_name,
            data=validated_records,
            **kwargs,
        )
        logger.debug(f"Successfully added vectors in Milvus: {op_info}")

    def delete(
        self,
        ids: List[str],
        **kwargs: Any,
    ) -> None:
        r"""Deletes a list of vectors identified by their IDs from the
        storage. If unsure of ids you can first query the collection to grab
        the corresponding data.

        Args:
            ids (List[str]): List of unique identifiers for the vectors to be
                deleted.
            **kwargs (Any): Additional keyword arguments passed to delete.

        Raises:
            RuntimeError: If there is an error during the deletion process.
        """

        op_info = self._client.delete(
            collection_name=self.collection_name, pks=ids, **kwargs
        )
        logger.debug(f"Successfully deleted vectors in Milvus: {op_info}")

    def status(self) -> VectorDBStatus:
        r"""Retrieves the current status of the Milvus collection. This method
        provides information about the collection, including its vector
        dimensionality and the total number of vectors stored.

        Returns:
            VectorDBStatus: An object containing information about the
                collection's status.
        """
        status = self._get_collection_info(self.collection_name)
        return VectorDBStatus(
            vector_dim=status["vector_dim"],
            vector_count=status["vector_count"],
        )

    def query(
        self,
        query: VectorDBQuery,
        **kwargs: Any,
    ) -> List[VectorDBQueryResult]:
        r"""Searches for similar vectors in the storage based on the provided
        query.

        Args:
            query (VectorDBQuery): The query object containing the search
                vector and the number of top similar vectors to retrieve.
            **kwargs (Any): Additional keyword arguments passed to search.

        Returns:
            List[VectorDBQueryResult]: A list of vectors retrieved from the
                storage based on similarity to the query vector.
        """
        search_result = self._client.search(
            collection_name=self.collection_name,
            data=[query.query_vector],
            limit=query.top_k,
            output_fields=['vector', 'payload'],
            **kwargs,
        )
        query_results = []
        for points in search_result:
            for point in points:
                query_results.append(
                    VectorDBQueryResult.create(
                        similarity=point['distance'],
                        id=str(point['id']),
                        payload=(point['entity'].get('payload')),
                        vector=point['entity'].get('vector'),
                    )
                )

        return query_results

    def clear(self) -> None:
        r"""Removes all vectors from the Milvus collection. This method
        deletes the existing collection and then recreates it with the same
        schema to effectively remove all stored vectors.
        """
        self._delete_collection(self.collection_name)
        self._create_collection(collection_name=self.collection_name)

    def load(self) -> None:
        r"""Load the collection hosted on cloud service."""
        self._client.load_collection(self.collection_name)

    @property
    def client(self) -> Any:
        r"""Provides direct access to the Milvus client. This property allows
        for direct interactions with the Milvus client for operations that are
        not covered by the `MilvusStorage` class.

        Returns:
            Any: The Milvus client instance.
        """
        return self._client
