#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Neo4j数据库操作工具类
"""

from typing import List, Dict, Any, Optional
from neo4j import GraphDatabase
import pandas as pd


class Neo4jClient:
    """Neo4j数据库客户端"""
    
    def __init__(self, uri: str, user: str, password: str):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))
    
    def close(self):
        """关闭数据库连接"""
        if self.driver:
            self.driver.close()
    
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()
    
    def clear_database(self):
        """清空数据库"""
        with self.driver.session() as session:
            session.run("MATCH (n) DETACH DELETE n")
    
    def init_n10s_config(self):
        """初始化n10s RDF插件配置"""
        with self.driver.session() as session:
            session.run("""
                CALL n10s.graphconfig.init({
                    handleVocabUris: 'MAP',
                    handleMultival: 'OVERWRITE',
                    keepLangTag: false,
                    keepCustomDataTypes: false,
                    applyNeo4jNaming: false
                })
            """)
    
    def import_rdf_from_sparql(self, sparql_query: str, max_retries: int = 5, retry_delay: float = 2.0):
        """从SPARQL查询导入RDF数据，带重试机制"""
        import logging
        import time
        logger = logging.getLogger(__name__)
        
        for attempt in range(max_retries):
            try:
                with self.driver.session() as session:
                    logger.debug(f"🌐 开始执行n10s.rdf.import.fetch (尝试 {attempt + 1}/{max_retries})")
                    logger.debug(f"📝 查询长度: {len(sparql_query)} 字符")
                    
                    result = session.run(
                        """
                        CALL n10s.rdf.import.fetch(
                            'https://query.wikidata.org/sparql?query=' + apoc.text.urlencode($sparql),
                            'Turtle',
                            { headerParams: { Accept: 'application/x-turtle' } }
                        )
                        """,
                        sparql=sparql_query
                    )
                    
                    # 获取导入结果
                    records = list(result)
                    logger.debug(f"📊 n10s导入结果记录数: {len(records)}")
                    
                    for record in records:
                        logger.debug(f"📋 导入记录: {dict(record)}")
                    
                    # 成功执行，返回
                    logger.info(f"✅ SPARQL导入成功 (尝试 {attempt + 1}/{max_retries})")
                    return
                        
            except Exception as e:
                error_msg = str(e)
                logger.warning(f"❌ SPARQL导入失败 (尝试 {attempt + 1}/{max_retries}): {error_msg}")
                
                # 检查是否是可重试的错误
                if self._is_retryable_error(error_msg):
                    if attempt < max_retries - 1:
                        wait_time = retry_delay * (2 ** attempt)  # 指数退避
                        logger.info(f"⏳ 等待 {wait_time:.1f} 秒后重试...")
                        time.sleep(wait_time)
                        continue
                else:
                    # 不可重试的错误，直接抛出
                    logger.error(f"❌ 检测到不可重试错误，停止重试: {error_msg}")
                    raise
        
        # 所有重试都失败
        logger.error(f"❌ SPARQL导入最终失败，已重试 {max_retries} 次")
        raise Exception(f"SPARQL导入失败，已重试 {max_retries} 次")
    
    def _is_retryable_error(self, error_msg: str) -> bool:
        """判断错误是否可重试"""
        # 定义可重试的错误模式
        retryable_patterns = [
            'timeout',
            'connection',
            'network',
            'temporary',
            'unavailable',
            'busy',
            'overload',
            '503',  # Service Unavailable
            '502',  # Bad Gateway  
            '504',  # Gateway Timeout
            'read timeout',
            'connect timeout'
        ]
        
        error_lower = error_msg.lower()
        return any(pattern in error_lower for pattern in retryable_patterns)
    
    def rename_node_properties(self):
        """重命名节点属性，包括名称和描述"""
        with self.driver.session() as session:
            # 重命名名称属性
            session.run("""
                MATCH (n:node)
                SET n.name = n.node
                REMOVE n.node
            """)
            
            # 重命名描述属性
            session.run("""
                MATCH (n:node)
                WHERE n.description IS NOT NULL
                SET n.wikipedia_description = n.description
                REMOVE n.description
            """)
    
    def clean_duplicate_relationships(self):
        """清理导入后的重复关系，每对节点只保留一种关系类型"""
        # 这个方法需要在RDF导入后、关系转换前调用
        # 因为此时关系类型还是P31, P279, P361, P527
        with self.driver.session() as session:
            # 检查实际的关系类型
            result = session.run("MATCH ()-[r]->() RETURN DISTINCT type(r) as rel_type LIMIT 5")
            current_types = [record['rel_type'] for record in result]
            print(f"当前关系类型: {current_types}")
            
            if 'P31' in current_types:
                # 关系还没转换，删除原始关系类型的重复
                print("清理原始关系类型的重复...")
                
                # 删除重复的P527关系
                result = session.run("""
                    MATCH (s:node)-[r:P527]->(e:node)
                    WHERE EXISTS { (s)-[:P31|P279|P361]->(e) }
                    DELETE r
                    RETURN count(r) as deleted
                """)
                deleted_p527 = result.single()["deleted"]
                
                # 删除重复的P361关系
                result = session.run("""
                    MATCH (s:node)-[r:P361]->(e:node)
                    WHERE EXISTS { (s)-[:P31|P279]->(e) }
                    DELETE r
                    RETURN count(r) as deleted
                """)
                deleted_p361 = result.single()["deleted"]
                
                # 删除重复的P279关系
                result = session.run("""
                    MATCH (s:node)-[r:P279]->(e:node)
                    WHERE EXISTS { (s)-[:P31]->(e) }
                    DELETE r
                    RETURN count(r) as deleted
                """)
                deleted_p279 = result.single()["deleted"]
                
                total_deleted = deleted_p527 + deleted_p361 + deleted_p279
                if total_deleted > 0:
                    print(f"清理了 {total_deleted} 个重复的原始关系")
                    
            elif 'instance_of' in current_types:
                # 关系已经转换，只删除同一对节点之间的重复关系
                print("清理转换后关系类型的重复（只删除同一对节点的多个关系）...")
                
                # 删除同一对节点之间的重复has_part关系，如果有更高优先级的关系
                result = session.run("""
                    MATCH (s:node)-[r:has_part]->(e:node)
                    WHERE EXISTS { 
                        MATCH (s)-[:instance_of|subclass_of|part_of]->(e)
                    }
                    DELETE r
                    RETURN count(r) as deleted
                """)
                deleted_has_part = result.single()["deleted"]
                
                # 删除同一对节点之间的重复part_of关系，如果有更高优先级的关系
                result = session.run("""
                    MATCH (s:node)-[r:part_of]->(e:node)
                    WHERE EXISTS { 
                        MATCH (s)-[:instance_of|subclass_of]->(e)
                    }
                    DELETE r
                    RETURN count(r) as deleted
                """)
                deleted_part_of = result.single()["deleted"]
                
                # 删除同一对节点之间的重复subclass_of关系，如果有instance_of关系
                result = session.run("""
                    MATCH (s:node)-[r:subclass_of]->(e:node)
                    WHERE EXISTS { 
                        MATCH (s)-[:instance_of]->(e)
                    }
                    DELETE r
                    RETURN count(r) as deleted
                """)
                deleted_subclass_of = result.single()["deleted"]
                
                total_deleted = deleted_has_part + deleted_part_of + deleted_subclass_of
                if total_deleted > 0:
                    print(f"清理了 {total_deleted} 个同一节点对的重复关系")
                    
                # 显示清理后的关系类型分布
                result = session.run("""
                    MATCH ()-[r]->()
                    RETURN type(r) as relation_type, count(r) as count
                    ORDER BY count DESC
                """)
                print("清理后的关系类型分布:")
                for record in result:
                    print(f"  - {record['relation_type']}: {record['count']} 个")

    def reverse_relationships(self):
        """反转关系并根据原始关系类型创建正确的关系名称"""
        with self.driver.session() as session:
            # 首先查询数据库中实际存在的关系类型
            existing_relations = session.run("""
                MATCH ()-[r]->()
                RETURN DISTINCT type(r) as relation_type
            """).values()
            
            print(f"发现的关系类型: {[r[0] for r in existing_relations]}")
            
            # 处理 P31 (instance_of) 关系
            result = session.run("""
                MATCH (s:node)-[r:P31]->(e:node)
                CREATE (e)-[r2:instance_of]->(s)
                SET r2 = r
                DELETE r
                RETURN count(r2) as processed
            """)
            p31_count = result.single()["processed"]
            if p31_count > 0:
                print(f"处理了 {p31_count} 个 P31 (instance_of) 关系")
            
            # 处理 P279 (subclass_of) 关系
            result = session.run("""
                MATCH (s:node)-[r:P279]->(e:node)
                CREATE (e)-[r2:subclass_of]->(s)
                SET r2 = r
                DELETE r
                RETURN count(r2) as processed
            """)
            p279_count = result.single()["processed"]
            if p279_count > 0:
                print(f"处理了 {p279_count} 个 P279 (subclass_of) 关系")
            
            # 处理 P361 (part_of) 关系
            result = session.run("""
                MATCH (s:node)-[r:P361]->(e:node)
                CREATE (e)-[r2:part_of]->(s)
                SET r2 = r
                DELETE r
                RETURN count(r2) as processed
            """)
            p361_count = result.single()["processed"]
            if p361_count > 0:
                print(f"处理了 {p361_count} 个 P361 (part_of) 关系")
            
            # 处理 P527 (has_part) 关系  
            result = session.run("""
                MATCH (s:node)-[r:P527]->(e:node)
                CREATE (e)-[r2:has_part]->(s)
                SET r2 = r
                DELETE r
                RETURN count(r2) as processed
            """)
            p527_count = result.single()["processed"]
            if p527_count > 0:
                print(f"处理了 {p527_count} 个 P527 (has_part) 关系")
            
            # 检查是否还有其他剩余关系类型
            remaining_relations = session.run("""
                MATCH (s:node)-[r]->(e:node)
                WHERE NOT (type(r) IN ['instance_of', 'subclass_of', 'part_of', 'has_part'])
                RETURN DISTINCT type(r) as relation_type
                LIMIT 5
            """).values()
            
            if remaining_relations:
                print(f"发现剩余关系类型: {[r[0] for r in remaining_relations]}")
                # 将剩余关系统一处理为 subclass_of（保持原有行为）
                result = session.run("""
                    MATCH (s:node)-[r]->(e:node)
                    WHERE NOT (type(r) IN ['instance_of', 'subclass_of', 'part_of', 'has_part'])
                    CREATE (e)-[r2:subclass_of]->(s)
                    SET r2 = r
                    DELETE r
                    RETURN count(r2) as processed
                """)
                remaining_count = result.single()["processed"]
                if remaining_count > 0:
                    print(f"处理了 {remaining_count} 个剩余关系，统一命名为 subclass_of")
            
            # 显示最终的关系类型分布
            final_relations = session.run("""
                MATCH ()-[r]->()
                RETURN type(r) as relation_type, count(r) as count
                ORDER BY count DESC
            """).data()
            
            print("最终关系类型分布:")
            for rel in final_relations:
                print(f"  - {rel['relation_type']}: {rel['count']} 个")
    
    def remove_isolated_nodes(self):
        """删除孤立节点"""
        with self.driver.session() as session:
            session.run("MATCH (n) WHERE NOT EXISTS {(n)--()} DETACH DELETE n")
    
    def fetch_all_nodes(self) -> List[Dict[str, Any]]:
        """获取所有节点"""
        with self.driver.session() as session:
            result = session.run(
                """
                MATCH (n)
                WHERE NOT '_masked' IN labels(n) AND NOT '_GraphConfig' IN labels(n)
                RETURN toString(ID(n)) AS id, properties(n) AS properties
                """
            )
            return [record.data() for record in result]
    
    def fetch_nodes_with_names(self) -> List[Dict[str, Any]]:
        """获取带名称的节点"""
        with self.driver.session() as session:
            result = session.run(
                """
                MATCH (n) WHERE NOT '_masked' IN labels(n) AND NOT '_GraphConfig' IN labels(n)
                RETURN toString(ID(n)) AS id, n.name AS name
                """
            )
            return [{"nodeId": int(r["id"]), "node_name": r["name"] or ""} for r in result]
    
    def update_node_property(self, node_id: int, property_name: str, value: Any):
        """更新节点属性"""
        with self.driver.session() as session:
            session.run(
                f"""
                MATCH (n) WHERE ID(n) = $nid
                SET n.{property_name} = $value
                """,
                nid=node_id,
                value=value
            )
    
    def batch_update_node_properties(self, df: pd.DataFrame, 
                                   id_col: str, property_name: str, value_col: str):
        """批量更新节点属性"""
        with self.driver.session() as session:
            for _, row in df.iterrows():
                session.run(
                    f"""
                    MATCH (n) WHERE ID(n) = $nid
                    SET n.{property_name} = $value
                    """,
                    nid=int(row[id_col]),
                    value=float(row[value_col])
                )
    
    def delete_nodes_not_in_list(self, keep_node_ids: List[int]):
        """删除不在指定列表中的节点"""
        with self.driver.session() as session:
            session.run(
                "MATCH (n) WHERE NOT ID(n) IN $ids DETACH DELETE n",
                ids=keep_node_ids
            )
    
    def get_node_count(self) -> int:
        """获取节点总数"""
        with self.driver.session() as session:
            result = session.run("MATCH (n) RETURN count(n) as count")
            return result.single()["count"]
    
    def get_relationship_count(self) -> int:
        """获取关系总数"""
        with self.driver.session() as session:
            result = session.run("MATCH ()-[r]->() RETURN count(r) as count")
            return result.single()["count"]
    
    def get_relationship_type_distribution(self) -> Dict[str, int]:
        """
        获取关系类型分布统计
        
        Returns:
            Dict[str, int]: 关系类型名称到数量的映射
        """
        with self.driver.session() as session:
            result = session.run("""
                MATCH ()-[r]->()
                RETURN type(r) as relation_type, count(r) as count
                ORDER BY count DESC
            """)
            return {record["relation_type"]: record["count"] for record in result}
    
    def get_node_names(self, limit: int = 10) -> List[str]:
        """
        获取节点名称列表
        
        Args:
            limit: 返回的节点数量限制
            
        Returns:
            节点名称列表
        """
        with self.driver.session() as session:
            result = session.run(f"""
                MATCH (n)
                WHERE n.name IS NOT NULL AND n.name <> ''
                RETURN DISTINCT n.name as name
                LIMIT {limit}
            """)
            return [record["name"] for record in result]
    
    def get_nodes_without_wikipedia_summary(self, limit: int = 10) -> List[str]:
        """
        获取没有Wikipedia摘要的节点名称
        
        Args:
            limit: 返回的节点数量限制
            
        Returns:
            节点名称列表
        """
        with self.driver.session() as session:
            result = session.run(f"""
                MATCH (n)
                WHERE n.name IS NOT NULL AND n.name <> '' 
                AND (n.wikipedia_summary IS NULL OR n.wikipedia_summary = '')
                RETURN DISTINCT n.name as name
                LIMIT {limit}
            """)
            return [record["name"] for record in result]
    
    def add_wikipedia_summary_to_node(self, node_name: str, wikipedia_summary: str):
        """
        为指定节点添加Wikipedia摘要
        
        Args:
            node_name: 节点名称
            wikipedia_summary: Wikipedia摘要
        """
        with self.driver.session() as session:
            result = session.run("""
                MATCH (n {name: $node_name})
                SET n.wikipedia_summary = $wikipedia_summary
                RETURN count(n) as updated_count
            """, node_name=node_name, wikipedia_summary=wikipedia_summary)
            return result.single()["updated_count"]
    
    def update_node_toxicity(self, 
                            node_id: int, 
                            average_toxicity: float, 
                            max_toxicity: float, 
                            harmful_ratio: Optional[float] = None,
                            total_prompts: Optional[int] = None):
        """
        更新节点的毒性分数
        
        Args:
            node_id: 节点ID
            average_toxicity: 平均毒性分数
            max_toxicity: 最大毒性分数
            harmful_ratio: 有害提示词比例
            total_prompts: 总提示词数量
        """
        with self.driver.session() as session:
            # 构建SET子句
            set_properties = [
                "n.average_toxicity = $avg_toxicity",
                "n.max_toxicity = $max_toxicity",
                "n.toxicity_analyzed = true"
            ]
            
            params = {
                "node_id": node_id,
                "avg_toxicity": average_toxicity,
                "max_toxicity": max_toxicity
            }
            
            if harmful_ratio is not None:
                set_properties.append("n.harmful_ratio = $harmful_ratio")
                params["harmful_ratio"] = harmful_ratio
            
            if total_prompts is not None:
                set_properties.append("n.total_prompts = $total_prompts")
                params["total_prompts"] = total_prompts
            
            set_clause = ", ".join(set_properties)
            query = f"""
                MATCH (n) WHERE ID(n) = $node_id
                SET {set_clause}
            """
            
            session.run(query, **params)
    
    def get_node_by_id(self, node_id: int) -> Optional[Dict[str, Any]]:
        """根据ID获取节点信息"""
        with self.driver.session() as session:
            result = session.run("""
                MATCH (n) WHERE ID(n) = $node_id
                RETURN n, ID(n) as node_id
            """, node_id=node_id)
            
            record = result.single()
            if record:
                node = record["n"]
                return {
                    "id": record["node_id"],
                    "labels": list(node.labels),
                    "properties": dict(node)
                }
            return None
    
    def fix_relationship_types(self):
        """
        修复关系类型：将Wikidata属性关系转换为语义化关系名称
        从原始的wdt:P31, wdt:P279等转换为instance_of, subclass_of等
        
        这个方法解决"所有关系都叫subclass_of"的问题，通过检查原始关系属性来确定正确的语义类型
        """
        with self.driver.session() as session:
            # 检查当前关系分布
            current_rels = session.run("MATCH ()-[r]->() RETURN type(r) as rel_type, count(*) as count")
            print("修复前的关系类型分布:")
            for record in current_rels:
                print(f"  {record['rel_type']}: {record['count']}")
            
            # 如果图中已经存在了原始的wdt属性关系，我们需要将它们转换
            # 但通常n10s导入会创建语义化的关系名，所以我们可能需要不同的策略
            
            # 策略1：如果存在原始wdt关系，转换它们
            wdt_relations = [
                ("P31", "instance_of"),
                ("P279", "subclass_of"), 
                ("P361", "part_of"),
                ("P527", "has_part")
            ]
            
            for wdt_prop, semantic_name in wdt_relations:
                # 检查是否存在这种wdt关系
                check_query = f"MATCH ()-[r:`wdt:{wdt_prop}`]->() RETURN count(r) as count"
                result = session.run(check_query)
                count = result.single()['count']
                
                if count > 0:
                    print(f"转换 wdt:{wdt_prop} -> {semantic_name} ({count} 条关系)")
                    
                    # 创建新的语义化关系并删除原关系
                    convert_query = f"""
                    MATCH (a)-[old:`wdt:{wdt_prop}`]->(b)
                    CREATE (a)-[new:{semantic_name}]->(b)
                    DELETE old
                    """
                    session.run(convert_query)
            
            # 策略2：如果所有关系都是subclass_of，基于节点属性来重建正确关系
            # 检查是否所有关系都是subclass_of
            subclass_count = session.run("MATCH ()-[r:subclass_of]->() RETURN count(r) as count").single()['count']
            total_count = session.run("MATCH ()-[r]->() RETURN count(r) as count").single()['count']
            
            if subclass_count > 0 and subclass_count == total_count:
                print(f"发现所有{total_count}条关系都是subclass_of，需要基于原始数据重建正确关系类型")
                
                # 这种情况下，我们需要回到原始查询来获取正确的关系类型
                # 但这需要重新查询Wikidata，比较复杂
                print("警告：无法从当前数据重建关系类型，建议使用分类型查询重新导入")
            
            # 检查修复后的关系分布  
            final_rels = session.run("MATCH ()-[r]->() RETURN type(r) as rel_type, count(*) as count ORDER BY count DESC")
            print("修复后的关系类型分布:")
            for record in final_rels:
                print(f"  {record['rel_type']}: {record['count']}")
    
    def clean_duplicate_relationships(self):
        """
        清理重复关系：如果同一对节点之间有多种关系类型，只保留一种
        
        优先级：instance_of > subclass_of > part_of > has_part
        """
        with self.driver.session() as session:
            # 检查重复关系
            duplicate_check = session.run("""
                MATCH (a)-[r]->(b)
                WITH a, b, collect(type(r)) as rel_types, count(r) as rel_count
                WHERE rel_count > 1
                RETURN count(*) as duplicate_pairs
            """).single()['duplicate_pairs']
            
            if duplicate_pairs > 0:
                print(f"发现 {duplicate_pairs} 对节点有重复关系，开始清理...")
                
                # 删除重复关系，保持优先级
                session.run("""
                    MATCH (a)-[r]->(b)
                    WITH a, b, collect(r) as rels
                    WHERE size(rels) > 1
                    
                    // 按优先级排序：instance_of > subclass_of > part_of > has_part
                    WITH a, b, rels,
                         [rel IN rels WHERE type(rel) = 'instance_of'] as instance_rels,
                         [rel IN rels WHERE type(rel) = 'subclass_of'] as subclass_rels,
                         [rel IN rels WHERE type(rel) = 'part_of'] as part_rels,
                         [rel IN rels WHERE type(rel) = 'has_part'] as has_part_rels
                    
                    WITH a, b, rels,
                         CASE 
                           WHEN size(instance_rels) > 0 THEN instance_rels[0]
                           WHEN size(subclass_rels) > 0 THEN subclass_rels[0] 
                           WHEN size(part_rels) > 0 THEN part_rels[0]
                           ELSE has_part_rels[0]
                         END as keep_rel
                    
                    WITH rels, keep_rel, [rel IN rels WHERE rel <> keep_rel] as delete_rels
                    
                    FOREACH (rel IN delete_rels | DELETE rel)
                """)
                
                print("重复关系清理完成")
            else:
                print("没有发现重复关系")
    
    def get_relationship_type_distribution(self) -> Dict[str, int]:
        """
        获取关系类型分布统计
        
        Returns:
            Dict[str, int]: 关系类型到数量的映射
        """
        with self.driver.session() as session:
            result = session.run("""
                MATCH ()-[r]->()
                RETURN type(r) as rel_type, count(*) as count
                ORDER BY count DESC
            """)
            
            return {record["rel_type"]: record["count"] for record in result}
