import os
import time
import logging
import json

from dotenv import load_dotenv
from neo4j import GraphDatabase
from langchain_openai import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from openai import OpenAI
from openai import RateLimitError , APIError , APIConnectionError

from prompt_pool.common import get_prompt

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class LabelExpander:
    def __init__(self, new_label, uri, user, password, api_key, api_url, embedding_model_name):
        self.max_retries = 3
        self.initial_retry_delay = 1
        self.new_label = new_label
        self.model_name = "Pro/deepseek-ai/DeepSeek-V3"
        self.driver = GraphDatabase.driver(uri, auth=(user, password))
        self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name, model_kwargs={"device" : "cuda", "trust_remote_code": True})
        self.label_collection = self.get_label_collection_from_neo4j()
        self.chat = ChatOpenAI(api_key=api_key, base_url=api_url, model=self.model_name)
        self.client = OpenAI(api_key=api_key, base_url=api_url)

    def get_llm_completion(self, prompt, system_message=None):
        """
            request llm api and get response
        """
        messages = []
        if system_message:
            messages.append({'role': 'system', 'content': system_message})
        messages.append({'role': 'user', 'content': prompt})

        retry_delay = self.initial_retry_delay
        for attempt in range(self.max_retries):
            try:
                completion = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=messages,
                )
                return completion.choices[0].message.content
            except (RateLimitError, APIError, APIConnectionError) as e:
                logging.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {retry_delay} seconds...")
                time.sleep(retry_delay)
                retry_delay *= 2  # Exponential backoff
            except Exception as e:
                logging.error(f"LLM API request failed with unexpected error: {e}")
                raise

        raise Exception(f"Max retries ({self.max_retries}) reached. Unable to complete the request.")
    
    def _parse_llm_json_response(self, result, context_info=""):
        """parse the JSON response from LLM

        Args:
            result: original response text from LLM
            context_info: context information for logging
            
        Returns:
            dict: parsed JSON object, empty dict if parsing fails
        """
        try:
            # 尝试直接解析
            labels = json.loads(result)
            logging.info(f"Successfully parsed LLM response for {context_info}")
            return labels
        except json.JSONDecodeError:
            # 尝试从文本中提取JSON
            import re
            json_str = re.search(r'\{[\s\S]*\}', result)
            if json_str:
                try:
                    return json.loads(json_str.group())
                except json.JSONDecodeError as e:
                    logging.warning(f"Failed to parse extracted JSON from LLM response for {context_info}")
                    return {}
            else:
                logging.error(f"Failed to parse LLM response as JSON for {context_info}")
                logging.error(f"LLM response content: {result}")
                return {}

    def get_label_collection_from_neo4j(self):
        """
        Query all name attributes of nodes with the label "Tag" from the Neo4j knowledge graph to obtain the tag set.
        """
        try:
            with self.driver.session() as session:
                query = "MATCH (t:Tag) RETURN t.name"
                result = session.run(query)
                label_collection = [record["t.name"] for record in result]
            self.driver.close()
            return label_collection
        except Exception as e:
            print(f"Error fetching labels from Neo4j: {e}")
            return []

    def get_embedding(self, text):
        """
        Obtain the embedding vector of the text
        """
        return self.embeddings.embed_query(text)

    def find_similar_labels(self, top_n=5):

        vectorstore = FAISS.from_texts(self.label_collection, self.embeddings)

        similar_docs = vectorstore.similarity_search_with_score(self.new_label, k=top_n)

        top_labels = [doc[0].page_content for doc in similar_docs]
        top_similarities = [doc[1] for doc in similar_docs]

        return top_labels, top_similarities

    def labels_rerank_chain(self, label_rerank_prompt):
        labels_rerank_prompt_template = ChatPromptTemplate.from_template(label_rerank_prompt)

        labels_rerank_chain = (
            labels_rerank_prompt_template
            | self.chat
            | StrOutputParser()
        )
        return labels_rerank_chain
    
    def expand_label_subgraph(self, reranked_labels, sibling_count=3, child_count=3):
        expanded_labels = set(reranked_labels)
        try:
            with self.driver.session() as session:
                for label in reranked_labels:
                    # Find the sibling tags under the parent node (horizontal expansion) and select one randomly
                    sibling_query = """
                    MATCH (t:Tag {name: $label})-[:IS_SUBCATEGORY_OF]->(parent:Tag)<-[:IS_SUBCATEGORY_OF]-(sibling:Tag)
                    WHERE sibling.name <> $label
                    RETURN sibling.name
                    ORDER BY rand()
                    LIMIT $limit
                    """
                    sibling_result = session.run(sibling_query, label=label, limit=sibling_count)
                    expanded_labels.update(record["name"] for record in sibling_result)

                    # Find child node tags (vertical expansion) and select randomly
                    child_query = """
                        MATCH (t:Tag {name: $label})<-[:IS_SUBCATEGORY_OF]-(child:Tag)
                        RETURN child.name AS name
                        ORDER BY rand()
                        LIMIT $limit
                    """
                    child_result = session.run(child_query, label=label, limit=child_count)
                    expanded_labels.update(record["name"] for record in child_result)
                    
            self.driver.close()
        except Exception as e:
            print(f"Error expanding labels from Neo4j: {e}")
        return list(expanded_labels)
    
    def label_expand_configs(self, new_label, expanded_labels):
        """
        For each label item in expanded_labels, 
        retrieve the configuration item node corresponding to the label in the graph, 
        call the LLM to determine its relevance to the new label, 
        and return a mapping list of new labels to configuration items.
        """
        label_config_mapping = []

        try:
            with self.driver.session() as session:
                for label in expanded_labels:
                    logging.info(f"[LabelExpand] Processing label: {label}")
                    config_query = """
                    MATCH (c:Config)-[:HAS_LABEL]->(t:Tag {name: $label})
                    OPTIONAL MATCH (c)-[:HAS_LABEL]->(other:Tag)
                    RETURN c.name AS name, c.help AS help, collect(other.name) AS config_labels
                    """
                    results = session.run(config_query, label=label)
                    for record in results:
                        config_name = record.get("name")
                        config_help = record.get("help") or ""
                        config_labels = record.get("config_labels") or []

                        prompt = get_prompt("new_label_config").format(
                            new_label=new_label,
                            config_name=config_name,
                            config_help=config_help,
                            config_labels=", ".join(config_labels)
                        )

                        try:
                            response = self.get_llm_completion(prompt)
                            parsed = self._parse_llm_json_response(response, f"{config_name}")
                            if parsed and parsed.get("relevant", False):
                                label_config_mapping.append({
                                    "config": config_name,
                                    "new_label": new_label,
                                    "reason": parsed.get("reason", ""),
                                    "confidence": parsed.get("confidence", ""),
                                })
                                logging.info(f"[LabelExpand] + Relevant: {config_name}")
                            else:
                                logging.debug(f"[LabelExpand] - Irrelevant: {config_name}")
                        except Exception as llm_error:
                            logging.warning(f"LLM interaction failed for {config_name}: {llm_error}")
        except Exception as e:
            logging.error(f"[LabelExpand] Error accessing Neo4j or processing configs: {e}")

        return label_config_mapping
        

if __name__ == '__main__':
    load_dotenv()

    uri = os.getenv("NEO4J_URI")
    user = os.getenv("NEO4J_USERNAME")
    password = os.getenv("NEO4J_PASSWORD")
    api_key = os.getenv("SILLICONFLOW_API_KEY")
    api_url = os.getenv("SILLICONFLOW_BASE_URL")
    model_name = "Pro/deepseek-ai/DeepSeek-V3"
    embedding_model_name = "stella_en_1.5B_v5"
    label_expender = LabelExpander(uri, user, password, api_key, api_url, model_name, embedding_model_name)
    target = "test"
    top_labels, top_similarities = label_expender.find_similar_labels(new_label=target)
    rerank_prompt = get_prompt("label_rerank_prompt")
    labels_rerank_chain = label_expender.labels_rerank_chain(rerank_prompt)
    rerank_result = labels_rerank_chain.invoke({"target": target, "labels": top_labels})
    
    
