# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import logging
import re
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from camel.storages.graph_storages.base import BaseGraphStorage
from camel.storages.graph_storages.graph_element import (
    GraphElement,
)
from camel.utils.commons import dependencies_required

logger = logging.getLogger(__name__)


if TYPE_CHECKING:
    from nebula3.data.ResultSet import (  # type: ignore[import-untyped]
        ResultSet,
    )
    from nebula3.gclient.net import (  # type: ignore[import-untyped]
        ConnectionPool,
        Session,
    )


MAX_RETRIES = 5
RETRY_DELAY = 3


class NebulaGraph(BaseGraphStorage):
    @dependencies_required('nebula3')
    def __init__(
        self, host, username, password, space, port=9669, timeout=10000
    ):
        r"""Initializes the NebulaGraph client.

        Args:
            host (str): The host address of the NebulaGraph service.
            username (str): The username for authentication.
            password (str): The password for authentication.
            space (str): The graph space to use. If it doesn't exist, a new
                one will be created.
            port (int, optional): The port number for the connection.
                (default: :obj:`9669`)
            timeout (int, optional): The connection timeout in milliseconds.
                (default: :obj:`10000`)
        """
        self.host = host
        self.username = username
        self.password = password
        self.space = space
        self.timeout = timeout
        self.port = port
        self.schema: str = ""
        self.structured_schema: Dict[str, Any] = {}
        self.connection_pool = self._init_connection_pool()
        self.session = self._get_session()

    def _init_connection_pool(self) -> "ConnectionPool":
        r"""Initialize the connection pool.

        Returns:
            ConnectionPool: A connection pool instance.

        Raises:
            Exception: If the connection pool initialization fails.
        """
        from nebula3.Config import Config  # type: ignore[import-untyped]
        from nebula3.gclient.net import ConnectionPool

        config = Config()
        config.max_connection_pool_size = 10
        config.timeout = self.timeout

        # Create the connection pool
        connection_pool = ConnectionPool()

        # Initialize the connection pool with Nebula Graph's address and port
        if not connection_pool.init([(self.host, self.port)], config):
            raise Exception("Failed to initialize the connection pool")

        return connection_pool

    def _get_session(self) -> "Session":
        r"""Get a session from the connection pool.

        Returns:
            Session: A session object connected to NebulaGraph.

        Raises:
            Exception: If session creation or space usage fails.
        """
        session = self.connection_pool.get_session(
            self.username, self.password
        )
        if not session:
            raise Exception("Failed to create a session")

        # Use the specified space
        session.execute(
            f"CREATE SPACE IF NOT EXISTS {self.space} "
            "(vid_type=FIXED_STRING(30));"
        )

        for attempt in range(MAX_RETRIES):
            res = session.execute(f"USE {self.space};")

            if res.is_succeeded():
                return session

            if attempt < MAX_RETRIES - 1:
                time.sleep(RETRY_DELAY)
            else:
                # Final attempt failed, raise an exception
                raise Exception(
                    f"Failed to execute `{self.space}` after "
                    f"{MAX_RETRIES} attempts: {res.error_msg()}"
                )

    @property
    def get_client(self) -> Any:
        r"""Get the underlying graph storage client."""
        return self.session

    def query(self, query: str) -> "ResultSet":  # type:ignore[override]
        r"""Execute a query on the graph store.

        Args:
            query (str): The Cypher-like query to be executed.

        Returns:
            ResultSet: The result set of the query execution.

        Raises:
            ValueError: If the query execution fails.
        """
        try:
            # Get the session
            result_set = self.session.execute(query)
            return result_set

        except Exception as e:
            raise ValueError(f"Query execution error: {e!s}")

    def get_relationship_types(self) -> List[str]:
        r"""Retrieve relationship types from the graph.

        Returns:
            List[str]: A list of relationship (edge) type names.
        """
        # Query all edge types
        result = self.query('SHOW EDGES')
        rel_types = []

        # Extract relationship type names
        for row in result.rows():
            edge_name = row.values[0].get_sVal().decode('utf-8')
            rel_types.append(edge_name)

        return rel_types

    def add_graph_elements(
        self,
        graph_elements: List[GraphElement],
    ) -> None:
        r"""Add graph elements (nodes and relationships) to the graph.

        Args:
            graph_elements (List[GraphElement]): A list of graph elements
                containing nodes and relationships.
        """
        nodes = self._extract_nodes(graph_elements)
        for node in nodes:
            try:
                self.add_node(node['id'], node['type'])
            except Exception as e:
                logger.warning(f"Failed to add node {node}. Error: {e}")
                continue

        relationships = self._extract_relationships(graph_elements)
        for rel in relationships:
            try:
                self.add_triplet(
                    rel['subj']['id'], rel['obj']['id'], rel['type']
                )
            except Exception as e:
                logger.warning(f"Failed to add relationship {rel}. Error: {e}")
                continue

    def ensure_edge_type_exists(
        self,
        edge_type: str,
        time_label: Optional[str] = None,
    ) -> None:
        r"""Ensures that a specified edge type exists in the NebulaGraph
        database. If the edge type already exists, this method does nothing.

        Args:
            edge_type (str): The name of the edge type to be created.
            time_label (str, optional): A specific timestamp to set as the
                default value for the time label property. If not
                provided, no timestamp will be added. (default: :obj:`None`)

        Raises:
            Exception: If the edge type creation fails after multiple retry
                attempts, an exception is raised with the error message.
        """
        create_edge_stmt = f"CREATE EDGE IF NOT EXISTS {edge_type} ()"
        if time_label is not None:
            time_label = self._validate_time_label(time_label)
            create_edge_stmt = f"""CREATE EDGE IF NOT EXISTS {edge_type}
            (time_label DATETIME DEFAULT {time_label})"""

        for attempt in range(MAX_RETRIES):
            res = self.query(create_edge_stmt)
            if res.is_succeeded():
                return  # Edge type creation succeeded

            if attempt < MAX_RETRIES - 1:
                time.sleep(RETRY_DELAY)
            else:
                # Final attempt failed, raise an exception
                raise Exception(
                    f"Failed to create edge type `{edge_type}` after "
                    f"{MAX_RETRIES} attempts: {res.error_msg()}"
                )

    def ensure_tag_exists(
        self, tag_name: str, time_label: Optional[str] = None
    ) -> None:
        r"""Ensures a tag is created in the NebulaGraph database. If the tag
        already exists, it does nothing.

        Args:
            tag_name (str): The name of the tag to be created.
            time_label (str, optional): A specific timestamp to set as the
                default value for the time label property. If not provided,
                no timestamp will be added. (default: :obj:`None`)

        Raises:
            Exception: If the tag creation fails after retries, an exception
                is raised with the error message.
        """
        create_tag_stmt = f"CREATE TAG IF NOT EXISTS {tag_name} ()"
        if time_label is not None:
            time_label = self._validate_time_label(time_label)
            create_tag_stmt = f"""CREATE TAG IF NOT EXISTS {tag_name}
            (time_label DATETIME DEFAULT {time_label})"""

        for attempt in range(MAX_RETRIES):
            res = self.query(create_tag_stmt)
            if res.is_succeeded():
                return  # Tag creation succeeded, exit the method

            if attempt < MAX_RETRIES - 1:
                time.sleep(RETRY_DELAY)
            else:
                # Final attempt failed, raise an exception
                raise Exception(
                    f"Failed to create tag `{tag_name}` after "
                    f"{MAX_RETRIES} attempts: {res.error_msg()}"
                )

    def add_node(
        self,
        node_id: str,
        tag_name: str,
        time_label: Optional[str] = None,
    ) -> None:
        r"""Add a node with the specified tag and properties.

        Args:
            node_id (str): The ID of the node.
            tag_name (str): The tag name of the node.
            time_label (str, optional): A specific timestamp to set for
                the node's time label property. If not provided, no timestamp
                will be added. (default: :obj:`None`)
        """
        node_id = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '', node_id)
        tag_name = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '', tag_name)

        self.ensure_tag_exists(tag_name, time_label)

        # Insert node with or without time_label property
        if time_label is not None:
            time_label = self._validate_time_label(time_label)
            insert_stmt = (
                f'INSERT VERTEX IF NOT EXISTS {tag_name}(time_label) VALUES '
                f'"{node_id}":("{time_label}")'
            )
        else:
            insert_stmt = (
                f'INSERT VERTEX IF NOT EXISTS {tag_name}() VALUES '
                f'"{node_id}":()'
            )

        for attempt in range(MAX_RETRIES):
            res = self.query(insert_stmt)
            if res.is_succeeded():
                return  # Node creation succeeded, exit the method

            if attempt < MAX_RETRIES - 1:
                time.sleep(RETRY_DELAY)
            else:
                # Final attempt failed, raise an exception
                raise Exception(
                    f"Failed to add node `{node_id}` after"
                    f" {MAX_RETRIES} attempts: {res.error_msg()}"
                )

    def _extract_nodes(self, graph_elements: List[Any]) -> List[Dict]:
        r"""Extracts unique nodes from graph elements.

        Args:
            graph_elements (List[Any]): A list of graph elements containing
                nodes.

        Returns:
            List[Dict]: A list of dictionaries representing nodes.
        """
        nodes = []
        seen_nodes = set()
        for graph_element in graph_elements:
            for node in graph_element.nodes:
                node_key = (node.id, node.type)
                if node_key not in seen_nodes:
                    nodes.append(
                        {
                            'id': node.id,
                            'type': node.type,
                            'properties': node.properties,
                        }
                    )
                    seen_nodes.add(node_key)
        return nodes

    def _extract_relationships(self, graph_elements: List[Any]) -> List[Dict]:
        r"""Extracts relationships from graph elements.

        Args:
            graph_elements (List[Any]): A list of graph elements containing
                relationships.

        Returns:
            List[Dict]: A list of dictionaries representing relationships.
        """
        relationships = []
        for graph_element in graph_elements:
            for rel in graph_element.relationships:
                relationship_dict = {
                    'subj': {'id': rel.subj.id, 'type': rel.subj.type},
                    'obj': {'id': rel.obj.id, 'type': rel.obj.type},
                    'type': rel.type,
                }
                relationships.append(relationship_dict)
        return relationships

    def refresh_schema(self) -> None:
        r"""Refreshes the schema by fetching the latest schema details."""
        self.schema = self.get_schema()
        self.structured_schema = self.get_structured_schema

    @property
    def get_structured_schema(self) -> Dict[str, Any]:
        r"""Generates a structured schema consisting of node and relationship
        properties, relationships, and metadata, including timestamps.

        Returns:
            Dict[str, Any]: A dictionary representing the structured schema.
        """
        _, node_properties = self.get_node_properties()
        _, rel_properties = self.get_relationship_properties()
        relationships = self.get_relationship_types()
        index = self.get_indexes()

        # Build structured_schema
        structured_schema = {
            "node_props": {
                el["labels"]: el["properties"] for el in node_properties
            },
            "rel_props": {
                el["type"]: el["properties"] for el in rel_properties
            },
            "relationships": relationships,
            "metadata": {"index": index},
        }

        return structured_schema

    def get_schema(self):
        r"""Generates a schema string describing node and relationship
        properties and relationships.

        Returns:
            str: A string describing the schema.
        """
        # Get all node and relationship properties
        formatted_node_props, _ = self.get_node_properties()
        formatted_rel_props, _ = self.get_relationship_properties()
        formatted_rels = self.get_relationship_types()

        # Generate schema string
        schema = "\n".join(
            [
                "Node properties are the following:",
                ", ".join(formatted_node_props),
                "Relationship properties are the following:",
                ", ".join(formatted_rel_props),
                "The relationships are the following:",
                ", ".join(formatted_rels),
            ]
        )

        return schema

    def get_indexes(self):
        r"""Fetches the tag indexes from the database.

        Returns:
            List[str]: A list of tag index names.
        """
        result = self.query('SHOW TAG INDEXES')
        indexes = []

        # Get tag indexes
        for row in result.rows():
            index_name = row.values[0].get_sVal().decode('utf-8')
            indexes.append(index_name)

        return indexes

    def add_triplet(
        self,
        subj: str,
        obj: str,
        rel: str,
        time_label: Optional[str] = None,
    ) -> None:
        r"""Adds a relationship (triplet) between two entities in the Nebula
        Graph database.

        Args:
            subj (str): The identifier for the subject entity.
            obj (str): The identifier for the object entity.
            rel (str): The relationship between the subject and object.
            time_label (str, optional): A specific timestamp to set for the
                time label property of the relationship. If not provided,
                no timestamp will be added. (default: :obj:`None`)

        Raises:
            ValueError: If the time_label format is invalid.
            Exception: If creating the relationship fails.
        """
        subj = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '', subj)
        obj = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '', obj)
        rel = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '', rel)

        self.ensure_tag_exists(subj)
        self.ensure_tag_exists(obj)
        self.ensure_edge_type_exists(rel, time_label)
        self.add_node(node_id=subj, tag_name=subj)
        self.add_node(node_id=obj, tag_name=obj)

        # Avoid latency
        time.sleep(1)

        # Create edge with or without time_label property
        if time_label is not None:
            time_label = self._validate_time_label(time_label)
            insert_stmt = (
                f'INSERT EDGE IF NOT EXISTS {rel}(time_label) VALUES '
                f'"{subj}"->"{obj}":("{time_label}")'
            )
        else:
            insert_stmt = (
                f'INSERT EDGE IF NOT EXISTS {rel}() VALUES '
                f'"{subj}"->"{obj}":()'
            )

        res = self.query(insert_stmt)
        if not res.is_succeeded():
            raise Exception(
                f'create relationship `{subj}` -> `{obj}`'
                + f'failed: {res.error_msg()}'
            )

    def delete_triplet(self, subj: str, obj: str, rel: str) -> None:
        r"""Deletes a specific triplet (relationship between two entities)
        from the Nebula Graph database.

        Args:
            subj (str): The identifier for the subject entity.
            obj (str): The identifier for the object entity.
            rel (str): The relationship between the subject and object.
        """
        delete_edge_query = f'DELETE EDGE {rel} "{subj}"->"{obj}";'
        self.query(delete_edge_query)

        if not self._check_edges(subj):
            self.delete_entity(subj)
        if not self._check_edges(obj):
            self.delete_entity(obj)

    def delete_entity(self, entity_id: str) -> None:
        r"""Deletes an entity (vertex) from the graph.

        Args:
            entity_id (str): The identifier of the entity to be deleted.
        """
        delete_vertex_query = f'DELETE VERTEX "{entity_id}";'
        self.query(delete_vertex_query)

    def _check_edges(self, entity_id: str) -> bool:
        r"""Checks if an entity has any remaining edges in the graph.

        Args:
            entity_id (str): The identifier of the entity.

        Returns:
            bool: :obj:`True` if the entity has edges, :obj:`False` otherwise.
        """
        # Combine the outgoing and incoming edge count query
        check_query = f"""
        (GO FROM {entity_id} OVER * YIELD count(*) as out_count) 
        UNION 
        (GO FROM {entity_id} REVERSELY OVER * YIELD count(*) as in_count)
        """

        # Execute the query
        result = self.query(check_query)

        # Check if the result contains non-zero edges
        if result.is_succeeded():
            rows = result.rows()
            total_count = sum(int(row.values[0].get_iVal()) for row in rows)
            return total_count > 0
        else:
            return False

    def get_node_properties(self) -> Tuple[List[str], List[Dict[str, Any]]]:
        r"""Retrieve node properties from the graph.

        Returns:
            Tuple[List[str], List[Dict[str, Any]]]: A tuple where the first
                element is a list of node schema properties, and the second
                element is a list of dictionaries representing node structures.
        """
        # Query all tags
        result = self.query('SHOW TAGS')
        node_schema_props = []
        node_structure_props = []

        # Iterate through each tag to get its properties
        for row in result.rows():
            tag_name = row.values[0].get_sVal().decode('utf-8')
            describe_result = self.query(f'DESCRIBE TAG {tag_name}')
            properties = []

            for prop_row in describe_result.rows():
                prop_name = prop_row.values[0].get_sVal().decode('utf-8')
                node_schema_props.append(f"{tag_name}.{prop_name}")
                properties.append(prop_name)

            node_structure_props.append(
                {"labels": tag_name, "properties": properties}
            )

        return node_schema_props, node_structure_props

    def get_relationship_properties(
        self,
    ) -> Tuple[List[str], List[Dict[str, Any]]]:
        r"""Retrieve relationship (edge) properties from the graph.

        Returns:
            Tuple[List[str], List[Dict[str, Any]]]: A tuple where the first
                element is a list of relationship schema properties, and the
                second element is a list of dictionaries representing
                relationship structures.
        """

        # Query all edge types
        result = self.query('SHOW EDGES')
        rel_schema_props = []
        rel_structure_props = []

        # Iterate through each edge type to get its properties
        for row in result.rows():
            edge_name = row.values[0].get_sVal().decode('utf-8')
            describe_result = self.query(f'DESCRIBE EDGE {edge_name}')
            properties = []

            for prop_row in describe_result.rows():
                prop_name = prop_row.values[0].get_sVal().decode('utf-8')
                rel_schema_props.append(f"{edge_name}.{prop_name}")
                properties.append(prop_name)

            rel_structure_props.append(
                {"type": edge_name, "properties": properties}
            )

        return rel_schema_props, rel_structure_props

    def _validate_time_label(self, time_label: str) -> str:
        r"""Validates the format of a time label string.

        Args:
            time_label (str): The time label string to validate.
                Should be in format 'YYYY-MM-DDThh:mm:ss'.

        Returns:
            str: The validated time label.

        Raises:
            ValueError: If the time label format is invalid.
        """
        try:
            # Check if the format matches YYYY-MM-DDThh:mm:ss
            pattern = r'^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$'
            if not re.match(pattern, time_label):
                raise ValueError(
                    "Time label must be in format 'YYYY-MM-DDThh:mm:ss'"
                )
            return time_label
        except Exception as e:
            raise ValueError(f"Invalid time label format: {e!s}")
