import os
import json
import time
import logging

from neo4j import GraphDatabase
from dotenv import load_dotenv
from openai import OpenAI
from openai import RateLimitError , APIError , APIConnectionError

from prompt_pool.common import get_prompt

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# load env 
load_dotenv()

class LabelSystem:
    def __init__(self, label_system, uri, name, password):
        self.label_system = label_system
        self.all_labels = {}
        self.all_config_labels = {}
        self.all_branch_labels = {}
        self.output_dir = "output/new_labels"
        self.model_name = "Pro/deepseek-ai/DeepSeek-V3"
        self.max_retries = 3
        self.initial_retry_delay = 1
        self.client = OpenAI(
            api_key=os.getenv("SILLICONFLOW_API_KEY"), 
            base_url=os.getenv("SILLICONFLOW_BASE_URL"),
        )
        self.driver = GraphDatabase.driver(
            uri=uri, 
            auth=(name, password)
        )

    def get_parent_paths(self, target_name):
        """
        查找指定 name 的父辈路径结点
        """
        with self.driver.session() as session:
            query = """
            MATCH p=(root)-[:PARENT*]->(target {name: $target_name})
            WHERE NOT ()-[:PARENT]->(root)
            RETURN [node IN nodes(p) | node.name] AS path
            """
            result = session.run(query, target_name=target_name)
            path_strs = []
            for record in result:
                path = record["path"]
                path_str = " -> ".join(path)  # 将路径列表转换为字符串
                path_strs.append(path_str)
            return path_strs
        
    def get_children_by_name(self, branch_name):
        """Get all sub-nodes by Branch name"""
        with self.driver.session() as session:
            result = session.run(
                """
                MATCH (parent {name: $branch_name})-[:PARENT]->(child)
                RETURN child
                """,
                branch_name=branch_name
            )
            children = [record["child"] for record in result]
            logging.info(f"Found {len(children)} children for menu: {branch_name}")
            return children
        
    def get_children_count_by_name(self, branch_name):
        """Get the count of all sub - nodes by Branch name"""
        with self.driver.session() as session:
            result = session.run(
                """
                MATCH (parent {name: $branch_name})-[:PARENT]->(child)
                RETURN COUNT(child) AS child_count
                """,
                branch_name=branch_name
            )
            record = result.single()
            if record:
                child_count = record["child_count"]
                logging.info(f"Found {child_count} children for menu: {branch_name}")
                return child_count
            return 0

    def get_parent_by_name(self, child_name):
        """Get the parent node by child node name"""
        with self.driver.session() as session:
            result = session.run(
                """
                MATCH (parent)-[:PARENT]->(child {name: $child_name})
                RETURN parent.name
                """,
                child_name=child_name
            )
            parents = [record["parent.name"] for record in result]
            logging.info(f"Found {len(parents)} parents for child: {child_name}")
            return parents
        
    def get_node_labels(self, node_name):
        """get the type of the node"""
        with self.driver.session() as session:
            result = session.run(
                """
                MATCH (node {name: $node_name})
                RETURN labels(node) as labels
                """,
                node_name=node_name
            )
            record = result.single()
            return record["labels"] if record else []
        
    def get_description(self, node_name):
        """Get the information of a Config node"""
        with self.driver.session() as session:
            result = session.run(
                """
                MATCH (c:Config {name: $node_name})
                RETURN c.help as help
                """,
                node_name=node_name
            )
            record = result.single()
            if record:
                return {
                    "help": record["help"],
                }
            return None


    def get_llm_completion(self, prompt, system_message=None):
        """
            request llm api and get response
        """
        messages = []
        if system_message:
            messages.append({'role': 'system', 'content': system_message})
        messages.append({'role': 'user', 'content': prompt})

        retry_delay = self.initial_retry_delay
        for attempt in range(self.max_retries):
            try:
                completion = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=messages,
                )
                return completion.choices[0].message.content
            except (RateLimitError, APIError, APIConnectionError) as e:
                logging.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {retry_delay} seconds...")
                time.sleep(retry_delay)
                retry_delay *= 2  # Exponential backoff
            except Exception as e:
                logging.error(f"LLM API request failed with unexpected error: {e}")
                raise

        raise Exception(f"Max retries ({self.max_retries}) reached. Unable to complete the request.")
    
    def _parse_llm_json_response(self, result, context_info=""):
        """parse the JSON response from LLM

        Args:
            result: original response text from LLM
            context_info: context information for logging
            
        Returns:
            dict: parsed JSON object, empty dict if parsing fails
        """
        try:
            labels = json.loads(result)
            logging.info(f"Successfully parsed LLM response for {context_info}")
            return labels
        except json.JSONDecodeError:
            import re
            json_str = re.search(r'\{[\s\S]*\}', result)
            if json_str:
                try:
                    return json.loads(json_str.group())
                except json.JSONDecodeError as e:
                    logging.warning(f"Failed to parse extracted JSON from LLM response for {context_info}")
                    return {}
            else:
                logging.error(f"Failed to parse LLM response as JSON for {context_info}")
                logging.error(f"LLM response content: {result}")
                return {}
            
  
    def _save_json_to_file(self, data, filename_prefix):
        """A General Method for Saving Data as a JSON File"""
        os.makedirs(self.output_dir, exist_ok=True)
        output_file = os.path.join(self.output_dir, f"{filename_prefix}.json")
        
        try:
            with open(output_file, "w", encoding="utf-8") as f:
                json.dump(data, f, indent=4, ensure_ascii=False)
            logging.info(f"Data written to: {output_file}")
        except Exception as e:
            logging.error(f"Saving file failed: {e}") 


    def process_branch_configs(self, configs, level_path):
        branch_label_state = {}
        for config in configs:
            #Label configs combine help information and hierarchical paths to assign possible labels, label classification paths, confidence levels/reasons
            config_name = config.get("name")
            if not config_name: continue
            help_info = config.get("help", "")
            
            #sibling configs label status
            sib_label_state = {}
            near_items_state = list(branch_label_state.items())[-5:]
            for key, values in near_items_state:  
                new_values = []
                for value in values:
                    new_value = {
                        "label_name": value["label_name"],
                        "suggested_label_path": value["suggested_label_path"],
                        "confidence": value["confidence"]
                    }
                    new_values.append(new_value)
                if new_values:
                    sib_label_state[key] = new_values
                
                
            try:
                config_info = f"{config_name}: \n description: {help_info}"
                if sib_label_state:
                    sibling_labels_str = json.dumps(sib_label_state)
                else:
                    sibling_labels_str = ""
                
                config_prompt = get_prompt("config_label").format(
                    basic_labels = self.label_system,
                    sibling_labels = sibling_labels_str,
                    level_info = level_path,
                    config_info = config_info
                )
                
                result = self.get_llm_completion(prompt=config_prompt)
                label_data = self._parse_llm_json_response(result, "config label mining")
                if label_data and "new_labels" in label_data:
                    new_labels = label_data["new_labels"]
                    branch_label_state[config_name] = new_labels
                
            except Exception as e:
                logging.error(f"Error processing config {config_name}: {str(e)}. Sibling labels: {sib_label_state}, Prompt: {config_prompt}")
                branch_label_state[config_name] = []  
                continue
        #output: branch_label_state :{"config_name": [], "config_name": []}
        return branch_label_state
    
    def process_menu_branch(self, branch_name, level_path, parent_label):
        """
            Tag the Menu items
        """
        menu_prompt = get_prompt("menu_label").format(branch_name=branch_name, level_path=level_path, parent_label=parent_label)
        try:
            result = self.get_llm_completion(prompt=menu_prompt)
            logging.info(f"LLM response: {result}")
            return self._parse_llm_json_response(result, "config label mining")
        except Exception as e:
            logging.error(f"LLM API response error: {e}")
            return {}
        
    def process_menuconfig_choice_branch(self, prompt_file, branch_name, help_info, level_path, parent_label):
        """
            label the single config item
        """
        branch_info = f"{branch_name}: \n description: {help_info}"
        prompt = get_prompt(prompt_file).format(
            basic_labels = self.label_system,
            parent_label = parent_label,
            level_info = level_path,
            branch_info = branch_info
        )
        try:
            result = self.get_llm_completion(prompt=prompt)
            label_data = self._parse_llm_json_response(result, "config label mining")
            #logging.info(f"LLM response: {result}")
            return label_data
        except Exception as e:
            logging.error(f"LLM API response error: {e}")
            return {}
        
    def config_label_mining(self, branch_data, output_fpath="all_labels"):
        """
        Traverse the branches and assign labels to Menu or Config nodes.
        - For Menu: aggregate labels from child Configs
        - For Config: assign label based on help info and parent path
        """   
        
        for branch_name, items in branch_data.items():
            logging.info(f"Processing branch: {branch_name}")
            branch_type = self.get_node_labels(branch_name)
            level_path = self.get_parent_paths(branch_name)
            
            if not branch_type:
                logging.warning(f"No labels found for {branch_name}, skipping.")
                continue     

            if "Menu" in branch_type:  
                configs = [item for item in items if item.get("type") == "config"] 
                if not configs:
                    continue
                
                try:
                    branch_configs_label = self.process_branch_configs(configs, level_path)
                    if branch_configs_label:
                        if branch_name in self.all_labels:
                            logging.warning(f"{branch_name} already exists in all_labels. Skipping addition.")
                        else:
                            self.all_labels[branch_name] = branch_configs_label
                            logging.info(f"Added label results for Menu {branch_name}")
                except Exception as e:
                    logging.error(f"Error processing Menu node {branch_name}: {e}")
                    continue
            
            elif "Config" in branch_type or "Choice" in branch_type:
                logging.info(f"Branch {branch_name} is a Config node")
                
                branch_parent_name = self.get_parent_by_name(branch_name)
                if branch_parent_name[0]:
                    branch = self.all_labels.get(branch_parent_name[0], {})
                    curr_label = branch.get(branch_name, {})
                else: 
                    curr_label = None
        
                if not curr_label:
                    
                    parent_labels = "None"
                    help_info = self.get_description(branch_name) or ""
                    prompt_file = "menuconfig_label" if "Config" in branch_type else "choice_label"
                    try:
                        label = self.process_menuconfig_choice_branch(
                            prompt_file, branch_name, help_info, level_path, parent_labels
                        )
                        if label and "new_labels" in label:
                            self.all_labels[branch_name] = label
                            logging.info(f"Added label results for Config {branch_name}")
                    except Exception as e:
                        logging.error(f"Error processing Config node {branch_name}: {e}")
                        continue
                else:
                    logging.info(f"Branch {branch_name} already has labels: {curr_label}")
                    if not configs:
                        continue
                    try:
                        branch_configs_label = self.process_mconfig_branch_configs(configs=configs, level_path=level_path, mconfig_label=curr_label)
                        if branch_configs_label:
                            if branch_name in self.all_labels:
                                logging.warning(f"{branch_name} already exists in all_labels. Skipping addition.")
                            else:
                                self.all_labels[branch_name] = branch_configs_label
                                logging.info(f"Added label results for Menu {branch_name}")
                    except Exception as e:
                        logging.error(f"Error processing Menu node {branch_name}: {e}")
                        continue
            else:
                logging.warning(f"Unknown branch type for {branch_name}, skipping.")
                continue
        # Save the results to a JSON file
        filename_prefix = output_fpath
        self._save_json_to_file(self.all_labels, filename_prefix)
        return self.all_labels
    
    
    
    def process_mconfig_branch_configs(self, configs, level_path, mconfig_label, prompt_fp="mconfig_label"):
        branch_label_state = {}
        for config in configs:
            
            config_name = config.get("name")
            if not config_name: continue
            help_info = config.get("help", "")
            
            #sibling configs label status
            sib_label_state = {}
            near_items_state = list(branch_label_state.items())[-5:]
            for key, values in near_items_state:  
                new_values = []
                for value in values:
                    new_value = {
                        "label_name": value["label_name"],
                        "suggested_label_path": value["suggested_label_path"],
                        "confidence": value["confidence"]
                    }
                    new_values.append(new_value)
                if new_values:
                    sib_label_state[key] = new_values
                
                
            try:
                config_info = f"{config_name}: \n description: {help_info}"
                if sib_label_state:
                    sibling_labels_str = json.dumps(sib_label_state)
                else:
                    sibling_labels_str = ""
                
                config_prompt = get_prompt(prompt_fp).format(
                    basic_labels = self.label_system,
                    sibling_labels = sibling_labels_str,
                    parent_label = mconfig_label,
                    level_info = level_path,
                    config_info = config_info
                )
                
                result = self.get_llm_completion(prompt=config_prompt)
                label_data = self._parse_llm_json_response(result, "config label mining")
                if label_data and "new_labels" in label_data:
                    new_labels = label_data["new_labels"]
                    branch_label_state[config_name] = new_labels
                
            except Exception as e:
                logging.error(f"Error processing config {config_name}: {str(e)}. Sibling labels: {sib_label_state}, Prompt: {config_prompt}")
                branch_label_state[config_name] = []  
                continue
        #output: branch_label_state :{"config_name": [], "config_name": []}
        return branch_label_state
    
    
    
    def branch_mconfig_label_mining(self, branch_data, branch_label):
            
        branch_all_labels = {}
        
        for branch_name, labels in branch_label.items():
            branch_type = self.get_node_labels(branch_name)
            if "Config" in branch_type:
                level_path = self.get_parent_paths(branch_name)
                sub_nodes_count = self.get_children_count_by_name(branch_name)
                if sub_nodes_count == 0:
                    logging.warning(f"No sub-nodes found for {branch_name}.")
                    continue
                if sub_nodes_count < 10:
                    logging.warning(f"Sub-nodes count for {branch_name} is less than 10: {sub_nodes_count}.")
                    continue
                else:
                    
                    items = branch_data.get(branch_name, [])
                    sub_configs = [item for item in items if item.get("type") == "config"] 
                    mconfig_label = labels.get("new_labels", [])
                    if not sub_configs:
                        logging.warning(f"No sub-configs found for {branch_name}.")
                        return {}
                    mc_config_labels = self.process_mconfig_branch_configs(sub_configs, level_path=level_path, mconfig_label=mconfig_label)
                    if mc_config_labels:
                        branch_all_labels[branch_name] = mc_config_labels
                        logging.info(f"Added label results for {branch_name}")
            else:
                continue
        
        # Save the results to a JSON file
        filename_prefix = "branch_configs_label"
        self._save_json_to_file(branch_all_labels, filename_prefix)
        
        return branch_all_labels
    
                
                