import os
import json
import networkx as nx
import argparse
import asyncio
import yaml
import tiktoken
from openai import AsyncOpenAI
from collections import Counter
from typing import List, Dict, Optional
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
from concurrent.futures import ProcessPoolExecutor


GRAPH_FIELD_SEP = "<SEP>"
CONFIG_PATH = "./Option/Config2.yaml"

class GraphPrompt:
    SUMMARIZE_ENTITY_DESCRIPTIONS = """
You are a helpful assistant. Please summarize the following list of descriptions for the entity '{entity_name}' into a single, coherent paragraph.
Combine the key information and remove redundant details.

Descriptions to summarize:
{description_list}

Concise Summary:
"""
    SUMMARIZE_RELATION_DESCRIPTION = """
You are a helpful assistant. Please summarize the following list of descriptions for the relationship '{item_name}' into a single, coherent paragraph.
Combine the key information and remove redundant details.

Descriptions to summarize:
{description_list}

Concise Summary:
"""
    SUMMARIZE_KEYWORDS = """
You are a helpful assistant. The following is a long, repetitive list of keywords for the relationship '{item_name}'.
Your task is to de-duplicate the list and distill it into a concise, representative set of the most important keywords, joined by "{separator}".

Keywords to summarize:
{keyword_list}

Concise and De-duplicated Keywords:
"""

class SummarizerConfig:
    def __init__(self, model_name: str):
        self.summary_max_tokens = 2000
        self.llm_model_max_token_size = 32678
        self.summarization_model = model_name
        self.token_check_threshold = 2000

def summarize_node_sync(args):
    """Synchronous wrapper function for multiprocessing node summarization."""
    node, description, api_key, base_url, model_name = args
    summarizer = DescriptionSummarizer(api_key, base_url, model_name)
    summary = asyncio.run(summarizer.summarize(node, description, text_type='entity_description'))
    return node, summary

class DescriptionSummarizer:
    def __init__(self, api_key: str, base_url: str, model_name: str):
        if not api_key:
            raise ValueError("OpenAI API key is required for summarization.")
        self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
        self.config = SummarizerConfig(model_name)
        self.encoder = tiktoken.get_encoding("cl100k_base")

    async def summarize(self, item_name: str, text_to_summarize: str, text_type: str) -> str:
        tokens = self.encoder.encode(text_to_summarize)
        if len(tokens) < self.config.token_check_threshold:
            return text_to_summarize

        if len(tokens) > self.config.llm_model_max_token_size:
            use_text = self.encoder.decode(tokens[:self.config.llm_model_max_token_size])
        else:
            use_text = text_to_summarize
        
        if text_type == 'entity_description':
            prompt_template = GraphPrompt.SUMMARIZE_ENTITY_DESCRIPTIONS
            context_base = {"entity_name": item_name, "description_list": "\n".join(f"- {d.strip()}" for d in use_text.split(GRAPH_FIELD_SEP) if d.strip())}
        elif text_type == 'relation_description':
            prompt_template = GraphPrompt.SUMMARIZE_RELATION_DESCRIPTION
            context_base = {"item_name": item_name, "description_list": "\n".join(f"- {d.strip()}" for d in use_text.split(GRAPH_FIELD_SEP) if d.strip())}
        elif text_type == 'keywords':
            prompt_template = GraphPrompt.SUMMARIZE_KEYWORDS
            context_base = {"item_name": item_name, "keyword_list": ", ".join(sorted(set(k.strip() for k in use_text.split(GRAPH_FIELD_SEP) if k.strip()))), "separator": GRAPH_FIELD_SEP}
        else:
            print(f"   ⚠️ WARNING: Unknown text_type '{text_type}'. Skipping summarization.")
            return text_to_summarize

        prompt = prompt_template.format(**context_base)
        
        attempt=3
        while attempt>0:
            try:
                print(f" -> Triggering summarization for '{item_name}' ({text_type})...")
                response = await self.client.chat.completions.create(model=self.config.summarization_model, messages=[{"role": "user", "content": prompt}], max_tokens=self.config.summary_max_tokens, temperature=0.2)
                summary = response.choices[0].message.content
                return summary.strip() if summary else text_to_summarize
            except Exception as e:
                print(f"Attempt {attempt} failed: item_name '{item_name}': {e}.")
                attempt -= 1
        print(f"All attempts failed: item_name '{item_name}' ({text_type}). Returning original text.")
        return text_to_summarize
    
# --- MergeEntity and MergeRelationship classes remain the same as the previous "two-pass" version ---
# --- They are designed to just concatenate, which is what we want for Pass 1. ---
class MergeEntity:
    # ... (no changes from previous correct version)
    merge_keys = ["source_id", "entity_type", "description"]
    @staticmethod
    def merge_source_ids(existing_source_ids: str, new_source_ids: str): #...
        existing_list = existing_source_ids.split(GRAPH_FIELD_SEP) if existing_source_ids else []
        new_list = new_source_ids.split(GRAPH_FIELD_SEP) if new_source_ids else []
        merged_source_ids = list(set(new_list) | set(existing_list))
        return GRAPH_FIELD_SEP.join(sorted(merged_source_ids))
    @staticmethod
    def merge_types(existing_entity_types: str, new_entity_types: str): #...
        existing_list = existing_entity_types.split(GRAPH_FIELD_SEP) if existing_entity_types else []
        new_list = new_entity_types.split(GRAPH_FIELD_SEP) if new_entity_types else []
        merged_entity_types = existing_list + new_list
        entity_type_counts = Counter(merged_entity_types)
        most_common_type = entity_type_counts.most_common(1)[0][0] if entity_type_counts else ''
        return most_common_type
    @staticmethod
    def merge_descriptions(existing_descriptions: str, new_descriptions: str, summarizer: Optional[DescriptionSummarizer], entity_name: str) -> str: #...
        existing_list = existing_descriptions.split(GRAPH_FIELD_SEP) if existing_descriptions else []
        new_list = new_descriptions.split(GRAPH_FIELD_SEP) if new_descriptions else []
        merged_descriptions = list(set(new_list) | set(existing_list))
        description = GRAPH_FIELD_SEP.join(sorted(merged_descriptions))
        if summarizer:
            tokens = summarizer.encoder.encode(description)
            if len(tokens) > summarizer.config.token_check_threshold:
                return asyncio.run(summarizer.summarize(entity_name, description, text_type='entity_description'))
        return description
    @classmethod
    def merge_info(cls, existing_node_data, new_node_data, summarizer: Optional[DescriptionSummarizer] = None, entity_name: str = "Unknown"): #...
        merge_function_map = {"source_id": cls.merge_source_ids, "entity_type": cls.merge_types}
        merged_data = existing_node_data.copy()
        for key in cls.merge_keys:
            if key in existing_node_data and key in new_node_data:
                val1, val2 = existing_node_data.get(key), new_node_data.get(key)
                if key == "description":
                    merged_data[key] = cls.merge_descriptions(val1, val2, summarizer, entity_name)
                elif key in merge_function_map:
                    merged_data[key] = merge_function_map[key](val1, val2)
        return merged_data

class MergeRelationship:
    # ... (no changes from previous correct version)
    merge_keys = ["source_id", "weight", "description", "keywords", "relation_name"]
    merge_function = None
    @staticmethod
    def merge_weight(existing_weight, new_weight): #...
        return float(existing_weight or 0.0) + float(new_weight or 0.0)
    @staticmethod
    def merge_generic_field(existing_values: str, new_values: str): #...
        existing_list = existing_values.split(GRAPH_FIELD_SEP) if existing_values else []
        new_list = new_values.split(GRAPH_FIELD_SEP) if new_values else []
        return GRAPH_FIELD_SEP.join(sorted(set(existing_list + new_list)))
    @classmethod
    def merge_info(cls, existing_edge_data, new_edge_data): #...
        if cls.merge_function is None:
            cls.merge_function = {"weight": cls.merge_weight, "description": cls.merge_generic_field, "source_id": cls.merge_generic_field, "keywords": cls.merge_generic_field, "relation_name": cls.merge_generic_field}
        merged_data = existing_edge_data.copy()
        for key in cls.merge_keys:
            if key in existing_edge_data and key in new_edge_data:
                val1, val2 = existing_edge_data.get(key), new_edge_data.get(key)
                if val1 is not None and val2 is not None:
                    merged_data[key] = cls.merge_function[key](val1, val2)
        return merged_data

def build_canonical_map(pairs: List[Dict]) -> Dict[str, str]:
    # ... (no changes)
    parent = {}
    def find_set(v): #...
        if v not in parent: parent[v] = v
        if v == parent[v]: return v
        parent[v] = find_set(parent[v])
        return parent[v]
    def unite_sets(a, b): #...
        a_root, b_root = find_set(a), find_set(b)
        if a_root != b_root:
            if a_root < b_root: parent[b_root] = a_root
            else: parent[a_root] = b_root
    for pair in pairs: unite_sets(pair['node1'], pair['node2'])
    return {node: find_set(node) for node in parent}

# --- MAIN FUNCTION WITH UPDATED SUMMARIZATION PASS ---
def merge_similar_nodes(dataset_name: str, graph_path: str, similar_nodes_path: str, output_path: str, threshold_edge_reason: float, merge_type: str, process_num: Optional[int] = None):
    print("--- Starting Node Merging Process ---")
    # Step 0: Load config and initialize summarizer (Unchanged)
    summarizer = None
    print(f"\n[Step 0] Loading configuration from {CONFIG_PATH}...")
    with open(CONFIG_PATH, 'r') as f: config = yaml.safe_load(f)
    api_key = config.get('llm', {}).get('api_key')
    base_url = config.get('llm', {}).get('base_url')
    model_name = config.get('llm', {}).get('model')
    summarizer = DescriptionSummarizer(api_key, base_url, model_name)
    print(" -> Summarizer initialized successfully.")

    # Step 1 & 2: Load data and build map (Unchanged)
    print(f"\n[Step 1] Loading data...")
    G = nx.read_graphml(graph_path)
    with open(similar_nodes_path, 'r', encoding='utf-8') as f: similar_node_pairs = json.load(f)
    print(f"✅ Data loaded. Initial graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges.")

    print("\n[Step 2] Building canonical map for all similar nodes...")
    canonical_map = build_canonical_map(similar_node_pairs)
    print(f" -> Map built. {len({n for n, t in canonical_map.items() if n != t})} nodes will be merged.")

    # [Step 3 - PASS 1] Merge nodes and edges WITHOUT summarization (Unchanged)
    print("\n[Step 3] Pass 1: Merging all nodes/edges and concatenating attributes...")
    # ... (The merging loop is the same as the previous version)
    merged_count = 0
    if merge_type == "reduction_only":
        for node_to_merge, target_node in tqdm(canonical_map.items()):
            if node_to_merge == target_node: continue
            if node_to_merge in G and target_node in G:
                source_data, target_data = G.nodes[node_to_merge], G.nodes[target_node]
                merged_data = MergeEntity.merge_info(target_data, source_data, summarizer=None)
                nx.set_node_attributes(G, {target_node: merged_data})
                for neighbor in list(G.neighbors(node_to_merge)):
                    edge_data = G.get_edge_data(node_to_merge, neighbor)
                    if edge_data['tgt_id'] == node_to_merge: edge_data['tgt_id'] = target_node
                    if edge_data['src_id'] == node_to_merge: edge_data['src_id'] = target_node
                    if G.has_edge(target_node, neighbor):
                        existing_edge_data = G.get_edge_data(target_node, neighbor)
                        merged_edge_data = MergeRelationship.merge_info(existing_edge_data, edge_data)
                        G.add_edge(target_node, neighbor, **merged_edge_data)
                    else:
                        G.add_edge(target_node, neighbor, **edge_data)
                G.remove_node(node_to_merge)
                merged_count += 1
    elif merge_type == "reduction_synonym":
        for node_to_merge, target_node in tqdm(canonical_map.items()):
            if node_to_merge == target_node: continue
            if node_to_merge in G and target_node in G:
                source_data, target_data = G.nodes[node_to_merge], G.nodes[target_node]
                merged_data = MergeEntity.merge_info(target_data, source_data, summarizer=None)
                nx.set_node_attributes(G, {target_node: merged_data})
                for neighbor in list(G.neighbors(node_to_merge)):
                    edge_data = G.get_edge_data(node_to_merge, neighbor)
                    if edge_data['tgt_id'] == node_to_merge: edge_data['tgt_id'] = target_node
                    if edge_data['src_id'] == node_to_merge: edge_data['src_id'] = target_node
                    if G.has_edge(target_node, neighbor):
                        existing_edge_data = G.get_edge_data(target_node, neighbor)
                        merged_edge_data = MergeRelationship.merge_info(existing_edge_data, edge_data)
                        G.add_edge(target_node, neighbor, **merged_edge_data)
                    else:
                        G.add_edge(target_node, neighbor, **edge_data)
                    G.remove_edge(node_to_merge, neighbor)
                synonym_data = {"source_id": target_data["source_id"], "src_id": node_to_merge, "tgt_id": target_node, "relation_name": "synonym of", "keywords": f"synonym cluster={node_to_merge}", "description": f"{target_node} is the synonym of {node_to_merge}", "weight": 1.0}
                G.add_edge(node_to_merge, target_node, **synonym_data)
                merged_count += 1
    elif merge_type == "synonym_only":
        for node_to_merge, target_node in tqdm(canonical_map.items()):
            if node_to_merge == target_node: continue
            if node_to_merge in G and target_node in G and G.has_edge(node_to_merge, target_node):
                synonym_data = {"src_id": node_to_merge, "tgt_id": target_node, "relation_name": "synonym of", "keywords": f"synonym cluster={node_to_merge}", "description": f"{target_node} is the synonym of {node_to_merge}", "weight": 1.0}
                G.add_edge(node_to_merge, target_node, **synonym_data)
                merged_count += 1
    else:
        raise ValueError(f"Invalid merge type: {merge_type}")
        
    # remove_list = []
    # for src_node, tgt_node, edge_data in G.edges(data=True):
    #     if edge_data['src_id'] not in G.nodes() or edge_data['tgt_id'] not in G.nodes():
    #         remove_list.append((src_node,tgt_node))
    # for (src_node,tgt_node) in remove_list:
    #     print(f" -> Removing edge from {src_node} to {tgt_node}")
    #     G.remove_edge(src_node,tgt_node)

    print(f"✅ Merging complete. {merged_count} nodes were merged.")

    # [Step 4 - PASS 2] Summarize long attributes for nodes AND relations
    print("\n[Step 4] Pass 2: Finding and summarizing long attributes...")
    if summarizer:
        # --- Summarize NODE descriptions ---
        print(" -> Summarizing long node descriptions...")
        nodes_to_summarize = [node for node, data in G.nodes(data=True) if len(summarizer.encoder.encode(data.get("description", ""))) > summarizer.config.token_check_threshold]
        if nodes_to_summarize:
            print(f" -> Found {len(nodes_to_summarize)} nodes to summarize.")
            # Prepare arguments for multiprocessing
            summarize_args = [(node, G.nodes[node].get("description"), api_key, base_url, model_name) for node in nodes_to_summarize]
            
            # Use ProcessPoolExecutor for parallel processing
            with ProcessPoolExecutor(max_workers=process_num) as executor:
                results = list(tqdm(executor.map(summarize_node_sync, summarize_args), 
                                  total=len(summarize_args), 
                                  desc="Summarizing nodes"))
            
            # Update graph with summarized descriptions
            for node, summary in results:
                nx.set_node_attributes(G, {node: {"description": summary}})
        else:
            print(" -> No nodes required summarization.")
        
        # --- NEW: Summarize RELATION descriptions and keywords ---
        print(" -> Summarizing long relation attributes (descriptions and keywords)...")
        relations_summarized_count = 0
        tasks = []
        for u, v, data in G.edges(data=True):
            summarized_this_edge = False
            # Check and summarize description
            rel_description = data.get("description", "")
            if rel_description and len(summarizer.encoder.encode(rel_description)) > summarizer.config.token_check_threshold:
                item_name = f"Relation from '{u}' to '{v}'"
                tasks.append(summarizer.summarize(item_name, rel_description, text_type='relation_description'))
                summarized_this_edge = True

            # Check and summarize keywords
            keywords = data.get("keywords", "")
            if keywords and len(summarizer.encoder.encode(keywords)) > summarizer.config.token_check_threshold:
                item_name = f"Keywords for relation from '{u}' to '{v}'"
                tasks.append(summarizer.summarize(item_name, keywords, text_type='keywords'))
                summarized_this_edge = True

        # Run all tasks with a progress bar
        summaries = asyncio.run(tqdm_asyncio.gather(*tasks, desc="Summarizing relations"))

        # Assign summaries back to the graph
        for (u, v, data), summary in zip(G.edges(data=True), summaries):
            if "description" in data:
                G[u][v]['description'] = summary
            if "keywords" in data:
                G[u][v]['keywords'] = summary
        
        if relations_summarized_count > 0:
            print(f" -> Summarized attributes for {relations_summarized_count} relations.")
        else:
            print(" -> No relations required summarization.")

        print("✅ Summarization complete.")
    else:
        print(" -> Summarizer not available, skipping summarization step.")

    # Step 5: Remove edges with reason score less than threshold_edge_reason
    print("\n[Step 5] Removing edges with reason score less than threshold_edge_reason...")
    file_path = os.path.join(graph_path.split("graph_storage")[0], "edge_process", "edge_reason.jsonl")
    edge_data = [json.loads(line) for line in open(file_path, 'r').readlines()]
    edges_to_remove = []
    for edge in edge_data:
        try:
            if edge['score'] < threshold_edge_reason:
                edges_to_remove.append((edge['source'], edge['destination']))
        except Exception as e:
            print(f"Error loading edge data: {e}")
            continue
    for u, v in edges_to_remove:
        if G.has_edge(u, v):
            G.remove_edge(u, v)
    print(f" -> Removed {len(edges_to_remove)} edges.")

    # Step 5 & 6: Remove self-loops and save (Unchanged)
    print("\n[Step 5] Removing self-loops...")
    self_loops = list(nx.selfloop_edges(G))
    if self_loops:
        print(f" -> Found and removed {len(self_loops)} self-loops.")
        G.remove_edges_from(self_loops)
    else: print(" -> No self-loops found.")
    print(f" -> Final graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges.")
    print(f"\n[Step 6] Saving new graph to: {output_path}")
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    nx.write_graphml(G, output_path)
    print("✅ New graph saved successfully.")
    print("\n--- Process Finished ---")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Merge similar nodes and summarize attributes in a GraphML file.")
    parser.add_argument("--threshold_edge_reason", type=float, default=0.2, help="Threshold of the edge reason file (e.g., 0.2).")
    parser.add_argument("--dataset_name", type=str, default="mini_cs", help="Name of the dataset.")
    parser.add_argument("--graph_path", type=str, default="Result/mini_cs/rkg_graph/graph_storage/graph_storage_nx_data.graphml", help="Name of the graph file.")
    parser.add_argument("--similar_nodes_path", type=str, default="Result/mini_cs/rkg_graph/node_neighbors/mini_cs_0.2_0_llm_node_only_reduction_only.json", help="Name of the similar nodes file.")
    parser.add_argument("--output_path", type=str, default="Result/mini_cs/rkg_graph/graph_storage/graph_storage_nx_data_nodes_20.0.graphml", help="Name of the output graph file.")
    parser.add_argument("--merge_type", type=str, default="reduction_synonym", help="Merge type.")
    parser.add_argument("--process_num", type=int, default=64, help="Number of processes for multiprocessing (default: None uses all available CPUs).")
    args = parser.parse_args()
    merge_similar_nodes(args.dataset_name, args.graph_path, args.similar_nodes_path, args.output_path, args.threshold_edge_reason, args.merge_type, args.process_num)