#!/usr/bin/env python3
"""
Enhanced Fast Topological Graph Module

This module creates a topological graph from NLI relationship matrices with proper
datapoint-to-code mapping tracking throughout all merging operations.

FIXED VERSION: Implements proper scoring-based merging and consistent edge accumulation.
"""

import os
import numpy as np
import pandas as pd
import networkx as nx
from typing import List, Dict, Any, Tuple, Set
from scipy import sparse
import json
from collections import defaultdict

def detect_and_save_cliques_parquet(graph: nx.DiGraph, output_dir: str):
    """
    Detect cliques and save them in optimized parquet format for fast retrieval
    
    Args:
        graph: The NetworkX graph
        output_dir: Output directory to save cliques
    """
    print("�� Detecting cliques with parquet-based storage for optimal retrieval...")
    
    # Create cliques directory
    cliques_dir = os.path.join(output_dir, "cliques")
    os.makedirs(cliques_dir, exist_ok=True)
    
    # Detect weakly connected components (cliques)
    components = list(nx.weakly_connected_components(graph))
    
    print(f"   Found {len(components)} connected components")
    
    # Prepare data structures for parquet storage
    clique_nodes_data = []
    clique_edges_data = []
    clique_hierarchies_data = []
    clique_metadata_data = []
    
    for i, component in enumerate(components):
        if len(component) > 1:  # Only process components with more than 1 node
            # Create subgraph for this component
            component_graph = graph.subgraph(component)
            
            # 1. Clique metadata
            clique_metadata_data.append({
                'clique_id': i,
                'size': len(component),
                'num_edges': component_graph.number_of_edges(),
                'density': nx.density(component_graph),
                'is_strongly_connected': nx.is_strongly_connected(component_graph)
            })
            
            # 2. Node-clique relationships
            for node in component:
                clique_nodes_data.append({
                    'clique_id': i,
                    'node': node,
                    'clique_size': len(component)
                })
            
            # 3. Edge relationships
            for edge in component_graph.edges():
                clique_edges_data.append({
                    'clique_id': i,
                    'source': edge[0],
                    'target': edge[1]
                })
            
            # 4. Hierarchical relationships
            for node in component:
                # Get parents (nodes that have A into B relationship with current node)
                parents = list(component_graph.predecessors(node))
                
                # Get children (nodes that current node has A into B relationship with)
                children = list(component_graph.successors(node))
                
                # Get siblings (nodes that have A into B relationship with the same parents)
                siblings = set()
                for parent in parents:
                    parent_children = list(component_graph.successors(parent))
                    siblings.update(parent_children)
                siblings.discard(node)  # Remove self from siblings
                
                clique_hierarchies_data.append({
                    'clique_id': i,
                    'node': node,
                    'parents': json.dumps(parents),  # Store as JSON string for parquet
                    'children': json.dumps(children),
                    'siblings': json.dumps(list(siblings)),
                    'num_parents': len(parents),
                    'num_children': len(children),
                    'num_siblings': len(siblings)
                })
            
            print(f"   Processed clique {i}: {len(component)} nodes, {component_graph.number_of_edges()} edges")
    
    # Save all data as parquet files
    if clique_metadata_data:
        # 1. Clique metadata
        metadata_df = pd.DataFrame(clique_metadata_data)
        metadata_df.to_parquet(os.path.join(cliques_dir, "clique_metadata.parquet"), index=False)
        
        # 2. Node-clique relationships
        nodes_df = pd.DataFrame(clique_nodes_data)
        nodes_df.to_parquet(os.path.join(cliques_dir, "clique_nodes.parquet"), index=False)
        
        # 3. Edge relationships
        edges_df = pd.DataFrame(clique_edges_data)
        edges_df.to_parquet(os.path.join(cliques_dir, "clique_edges.parquet"), index=False)
        
        # 4. Hierarchical relationships
        hierarchies_df = pd.DataFrame(clique_hierarchies_data)
        hierarchies_df.to_parquet(os.path.join(cliques_dir, "clique_hierarchies.parquet"), index=False)
        
        print(f"   ✅ Saved {len(clique_metadata_data)} cliques in parquet format to {cliques_dir}")
        print(f"   📊 Files created:")
        print(f"      • clique_metadata.parquet: {len(clique_metadata_data)} records")
        print(f"      • clique_nodes.parquet: {len(clique_nodes_data)} records")
        print(f"      • clique_edges.parquet: {len(clique_edges_data)} records")
        print(f"      • clique_hierarchies.parquet: {len(clique_hierarchies_data)} records")
    else:
        print("   ⚠️ No cliques found (all components have size 1)")

def load_clique_data_for_retrieval(cliques_dir: str) -> Dict[str, pd.DataFrame]:
    """
    Load clique data for fast retrieval operations
    
    Args:
        cliques_dir: Directory containing clique parquet files
        
    Returns:
        Dictionary with loaded DataFrames
    """
    print("📊 Loading clique data for retrieval...")
    
    data = {}
    
    # Load all parquet files
    files_to_load = [
        'clique_metadata.parquet',
        'clique_nodes.parquet', 
        'clique_edges.parquet',
        'clique_hierarchies.parquet'
    ]
    
    for filename in files_to_load:
        filepath = os.path.join(cliques_dir, filename)
        if os.path.exists(filepath):
            df_name = filename.replace('.parquet', '').replace('clique_', '')
            data[df_name] = pd.read_parquet(filepath)
            print(f"   ✅ Loaded {filename}: {len(data[df_name])} records")
        else:
            print(f"   ⚠️ File not found: {filename}")
    
    return data

def get_cliques_for_node(node: str, clique_data: Dict[str, pd.DataFrame]) -> List[int]:
    """
    Get all clique IDs that contain a specific node
    
    Args:
        node: Node to search for
        clique_data: Loaded clique data
        
    Returns:
        List of clique IDs containing the node
    """
    if 'nodes' not in clique_data:
        return []
    
    node_cliques = clique_data['nodes'][clique_data['nodes']['node'] == node]['clique_id'].tolist()
    return node_cliques

def get_node_relationships(node: str, clique_id: int, clique_data: Dict[str, pd.DataFrame]) -> Dict[str, Any]:
    """
    Get hierarchical relationships for a node in a specific clique
    
    Args:
        node: Node to get relationships for
        clique_id: Clique ID
        clique_data: Loaded clique data
        
    Returns:
        Dictionary with relationships
    """
    if 'hierarchies' not in clique_data:
        return {}
    
    # Filter for the specific node and clique
    node_data = clique_data['hierarchies'][
        (clique_data['hierarchies']['node'] == node) & 
        (clique_data['hierarchies']['clique_id'] == clique_id)
    ]
    
    if node_data.empty:
        return {}
    
    row = node_data.iloc[0]
    return {
        'parents': json.loads(row['parents']),
        'children': json.loads(row['children']),
        'siblings': json.loads(row['siblings']),
        'num_parents': row['num_parents'],
        'num_children': row['num_children'],
        'num_siblings': row['num_siblings']
    }

def get_clique_edges(clique_id: int, clique_data: Dict[str, pd.DataFrame]) -> List[tuple]:
    """
    Get all edges in a specific clique
    
    Args:
        clique_id: Clique ID
        clique_data: Loaded clique data
        
    Returns:
        List of (source, target) tuples
    """
    if 'edges' not in clique_data:
        return []
    
    clique_edges = clique_data['edges'][clique_data['edges']['clique_id'] == clique_id]
    return [(row['source'], row['target']) for _, row in clique_edges.iterrows()]



class DatapointMapper:
    """Handles datapoint-to-code mapping throughout all operations"""
    
    def __init__(self, original_mapping: Dict[str, List[str]] = None):
        """
        Initialize with original datapoint-to-code mapping
        
        Args:
            original_mapping: Dict mapping code -> list of datapoints
        """
        self.original_mapping = original_mapping or {}
        self.current_mapping = original_mapping.copy() if original_mapping else {}
        self.merge_history = []  # Track all merge operations
        self.removed_codes = set()  # Track codes that were removed/merged
        self.merging_frequencies = {}  # Track edge counts for each code
        
    def add_datapoint_mapping(self, code: str, datapoints: List[str]):
        """Add datapoint mapping for a code"""
        if code not in self.current_mapping:
            self.current_mapping[code] = []
        self.current_mapping[code].extend(datapoints)
        # Remove duplicates
        self.current_mapping[code] = list(set(self.current_mapping[code]))
    
    def update_merging_frequency(self, code: str, edge_count: int):
        """Update the merging frequency (incoming edge count) for a code"""
        self.merging_frequencies[code] = edge_count
    
    def get_global_frequency(self, code: str) -> int:
        """Get global frequency (number of datapoints containing this code)"""
        return len(self.current_mapping.get(code, []))
    
    def get_merging_frequency(self, code: str) -> int:
        """Get merging frequency (number of incoming edges to this node)"""
        return self.merging_frequencies.get(code, 0)
    
    def get_current_node_properties(self, code: str) -> Dict[str, Any]:
        """Get real-time properties of a node for decision making"""
        global_freq = self.get_global_frequency(code)
        incoming_edges = self.get_merging_frequency(code)
        return {
            'global_frequency': global_freq,
            'incoming_edges': incoming_edges,
            'merge_score': 0.3 * global_freq + 0.7 * incoming_edges,
            'exists': code in self.current_mapping
        }
    
    def merge_codes(self, source_codes: List[str], target_code: str, merge_type: str = "unknown"):
        """
        Merge multiple codes into a target code, updating datapoint mappings
        
        Args:
            source_codes: List of codes to merge
            target_code: The code that will represent the merged group
            merge_type: Type of merge (e.g., "mutual", "frequency", "manual")
        """
        # Collect all datapoints from source codes
        merged_datapoints = []
        merged_merging_freq = 0
        
        for source_code in source_codes:
            if source_code in self.current_mapping:
                merged_datapoints.extend(self.current_mapping[source_code])
                # Add incoming edge frequency
                merged_merging_freq += self.merging_frequencies.get(source_code, 0)
                # Mark as removed if different from target
                if source_code != target_code:
                    self.removed_codes.add(source_code)
                    del self.current_mapping[source_code]
                    if source_code in self.merging_frequencies:
                        del self.merging_frequencies[source_code]
        
        # Add target code's existing edges to the total
        if target_code in self.merging_frequencies:
            merged_merging_freq += self.merging_frequencies[target_code]
        
        # Add to target code
        if target_code not in self.current_mapping:
            self.current_mapping[target_code] = []
        self.current_mapping[target_code].extend(merged_datapoints)
        # Remove duplicates
        self.current_mapping[target_code] = list(set(self.current_mapping[target_code]))
        
        # Update incoming edge frequency for target code
        self.merging_frequencies[target_code] = merged_merging_freq
        
        # Record merge operation
        merge_record = {
            'source_codes': source_codes,
            'target_code': target_code,
            'merge_type': merge_type,
            'datapoint_count': len(merged_datapoints),
            'unique_datapoint_count': len(self.current_mapping[target_code]),
            'global_frequency': len(self.current_mapping[target_code]),
            'incoming_edge_frequency': merged_merging_freq
        }
        self.merge_history.append(merge_record)
        
        # Removed verbose merge printing
    
    def get_current_codes(self) -> List[str]:
        """Get list of current active codes"""
        return list(self.current_mapping.keys())
    
    def get_datapoints_for_code(self, code: str) -> List[str]:
        """Get datapoints associated with a code"""
        return self.current_mapping.get(code, [])
    
    def validate_mapping(self, expected_codes: Set[str]) -> Dict[str, Any]:
        """Validate that mapping is consistent with expected codes"""
        current_codes = set(self.current_mapping.keys())
        
        validation = {
            'missing_codes': expected_codes - current_codes,
            'extra_codes': current_codes - expected_codes,
            'total_datapoints': sum(len(dp) for dp in self.current_mapping.values()),
            'codes_with_no_datapoints': [code for code, dp in self.current_mapping.items() if len(dp) == 0],
            'is_valid': len(expected_codes - current_codes) == 0 and len(current_codes - expected_codes) == 0
        }
        
        return validation
    
    def save_mapping_report(self, output_path: str):
        """Save detailed mapping report"""
        report_data = {
            'merge_history': self.merge_history,
            'removed_codes': list(self.removed_codes),
            'current_mapping_summary': {
                code: len(datapoints) for code, datapoints in self.current_mapping.items()
            },
            'total_codes': len(self.current_mapping),
            'total_datapoints': sum(len(dp) for dp in self.current_mapping.values())
        }
        
        with open(output_path, 'w') as f:
            json.dump(report_data, f, indent=2)


class EnhancedTopologicalGraphBuilder:
    """Enhanced topological graph builder with proper datapoint tracking"""
    
    def __init__(self, datapoint_mapper: DatapointMapper = None):
        self.graph = None
        self.relationship_matrix = None
        self.unique_codes = None
        self.code_to_idx = None
        self.datapoint_mapper = datapoint_mapper or DatapointMapper()
    
    def _calculate_merge_score(self, code: str) -> float:
        """Calculate merge score: 0.3 * global_freq + 0.7 * incoming_edges"""
        properties = self.datapoint_mapper.get_current_node_properties(code)
        return properties['merge_score']
    
    def _select_representative_by_score(self, code_a: str, code_b: str) -> str:
        """Select representative based on weighted scoring (w1=0.3, w2=0.7)"""
        score_a = self._calculate_merge_score(code_a)
        score_b = self._calculate_merge_score(code_b)
        
        print(f"     Scoring: {code_a} ({score_a:.2f}) vs {code_b} ({score_b:.2f})")
        
        if score_a >= score_b:
            return code_a
        else:
            return code_b
    
    def _accumulate_incoming_edges(self, source_codes: List[str], target_code: str):
        """Accumulate incoming edges from source codes to target code"""
        total_incoming = 0
        
        # Sum incoming edges from all source codes
        for source_code in source_codes:
            total_incoming += self.datapoint_mapper.get_merging_frequency(source_code)
        
        # Add target's existing incoming edges if it's not in source codes
        if target_code not in source_codes:
            total_incoming += self.datapoint_mapper.get_merging_frequency(target_code)
        
        # Update target with accumulated edges
        self.datapoint_mapper.update_merging_frequency(target_code, total_incoming)
        
        return total_incoming
    def _update_edge_frequencies_after_merge(self):
        """Update edge frequencies based on current graph structure after merge operations"""
        # Create temporary graph from current matrix to calculate accurate edge frequencies
        temp_graph = nx.DiGraph()
        temp_graph.add_nodes_from(self.unique_codes)
        
        # Add edges from current matrix
        for i in range(len(self.unique_codes)):
            for j in range(len(self.unique_codes)):
                if i != j and self.relationship_matrix[i, j] > 0:
                    temp_graph.add_edge(self.unique_codes[i], self.unique_codes[j])
        
        # Update merging frequencies based on incoming edges
        for node in temp_graph.nodes():
            incoming_edge_count = temp_graph.in_degree(node)
            self.datapoint_mapper.update_merging_frequency(node, incoming_edge_count)
        
        print(f"   Updated edge frequencies for {len(temp_graph.nodes())} nodes after merge")

    @classmethod
    def from_corpus(cls, corpus_path: str):
        """Create builder from corpus file"""
        print(f"📊 Loading corpus from {corpus_path}")
        
        if not os.path.exists(corpus_path):
            raise FileNotFoundError(f"Corpus file not found: {corpus_path}")
        
        corpus_df = pd.read_parquet(corpus_path)
        
        # Create original datapoint mapping
        original_mapping = defaultdict(list)
        for _, row in corpus_df.iterrows():
            code = row['tag']
            datapoint = row['chunk_text']
            original_mapping[code].append(datapoint)
        
        # Convert to regular dict
        original_mapping = dict(original_mapping)
        
        print(f"   Loaded {len(original_mapping)} codes with {sum(len(dp) for dp in original_mapping.values())} datapoints")
        
        datapoint_mapper = DatapointMapper(original_mapping)
        return cls(datapoint_mapper)
    
    def load_relationship_matrix(self, matrix_path: str, codes_path: str):
        """Load relationship matrix and code mapping"""
        print("📊 Loading relationship matrix...")
        
        # Load matrix
        self.relationship_matrix = np.load(matrix_path)
        
        # Load codes
        codes_df = pd.read_parquet(codes_path)
        self.unique_codes = codes_df['code'].tolist()
        self.code_to_idx = {code: idx for idx, code in enumerate(self.unique_codes)}
        
        print(f"   Loaded {self.relationship_matrix.shape[0]}x{self.relationship_matrix.shape[1]} matrix")
        print(f"   Loaded {len(self.unique_codes)} unique codes")
        
        # Initialize incoming edge frequencies from the matrix
        # Treat mutual relationships as incoming edges for both nodes
        print("   Initializing incoming edge frequencies from matrix (mutual = incoming for both)...")
        for j, code in enumerate(self.unique_codes):
            incoming_edges = 0
            for i in range(len(self.unique_codes)):
                if self.relationship_matrix[i, j] > 0:  # There's a relationship from i to j
                    incoming_edges += 1
                    # If it's mutual (both directions), count it as incoming for both
                    if self.relationship_matrix[j, i] > 0:  # Mutual relationship
                        # This edge is already counted above, but we need to ensure
                        # that when we process the reverse direction, we don't double count
                        pass
            self.datapoint_mapper.update_merging_frequency(code, int(incoming_edges))
        
        # Validate that we have datapoint mappings for all codes
        current_codes = set(self.datapoint_mapper.get_current_codes())
        matrix_codes = set(self.unique_codes)
        
        validation = self.datapoint_mapper.validate_mapping(matrix_codes)
        if not validation['is_valid']:
            print(f"   ⚠️ Mapping validation issues:")
            if validation['missing_codes']:
                print(f"      Missing codes: {len(validation['missing_codes'])}")
            if validation['extra_codes']:
                print(f"      Extra codes: {len(validation['extra_codes'])}")
    
    def merge_mutual_relationships(self):
        """Merge nodes that have mutual relationships with proper scoring-based selection"""
        print("🔄 Merging mutual relationships with scoring-based selection...")
        
        if self.relationship_matrix is None or self.unique_codes is None:
            raise ValueError("Matrix and codes must be loaded first")
        
        # Find mutual relationships
        mutual_pairs = []
        for i in range(len(self.unique_codes)):
            for j in range(i + 1, len(self.unique_codes)):
                if (self.relationship_matrix[i, j] > 0 and 
                    self.relationship_matrix[j, i] > 0):
                    mutual_pairs.append((self.unique_codes[i], self.unique_codes[j]))
        
        print(f"   Found {len(mutual_pairs)} mutual relationship pairs")
        
        if not mutual_pairs:
            return self.relationship_matrix, self.unique_codes
        
        # Process mutual pairs for merging with scoring
        processed_codes = set()
        merge_groups = []
        
        for code_a, code_b in mutual_pairs:
            if code_a in processed_codes or code_b in processed_codes:
                continue
            
            # Use scoring to choose representative (w1=w2=0.5)
            representative = self._select_representative_by_score(code_a, code_b)
            source_codes = [code_a, code_b]
            
            # Merge in datapoint mapper with proper edge accumulation
            self.datapoint_mapper.merge_codes(source_codes, representative, "mutual")
            
            merge_groups.append((source_codes, representative))
            processed_codes.update(source_codes)
        
        # Create new matrix with representative nodes
        new_codes = []
        old_to_new_mapping = {}
        
        # Add representative codes and non-merged codes
        for code in self.unique_codes:
            if code in processed_codes:
                # Find which group this code belongs to
                for source_codes, representative in merge_groups:
                    if code in source_codes:
                        old_to_new_mapping[code] = representative
                        if representative not in new_codes:
                            new_codes.append(representative)
                        break
            else:
                old_to_new_mapping[code] = code
                new_codes.append(code)
        
        # Create new matrix
        n_new = len(new_codes)
        new_to_idx = {code: idx for idx, code in enumerate(new_codes)}
        merged_matrix = np.zeros((n_new, n_new), dtype=int)
        
        # Map relationships to new matrix
        for i, code_i in enumerate(self.unique_codes):
            for j, code_j in enumerate(self.unique_codes):
                if self.relationship_matrix[i, j] > 0:
                    new_i_code = old_to_new_mapping[code_i]
                    new_j_code = old_to_new_mapping[code_j]
                    
                    if new_i_code != new_j_code:  # Skip self-loops
                        new_i_idx = new_to_idx[new_i_code]
                        new_j_idx = new_to_idx[new_j_code]
                        merged_matrix[new_i_idx, new_j_idx] = 1
        
        # Merge operation completed (print removed)
        
        # Update internal state
        self.relationship_matrix = merged_matrix
        self.unique_codes = new_codes
        self.code_to_idx = {code: idx for idx, code in enumerate(new_codes)}
        
        # Update edge frequencies after merge
        self._update_edge_frequencies_after_merge()
        
        return merged_matrix, new_codes

    def merge_directional_relationships(self):
        """Merge codes based on directional relationships with proper edge accumulation"""
        print("🔄 Merging directional relationships with edge accumulation...")
        
        if self.relationship_matrix is None or self.unique_codes is None:
            raise ValueError("Matrix and codes must be loaded first")
        
        # Find directional relationships where one code points to another but not vice versa
        directional_pairs = []
        for i in range(len(self.unique_codes)):
            for j in range(len(self.unique_codes)):
                if i != j:
                    # Check if i points to j but j doesn't point to i (directional relationship)
                    if (self.relationship_matrix[i, j] > 0 and 
                        self.relationship_matrix[j, i] == 0):
                        source_code = self.unique_codes[i]
                        target_code = self.unique_codes[j]
                        directional_pairs.append((source_code, target_code))
        
        print(f"   Found {len(directional_pairs)} directional relationship pairs")
        
        if not directional_pairs:
            return self.relationship_matrix, self.unique_codes
        
        # Process directional pairs for merging
        processed_codes = set()
        merge_operations = []
        
        for source_code, target_code in directional_pairs:
            if source_code in processed_codes or target_code in processed_codes:
                continue
            
            # Merge source into target (source_code into target_code)
            merge_operations.append((source_code, target_code))
            processed_codes.add(source_code)
            processed_codes.add(target_code)
        
        # Apply merge operations with proper edge accumulation
        for source_code, target_code in merge_operations:
            self.datapoint_mapper.merge_codes([source_code], target_code, "directional")
        
        # Create new matrix with merged nodes
        new_codes = []
        old_to_new_mapping = {}
        
        # Add target codes and non-merged codes
        for code in self.unique_codes:
            if code in processed_codes:
                # Find which merge operation this code belongs to
                for source_code, target_code in merge_operations:
                    if code == source_code:
                        old_to_new_mapping[code] = target_code
                        if target_code not in new_codes:
                            new_codes.append(target_code)
                        break
                    elif code == target_code:
                        old_to_new_mapping[code] = target_code
                        if target_code not in new_codes:
                            new_codes.append(target_code)
                        break
            else:
                old_to_new_mapping[code] = code
                new_codes.append(code)
        
        # Create new matrix
        n_new = len(new_codes)
        new_to_idx = {code: idx for idx, code in enumerate(new_codes)}
        merged_matrix = np.zeros((n_new, n_new), dtype=int)
        
        # Map relationships to new matrix
        for i, code_i in enumerate(self.unique_codes):
            for j, code_j in enumerate(self.unique_codes):
                if self.relationship_matrix[i, j] > 0:
                    new_i_code = old_to_new_mapping[code_i]
                    new_j_code = old_to_new_mapping[code_j]
                    
                    if new_i_code != new_j_code:  # Skip self-loops
                        new_i_idx = new_to_idx[new_i_code]
                        new_j_idx = new_to_idx[new_j_code]
                        merged_matrix[new_i_idx, new_j_idx] = 1
        
        # Merge operation completed (print removed)
        
        # Update internal state
        self.relationship_matrix = merged_matrix
        self.unique_codes = new_codes
        self.code_to_idx = {code: idx for idx, code in enumerate(new_codes)}
        
        # Update edge frequencies after merge
        self._update_edge_frequencies_after_merge()
        
        return merged_matrix, new_codes

    def merge_low_frequency_codes(self, min_frequency: int = 2):
        """Merge codes with 0 or fewer incoming edges, properly accumulating incoming edges"""
        print(f"🔄 Merging codes with 0 or fewer incoming edges (min_frequency={min_frequency})...")
        
        # Get current merging frequencies (incoming edge counts) from datapoint mapper
        code_merging_frequencies = {}
        for code in self.datapoint_mapper.get_current_codes():
            merging_freq = self.datapoint_mapper.get_merging_frequency(code)
            code_merging_frequencies[code] = merging_freq
        
        # Find codes with 0 or fewer incoming edges that exist in our matrix
        zero_or_negative_freq_codes = []
        positive_freq_codes = []
        
        for code in self.unique_codes:
            merging_freq = code_merging_frequencies.get(code, 0)
            if merging_freq <= 5:  # Only merge codes with 5 or fewer incoming edges
                zero_or_negative_freq_codes.append(code)
            else:
                positive_freq_codes.append(code)
        
        print(f"   Found {len(zero_or_negative_freq_codes)} codes with 0 or fewer incoming edges")
        
        if not zero_or_negative_freq_codes:
            return self.relationship_matrix, self.unique_codes
        
        # For each zero/negative frequency code, find the best target based on relationship direction
        merge_operations = []
        
        for zero_freq_code in zero_or_negative_freq_codes:
            best_match = None
            max_relationships = 0
            
            zero_idx = self.code_to_idx[zero_freq_code]
            
            # First, try to find a code that this zero_freq_code points to (outgoing relationship)
            for target_code in positive_freq_codes:
                if target_code in self.code_to_idx:
                    target_idx = self.code_to_idx[target_code]
                    
                    # Check if zero_freq_code points to target_code (outgoing relationship)
                    if self.relationship_matrix[zero_idx, target_idx] > 0:
                        relationships = self.relationship_matrix[zero_idx, target_idx]
                        if relationships > max_relationships:
                            max_relationships = relationships
                            best_match = target_code
            
            # If no outgoing relationship found, try incoming relationships (reverse direction)
            if not best_match:
                for target_code in positive_freq_codes:
                    if target_code in self.code_to_idx:
                        target_idx = self.code_to_idx[target_code]
                        
                        # Check if target_code points to zero_freq_code (incoming relationship)
                        if self.relationship_matrix[target_idx, zero_idx] > 0:
                            relationships = self.relationship_matrix[target_idx, zero_idx]
                            if relationships > max_relationships:
                                max_relationships = relationships
                                best_match = target_code
            
            # If still no match, try bidirectional relationships
            if not best_match:
                for target_code in positive_freq_codes:
                    if target_code in self.code_to_idx:
                        target_idx = self.code_to_idx[target_code]
                        
                        # Count bidirectional relationships
                        relationships = (self.relationship_matrix[zero_idx, target_idx] + 
                                       self.relationship_matrix[target_idx, zero_idx])
                        
                        if relationships > max_relationships:
                            max_relationships = relationships
                            best_match = target_code
            
            if best_match and max_relationships > 0:
                merge_operations.append((zero_freq_code, best_match))
            else:
                # Keep as separate node if no good match found
                positive_freq_codes.append(zero_freq_code)
        
        # Apply merge operations with proper edge accumulation
        for zero_freq_code, target_code in merge_operations:
            self.datapoint_mapper.merge_codes([zero_freq_code], target_code, "frequency")
        
        # Create new matrix
        final_codes = list(set(positive_freq_codes + [target for _, target in merge_operations]))
        final_codes = [code for code in final_codes if code in self.datapoint_mapper.get_current_codes()]
        
        n_final = len(final_codes)
        final_to_idx = {code: idx for idx, code in enumerate(final_codes)}
        merged_matrix = np.zeros((n_final, n_final), dtype=int)
        
        # Create mapping from old to new codes
        code_mapping = {}
        for code in self.unique_codes:
            code_mapping[code] = code  # Default to itself
        
        for zero_freq_code, target_code in merge_operations:
            code_mapping[zero_freq_code] = target_code
        
        # Map relationships to new matrix
        for i, code_i in enumerate(self.unique_codes):
            for j, code_j in enumerate(self.unique_codes):
                if self.relationship_matrix[i, j] > 0:
                    new_i_code = code_mapping[code_i]
                    new_j_code = code_mapping[code_j]
                    
                    if (new_i_code in final_to_idx and new_j_code in final_to_idx and 
                        new_i_code != new_j_code):
                        new_i_idx = final_to_idx[new_i_code]
                        new_j_idx = final_to_idx[new_j_code]
                        merged_matrix[new_i_idx, new_j_idx] = 1
        
        # Merge operation completed (print removed)
        
        # Update internal state
        self.relationship_matrix = merged_matrix
        self.unique_codes = final_codes
        self.code_to_idx = {code: idx for idx, code in enumerate(final_codes)}
        
        # Update edge frequencies after merge
        self._update_edge_frequencies_after_merge()
        
        return merged_matrix, final_codes
    
    def create_graph_from_matrix(self):
        """Create NetworkX graph from current relationship matrix"""
        print("🔗 Creating topological graph...")
        
        if self.relationship_matrix is None or self.unique_codes is None:
            raise ValueError("Matrix and codes must be available")
        
        # Create directed graph
        self.graph = nx.DiGraph()
        
        # Add nodes
        self.graph.add_nodes_from(self.unique_codes)
        
        # Add edges from matrix
        edge_count = 0
        for i in range(len(self.unique_codes)):
            for j in range(len(self.unique_codes)):
                if i != j and self.relationship_matrix[i, j] > 0:  # Skip self-loops
                    self.graph.add_edge(self.unique_codes[i], self.unique_codes[j], weight=1)
                    edge_count += 1
        
        print(f"   Added {edge_count} edges")
        print(f"   Graph has {self.graph.number_of_nodes()} nodes and {self.graph.number_of_edges()} edges")
        
        # Validate graph matches datapoint mapper
        graph_codes = set(self.graph.nodes())
        mapper_codes = set(self.datapoint_mapper.get_current_codes())
        
        validation = self.datapoint_mapper.validate_mapping(graph_codes)
        if not validation['is_valid']:
            print(f"   ⚠️ Graph-mapper mismatch detected")
        
        return self.graph

    def update_merging_frequencies_from_graph(self):
        """Update merging frequencies based on incoming edges only"""
        if self.graph is None:
            return
            
        for node in self.graph.nodes():
            # Count only incoming edges (predecessors)
            incoming_edge_count = self.graph.in_degree(node)
            self.datapoint_mapper.update_merging_frequency(node, incoming_edge_count)
        
        print(f"   Updated merging frequencies (incoming edges) for {len(self.graph.nodes())} nodes")

    def perform_topological_sort(self):
        """Perform topological sort"""
        print("📊 Performing topological sort...")
        
        if self.graph is None:
            raise ValueError("Graph not created. Call create_graph_from_matrix first.")
        
        try:
            sorted_nodes = list(nx.topological_sort(self.graph))
            print(f"   ✅ Topological sort completed: {len(sorted_nodes)} nodes")
            
            levels = {node: i for i, node in enumerate(sorted_nodes)}
            return sorted_nodes, levels
            
        except (nx.NetworkXError, nx.NetworkXUnfeasible) as e:
            print(f"   ⚠️ Topological sort failed: {e}")
            print("   Using fallback: sorting by node degree")
            
            degrees = dict(self.graph.degree())
            sorted_nodes = sorted(degrees.keys(), key=lambda x: degrees[x], reverse=True)
            levels = {node: i for i, node in enumerate(sorted_nodes)}
            
            return sorted_nodes, levels

    def save_datapoint_code_mapping_files(self, output_dir: str):
        """Save datapoint-code mapping files in the format expected by context retrievers"""
        print("💾 Saving datapoint-code mapping files...")
        
        # Create datapoint_code_mapping subdirectory
        mapping_dir = os.path.join(output_dir, "datapoint_code_mapping")
        os.makedirs(mapping_dir, exist_ok=True)
        
        # Save code frequencies (global frequency = number of datapoints per code)
        code_frequencies = []
        for code in self.datapoint_mapper.get_current_codes():
            datapoints = self.datapoint_mapper.get_datapoints_for_code(code)
            code_frequencies.append({
                'code': code,
                'frequency': len(datapoints),
                'global_frequency': len(datapoints),
                'incoming_edges': self.datapoint_mapper.get_merging_frequency(code),
                'merge_score': 0.3 * len(datapoints) + 0.7 * self.datapoint_mapper.get_merging_frequency(code),
                'level': -1  # Will be updated later if needed
            })
        
        if code_frequencies:
            freq_df = pd.DataFrame(code_frequencies)
            freq_df.to_parquet(os.path.join(mapping_dir, "code_frequencies.parquet"), index=False)
            print(f"   Saved {len(code_frequencies)} code frequencies")
        
        # Save datapoint to codes mapping
        datapoint_to_codes = []
        for code in self.datapoint_mapper.get_current_codes():
            datapoints = self.datapoint_mapper.get_datapoints_for_code(code)
            for datapoint in datapoints:
                datapoint_to_codes.append({
                    'datapoint': datapoint,
                    'code': code,
                    'global_frequency': len(datapoints),
                    'incoming_edges': self.datapoint_mapper.get_merging_frequency(code),
                    'merge_score': 0.3 * len(datapoints) + 0.7 * self.datapoint_mapper.get_merging_frequency(code)
                })
        
        if datapoint_to_codes:
            dp_to_codes_df = pd.DataFrame(datapoint_to_codes)
            dp_to_codes_df.to_parquet(os.path.join(mapping_dir, "datapoint_to_codes.parquet"), index=False)
            print(f"   Saved {len(datapoint_to_codes)} datapoint-to-codes mappings")
        
        # Save code to datapoints mapping
        code_to_datapoints = []
        for code in self.datapoint_mapper.get_current_codes():
            datapoints = self.datapoint_mapper.get_datapoints_for_code(code)
            for datapoint in datapoints:
                code_to_datapoints.append({
                    'code': code,
                    'datapoint': datapoint,
                    'global_frequency': len(datapoints),
                    'incoming_edges': self.datapoint_mapper.get_merging_frequency(code),
                    'merge_score': 0.3 * len(datapoints) + 0.7 * self.datapoint_mapper.get_merging_frequency(code)
                })
        
        if code_to_datapoints:
            codes_to_dp_df = pd.DataFrame(code_to_datapoints)
            codes_to_dp_df.to_parquet(os.path.join(mapping_dir, "code_to_datapoints.parquet"), index=False)
            print(f"   Saved {len(code_to_datapoints)} code-to-datapoints mappings")
        
        # Save reverse mapping: datapoint -> list of codes
        datapoint_codes_dict = defaultdict(list)
        for code in self.datapoint_mapper.get_current_codes():
            datapoints = self.datapoint_mapper.get_datapoints_for_code(code)
            for datapoint in datapoints:
                datapoint_codes_dict[datapoint].append(code)
        
        reverse_mapping = []
        for datapoint, codes in datapoint_codes_dict.items():
            reverse_mapping.append({
                'datapoint': datapoint,
                'codes': codes,
                'code_count': len(codes)
            })
        
        if reverse_mapping:
            reverse_df = pd.DataFrame(reverse_mapping)
            reverse_df.to_parquet(os.path.join(mapping_dir, "datapoint_codes_list.parquet"), index=False)
            print(f"   Saved {len(reverse_mapping)} datapoint -> codes list mappings")
        
        # Save mapping summary with detailed statistics
        unique_datapoints = set()
        for datapoints in self.datapoint_mapper.current_mapping.values():
            unique_datapoints.update(datapoints)
        
        summary = {
            'total_codes': len(self.datapoint_mapper.get_current_codes()),
            'total_unique_datapoints': len(unique_datapoints),
            'total_mappings': len(datapoint_to_codes),
            'avg_datapoints_per_code': sum(len(self.datapoint_mapper.get_datapoints_for_code(code)) 
                                         for code in self.datapoint_mapper.get_current_codes()) / len(self.datapoint_mapper.get_current_codes()) if self.datapoint_mapper.get_current_codes() else 0,
            'avg_codes_per_datapoint': len(datapoint_to_codes) / len(unique_datapoints) if unique_datapoints else 0,
            'frequency_range': {
                'min': min(freq['frequency'] for freq in code_frequencies) if code_frequencies else 0,
                'max': max(freq['frequency'] for freq in code_frequencies) if code_frequencies else 0
            },
            'incoming_edges_range': {
                'min': min(freq['incoming_edges'] for freq in code_frequencies) if code_frequencies else 0,
                'max': max(freq['incoming_edges'] for freq in code_frequencies) if code_frequencies else 0
            },
            'merge_score_range': {
                'min': min(freq['merge_score'] for freq in code_frequencies) if code_frequencies else 0,
                'max': max(freq['merge_score'] for freq in code_frequencies) if code_frequencies else 0
            }
        }
        
        with open(os.path.join(mapping_dir, "mapping_summary.json"), 'w') as f:
            json.dump(summary, f, indent=2)
        
        print(f"   ✅ Saved datapoint-code mapping files to {mapping_dir}")
        print(f"       - {len(code_frequencies)} codes")
        print(f"       - {len(unique_datapoints)} unique datapoints") 
        print(f"       - {len(datapoint_to_codes)} total mappings")

    def save_enhanced_results(self, sorted_nodes: List[str], levels: Dict[str, int], output_dir: str):
        """Save enhanced results with full datapoint tracking"""
        print("💾 Saving enhanced results...")
        
        os.makedirs(output_dir, exist_ok=True)
        
        # Save topological sort
        sort_data = [{'node': node, 'level': levels[node]} for node in sorted_nodes]
        sort_df = pd.DataFrame(sort_data)
        sort_df.to_parquet(os.path.join(output_dir, "topological_sort.parquet"), index=False)
        
        # Save hierarchy
        hierarchy = {}
        for node, level in levels.items():
            if level not in hierarchy:
                hierarchy[level] = []
            hierarchy[level].append(node)
        
        hierarchy_data = []
        for level, nodes in hierarchy.items():
            for node in nodes:
                hierarchy_data.append({'level': level, 'node': node})
        hierarchy_df = pd.DataFrame(hierarchy_data)
        hierarchy_df.to_parquet(os.path.join(output_dir, "hierarchy.parquet"), index=False)
        
        # Save detailed datapoint mappings with enhanced information
        datapoint_data = []
        for code in self.datapoint_mapper.get_current_codes():
            datapoints = self.datapoint_mapper.get_datapoints_for_code(code)
            global_freq = len(datapoints)
            incoming_edges = self.datapoint_mapper.get_merging_frequency(code)
            merge_score = 0.3 * global_freq + 0.7 * incoming_edges
            
            for datapoint in datapoints:
                datapoint_data.append({
                    'code': code,
                    'datapoint': datapoint,
                    'level': levels.get(code, -1),
                    'global_frequency': global_freq,
                    'incoming_edges': incoming_edges,
                    'merge_score': merge_score
                })
        
        if datapoint_data:
            datapoint_df = pd.DataFrame(datapoint_data)
            datapoint_df.to_parquet(os.path.join(output_dir, "code_datapoints_enhanced.parquet"), index=False)
            print(f"   Saved {len(datapoint_data)} enhanced datapoint mappings")
        
        # Save mapping report
        mapping_report_path = os.path.join(output_dir, "mapping_report.json")
        self.datapoint_mapper.save_mapping_report(mapping_report_path)
        
        # Save enhanced graph analysis
        unique_datapoints = set()
        for datapoints in self.datapoint_mapper.current_mapping.values():
            unique_datapoints.update(datapoints)
            
        analysis_data = [{
            'num_nodes': self.graph.number_of_nodes(),
            'num_edges': self.graph.number_of_edges(),
            'density': nx.density(self.graph),
            'is_directed': self.graph.is_directed(),
            'is_acyclic': nx.is_directed_acyclic_graph(self.graph),
            'num_components': nx.number_weakly_connected_components(self.graph),
            'total_unique_datapoints': len(unique_datapoints),
            'total_datapoint_mappings': len(datapoint_data),
            'codes_with_datapoints': len([code for code in self.datapoint_mapper.get_current_codes() 
                                        if len(self.datapoint_mapper.get_datapoints_for_code(code)) > 0]),
            'avg_global_frequency': sum(len(self.datapoint_mapper.get_datapoints_for_code(code)) 
                                      for code in self.datapoint_mapper.get_current_codes()) / len(self.datapoint_mapper.get_current_codes()) if self.datapoint_mapper.get_current_codes() else 0,
            'avg_incoming_edges': sum(self.datapoint_mapper.get_merging_frequency(code) 
                                    for code in self.datapoint_mapper.get_current_codes()) / len(self.datapoint_mapper.get_current_codes()) if self.datapoint_mapper.get_current_codes() else 0
        }]
        analysis_df = pd.DataFrame(analysis_data)
        analysis_df.to_parquet(os.path.join(output_dir, "enhanced_graph_analysis.parquet"), index=False)
        
        # Save datapoint-code mapping files for context retrievers
        self.save_datapoint_code_mapping_files(output_dir)
        
        # Detect and save cliques (connected components)
        detect_and_save_cliques_parquet(self.graph, output_dir)
        
        print(f"   ✅ Saved enhanced results to {output_dir}")


def build_enhanced_topological_graph(corpus_path: str, matrix_path: str, codes_path: str, 
                                    output_dir: str, min_frequency: int = 2):
    """
    Build topological graph with complete datapoint tracking
    
    Args:
        corpus_path: Path to corpus parquet file
        matrix_path: Path to relationship matrix
        codes_path: Path to codes parquet file
        output_dir: Output directory
        min_frequency: Minimum frequency for merging
    
    Returns:
        Enhanced builder with complete tracking
    """
    print("🚀 Building enhanced topological graph...")
    
    # Create builder from corpus (loads original datapoint mappings)
    builder = EnhancedTopologicalGraphBuilder.from_corpus(corpus_path)
    
    # Load relationship matrix
    builder.load_relationship_matrix(matrix_path, codes_path)
    
    # Apply merging operations (with datapoint tracking and proper scoring)
    print("📋 Applying merging operations with scoring and edge accumulation...")
    builder.merge_mutual_relationships()
    
    if min_frequency > 0:
        builder.merge_low_frequency_codes(min_frequency)
    
    # Create final graph
    builder.create_graph_from_matrix()
    
    # Update merging frequencies based on incoming edges only (final validation)
    builder.update_merging_frequencies_from_graph()
    
    # Perform topological sort
    sorted_nodes, levels = builder.perform_topological_sort()
    
    # Save enhanced results
    builder.save_enhanced_results(sorted_nodes, levels, output_dir)
    
    return builder


if __name__ == "__main__":
    # Example usage
    corpus_path = "../temp_files/corpus.parquet"
    matrix_path = "../temp_files/conflict_detection/final_relationship_matrix.npy"
    codes_path = "../temp_files/conflict_detection/unique_codes.parquet"
    output_dir = "../temp_files/enhanced_topological_graph"
    
    if all(os.path.exists(path) for path in [corpus_path, matrix_path, codes_path]):
        builder = build_enhanced_topological_graph(
            corpus_path, matrix_path, codes_path, output_dir, 
            min_frequency=2
        )
        print("Enhanced topological graph building completed!")
    else:
        print("Required files not found. Check file paths:")
        for path in [corpus_path, matrix_path, codes_path]:
            print(f"  {'✅' if os.path.exists(path) else '❌'} {path}")