import os
import re
import json


from neo4j import GraphDatabase
from dotenv import load_dotenv
load_dotenv()

class GraphUpdater:
    def __init__(self, uri, user, password):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))

    def close(self):
        self.driver.close()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def read_basic_labels_from_file(self, file_path):
        try:
            with open(file_path, 'r', encoding='utf-8') as file:
                return file.read()
        except FileNotFoundError:
            print(f"ERROR: NO Found File {file_path}")
            return None
        except Exception as e:
            print(f"Error in read Label：{e}")
            return None

    def read_config_label_json_file(self, file_path):
        try:
            with open(file_path, 'r', encoding='utf-8') as file:
                return json.load(file)
        except FileNotFoundError:
            print(f"ERROR: File {file_path} not Found.")
            return None
        except json.JSONDecodeError:
            print(f"ERROR: Can not decode {file_path} JSON data")
            return None

    def create_basic_labels_and_relations(self, basic_labels_text):
        lines = basic_labels_text.strip().split('\n')
        parent_stack = []

        # 自动识别缩进单位（空格或 tab）
        indent_unit = self._detect_indent_unit(lines)

        with self.driver.session() as session:
            for line in lines:
                match = re.match(r'^(\s*)-?\s*(.+)', line)
                if not match:
                    continue

                indent_str, label_name = match.group(1), match.group(2).strip()
                level = self._calculate_level(indent_str, indent_unit)

                # 出栈到当前层级
                while parent_stack and parent_stack[-1][1] >= level:
                    parent_stack.pop()

                # 创建当前标签节点
                session.execute_write(self._create_label_node, label_name)

                # 建立层级关系
                if parent_stack:
                    parent_label = parent_stack[-1][0]
                    session.execute_write(self._create_label_relationship, parent_label, label_name)

                # 入栈
                parent_stack.append((label_name, level))

    def update_config_label(self, config_label_map):
        with self.driver.session() as session:
            for config_name, labels in config_label_map.items():
                for label_name in labels:
                    try:
                        session.execute_write(self._create_label_config_relationship, config_name, label_name)
                        print(f"为配置项 '{config_name}' 与标签 '{label_name}' 建立了 HAS_LABEL 关系。")
                    except Exception as e:
                        print(f"处理配置项 '{config_name}' 与标签 '{label_name}' 时出错: {e}")

    @staticmethod
    def _create_label_node(tx, label_name):
        query = "MERGE (l:Tag {name: $label_name}) RETURN l"
        tx.run(query, label_name=label_name)

    @staticmethod
    def _create_label_relationship(tx, parent_name, child_name):
        query = """
        MATCH (parent:Tag {name: $parent_name})
        MATCH (child:Tag {name: $child_name})
        MERGE (child)-[:IS_SUBCATEGORY_OF]->(parent)
        """
        tx.run(query, parent_name=parent_name, child_name=child_name)

    @staticmethod
    def _create_label_config_relationship(tx, config_name, label_name):
        query = """
        MERGE (c:Config {name: $config_name})
        MERGE (l:Tag {name: $label_name})
        MERGE (c)-[:HAS_LABEL]->(l)
        """
        tx.run(query, config_name=config_name, label_name=label_name)

    @staticmethod
    def _detect_indent_unit(lines):
        for line in lines:
            if match := re.match(r'^(\s+)-?\s*\S+', line):
                indent_str = match.group(1)
                if '\t' in indent_str:
                    return '\t'
                return ' ' * len(indent_str)
        return '    '  # 默认四个空格

    @staticmethod
    def _calculate_level(indent_str, indent_unit):
        if indent_unit == '\t':
            return indent_str.count('\t')
        else:
            return len(indent_str) // len(indent_unit)

def main():
    uri = os.getenv("NEO4J_URI")  # URI
    user = os.getenv("NEO4J_USERNAME")  # username
    password = os.getenv("NEO4J_PASSWORD")  # password
    file_path = "basic_labels.txt"  # Path to the file containing basic labels
    updater = GraphUpdater(uri, user, password)
    # data = updater.read_config_label_json_file(file_path)
    # if data:
    #     updater.update_config_label(data)
    basic_labels = updater.read_basic_labels_from_file(file_path)
    if basic_labels:
        updater._create_label_config_relationship(basic_labels)
    updater.close()


if __name__ == "__main__":
    main()
    