#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Neo4j数据库工具模块
提供Neo4j数据库的连接和操作功能
"""

import os
from pathlib import Path
from typing import Dict, List, Any, Optional, Union
from dataclasses import dataclass

try:
    import neo4j
    NEO4J_AVAILABLE = True
except ImportError:
    NEO4J_AVAILABLE = False
    print("警告: neo4j驱动未安装，使用模拟模式")

# 内置Neo4j客户端类（不依赖外部）
class Neo4jClient:
    def __init__(self, uri: str, user: str, password: str):
        if NEO4J_AVAILABLE:
            self.driver = neo4j.GraphDatabase.driver(uri, auth=(user, password))
        else:
            self.driver = None
    
    def close(self):
        if self.driver:
            self.driver.close()
    
    def execute_query(self, query: str, parameters: Optional[Dict] = None):
        """Wrapper for client execute_query"""
        if NEO4J_AVAILABLE and self.driver:
            with self.driver.session() as session:
                result = session.run(query, parameters or {})
                return [record.data() for record in result]
        return []
    
    def clear_database(self):
        """清空数据库"""
        if NEO4J_AVAILABLE and self.driver:
            with self.driver.session() as session:
                session.run("MATCH (n) DETACH DELETE n")
        return True
    
    def init_n10s_config(self):
        """初始化n10s RDF插件配置"""
        if NEO4J_AVAILABLE and self.driver:
            with self.driver.session() as session:
                session.run("""
                    CALL n10s.graphconfig.init({
                        handleVocabUris: 'MAP',
                        handleMultival: 'OVERWRITE',
                        keepLangTag: false,
                        keepCustomDataTypes: false,
                        applyNeo4jNaming: false
                    })
                """)
    
    def import_rdf_from_sparql(self, sparql_query: str):
        """从SPARQL查询导入RDF数据"""
        if NEO4J_AVAILABLE and self.driver:
            with self.driver.session() as session:
                session.run(
                    """
                    CALL n10s.rdf.import.fetch(
                        'https://query.wikidata.org/sparql?query=' + apoc.text.urlencode($sparql),
                        'Turtle',
                        { headerParams: { Accept: 'application/x-turtle' } }
                    )
                    """,
                    sparql=sparql_query
                )
    
    def rename_node_properties(self):
        """重命名节点属性，包括名称和描述"""
        if NEO4J_AVAILABLE and self.driver:
            with self.driver.session() as session:
                # 重命名名称属性
                session.run("""
                    MATCH (n:node)
                    SET n.name = n.node
                    REMOVE n.node
                """)
                
                # 重命名描述属性
                session.run("""
                    MATCH (n:node)
                    WHERE n.description IS NOT NULL
                    SET n.wikipedia_description = n.description
                    REMOVE n.description
                """)
    
    def remove_isolated_nodes(self):
        """移除孤立节点（没有关系的节点）"""
        if NEO4J_AVAILABLE and self.driver:
            with self.driver.session() as session:
                result = session.run("""
                    MATCH (n:node)
                    WHERE NOT (n)--()
                    DELETE n
                    RETURN count(n) as deleted
                """)
                return result.single()["deleted"]
        return 0
    
    def get_node_count(self):
        """获取节点数量"""
        if NEO4J_AVAILABLE and self.driver:
            with self.driver.session() as session:
                result = session.run("MATCH (n) RETURN count(n) as count")
                return result.single()["count"]
        return 0
    
    def get_relationship_count(self):
        """获取关系数量"""
        if NEO4J_AVAILABLE and self.driver:
            with self.driver.session() as session:
                result = session.run("MATCH ()-[r]->() RETURN count(r) as count")
                return result.single()["count"]
        return 0

from .logger_utils import get_logger

logger = get_logger(__name__)

@dataclass
class Neo4jConnectionConfig:
    """Neo4j连接配置"""
    uri: str
    user: str
    password: str
    database: Optional[str] = None

class Neo4jManager:
    """Neo4j数据库管理器"""
    
    def __init__(self, config: Neo4jConnectionConfig):
        """
        初始化Neo4j管理器
        
        Args:
            config: Neo4j连接配置
        """
        self.config = config
        self.client = None
        self._connected = False
        
    def connect(self) -> bool:
        """
        连接到Neo4j数据库
        
        Returns:
            连接是否成功
        """
        try:
            self.client = Neo4jClient(
                uri=self.config.uri,
                user=self.config.user,
                password=self.config.password
            )
            
            # 测试连接
            self.test_connection()
            self._connected = True
            logger.info(f"成功连接到Neo4j数据库: {self.config.uri}")
            return True
            
        except Exception as e:
            logger.error(f"连接Neo4j数据库失败: {e}")
            self._connected = False
            return False
    
    def disconnect(self) -> None:
        """断开数据库连接"""
        if self.client:
            self.client.close()
            self._connected = False
            logger.info("已断开Neo4j数据库连接")
    
    def test_connection(self) -> bool:
        """
        测试数据库连接
        
        Returns:
            连接是否正常
        """
        try:
            result = self.client.execute_query("RETURN 1 as test")
            return len(result) > 0
        except Exception as e:
            logger.error(f"Neo4j连接测试失败: {e}")
            return False
    
    def is_connected(self) -> bool:
        """检查是否已连接"""
        return self._connected and self.client is not None
    
    def execute_query(self, query: str, parameters: Optional[Dict] = None):
        """
        执行Cypher查询
        
        Args:
            query: Cypher查询语句
            parameters: 查询参数
            
        Returns:
            查询结果
        """
        if not self.is_connected():
            raise RuntimeError("未连接到Neo4j数据库")
        
        return self.client.execute_query(query, parameters)
    
    def get_database_stats(self) -> Dict[str, Any]:
        """
        获取数据库统计信息
        
        Returns:
            数据库统计字典
        """
        if not self.is_connected():
            raise RuntimeError("未连接到Neo4j数据库")
        
        try:
            # 获取节点数量
            node_count_result = self.client.execute_query("MATCH (n) RETURN count(n) as count")
            node_count = node_count_result[0]["count"] if node_count_result else 0
            
            # 获取关系数量
            rel_count_result = self.client.execute_query("MATCH ()-[r]->() RETURN count(r) as count")
            rel_count = rel_count_result[0]["count"] if rel_count_result else 0
            
            # 获取节点标签
            labels_result = self.client.execute_query("CALL db.labels()")
            labels = [record["label"] for record in labels_result]
            
            # 获取关系类型
            rel_types_result = self.client.execute_query("CALL db.relationshipTypes()")
            rel_types = [record["relationshipType"] for record in rel_types_result]
            
            stats = {
                "node_count": node_count,
                "relationship_count": rel_count,
                "node_labels": labels,
                "relationship_types": rel_types,
                "labels_count": len(labels),
                "rel_types_count": len(rel_types)
            }
            
            logger.info(f"数据库统计: {node_count} 节点, {rel_count} 关系")
            return stats
            
        except Exception as e:
            logger.error(f"获取数据库统计失败: {e}")
            return {}
    
    def clear_database(self) -> bool:
        """
        清空数据库
        
        Returns:
            清空是否成功
        """
        if not self.is_connected():
            raise RuntimeError("未连接到Neo4j数据库")
        
        try:
            logger.warning("正在清空Neo4j数据库...")
            self.client.execute_query("MATCH (n) DETACH DELETE n")
            logger.info("数据库已清空")
            return True
        except Exception as e:
            logger.error(f"清空数据库失败: {e}")
            return False
    
    def get_all_nodes(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
        """
        获取所有节点
        
        Args:
            limit: 限制返回数量
            
        Returns:
            节点列表
        """
        if not self.is_connected():
            raise RuntimeError("未连接到Neo4j数据库")
        
        try:
            query = "MATCH (n) RETURN id(n) as id, labels(n) as labels, properties(n) as properties"
            if limit:
                query += f" LIMIT {limit}"
            
            result = self.client.execute_query(query)
            nodes = []
            
            for record in result:
                nodes.append({
                    "id": record["id"],
                    "labels": record["labels"],
                    "properties": record["properties"]
                })
            
            logger.info(f"获取到 {len(nodes)} 个节点")
            return nodes
            
        except Exception as e:
            logger.error(f"获取节点失败: {e}")
            return []
    
    def get_nodes_by_label(self, label: str, limit: Optional[int] = None) -> List[Dict[str, Any]]:
        """
        根据标签获取节点
        
        Args:
            label: 节点标签
            limit: 限制返回数量
            
        Returns:
            节点列表
        """
        if not self.is_connected():
            raise RuntimeError("未连接到Neo4j数据库")
        
        try:
            query = f"MATCH (n:{label}) RETURN id(n) as id, labels(n) as labels, properties(n) as properties"
            if limit:
                query += f" LIMIT {limit}"
            
            result = self.client.execute_query(query)
            nodes = []
            
            for record in result:
                nodes.append({
                    "id": record["id"],
                    "labels": record["labels"],
                    "properties": record["properties"]
                })
            
            logger.info(f"获取到 {len(nodes)} 个标签为 '{label}' 的节点")
            return nodes
            
        except Exception as e:
            logger.error(f"获取标签节点失败: {e}")
            return []
    
    def search_nodes_by_name(self, name: str, exact_match: bool = False) -> List[Dict[str, Any]]:
        """
        根据名称搜索节点
        
        Args:
            name: 节点名称
            exact_match: 是否精确匹配
            
        Returns:
            匹配的节点列表
        """
        if not self.is_connected():
            raise RuntimeError("未连接到Neo4j数据库")
        
        try:
            if exact_match:
                query = "MATCH (n) WHERE n.name = $name RETURN id(n) as id, labels(n) as labels, properties(n) as properties"
            else:
                query = "MATCH (n) WHERE n.name CONTAINS $name RETURN id(n) as id, labels(n) as labels, properties(n) as properties"
            
            result = self.client.execute_query(query, {"name": name})
            nodes = []
            
            for record in result:
                nodes.append({
                    "id": record["id"],
                    "labels": record["labels"],
                    "properties": record["properties"]
                })
            
            logger.info(f"搜索名称 '{name}' 找到 {len(nodes)} 个节点")
            return nodes
            
        except Exception as e:
            logger.error(f"搜索节点失败: {e}")
            return []
    
    def get_node_relationships(self, node_id: int) -> List[Dict[str, Any]]:
        """
        获取节点的所有关系
        
        Args:
            node_id: 节点ID
            
        Returns:
            关系列表
        """
        if not self.is_connected():
            raise RuntimeError("未连接到Neo4j数据库")
        
        try:
            query = """
            MATCH (n)-[r]-(m) 
            WHERE id(n) = $node_id 
            RETURN type(r) as relationship_type, 
                   id(startNode(r)) as start_id, 
                   id(endNode(r)) as end_id,
                   properties(r) as properties,
                   properties(m) as target_properties
            """
            
            result = self.client.execute_query(query, {"node_id": node_id})
            relationships = []
            
            for record in result:
                relationships.append({
                    "type": record["relationship_type"],
                    "start_id": record["start_id"],
                    "end_id": record["end_id"],
                    "properties": record["properties"],
                    "target_properties": record["target_properties"]
                })
            
            logger.info(f"节点 {node_id} 有 {len(relationships)} 个关系")
            return relationships
            
        except Exception as e:
            logger.error(f"获取节点关系失败: {e}")
            return []
    
    def execute_custom_query(self, query: str, parameters: Optional[Dict] = None) -> List[Dict[str, Any]]:
        """
        执行自定义查询
        
        Args:
            query: Cypher查询语句
            parameters: 查询参数
            
        Returns:
            查询结果
        """
        if not self.is_connected():
            raise RuntimeError("未连接到Neo4j数据库")
        
        try:
            result = self.client.execute_query(query, parameters or {})
            results = []
            
            for record in result:
                results.append(dict(record))
            
            logger.info(f"自定义查询返回 {len(results)} 条结果")
            return results
            
        except Exception as e:
            logger.error(f"执行自定义查询失败: {e}")
            return []
    
    def __enter__(self):
        """上下文管理器入口"""
        self.connect()
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """上下文管理器出口"""
        self.disconnect()

def create_neo4j_manager(config_dict: Dict[str, str]) -> Neo4jManager:
    """
    从配置字典创建Neo4j管理器
    
    Args:
        config_dict: 包含uri, user, password的配置字典
        
    Returns:
        Neo4j管理器实例
    """
    config = Neo4jConnectionConfig(
        uri=config_dict['uri'],
        user=config_dict['user'],
        password=config_dict['password'],
        database=config_dict.get('database')
    )
    
    return Neo4jManager(config)

if __name__ == "__main__":
    # 测试Neo4j管理器
    config = Neo4jConnectionConfig(
        uri="bolt://localhost:7687",
        user="neo4j",
        password="12345678"
    )
    
    manager = Neo4jManager(config)
    
    try:
        # 测试连接
        if manager.connect():
            print("Neo4j连接测试成功")
            
            # 获取数据库统计
            stats = manager.get_database_stats()
            print(f"数据库统计: {stats}")
            
            # 获取部分节点
            nodes = manager.get_all_nodes(limit=5)
            print(f"获取到 {len(nodes)} 个节点")
            
        else:
            print("Neo4j连接测试失败")
            
    finally:
        manager.disconnect()
