#!/usr/bin/env python3
"""
Graph converter for community-detection problem.
Converter created with subagent_prompt.md v_02

This problem is about finding communities in a graph that maximize modularity
while satisfying must-link and cannot-link constraints.
Key challenges: balance between modularity optimization and constraint satisfaction,
handling complex constraint interactions, dealing with varying graph densities.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the community detection problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create bipartite graph modeling vertices and constraints
    - Vertices (type 0) weighted by degree centrality and constraint involvement
    - Constraint nodes (type 1) for must-link, cannot-link, and modularity
    - Edges weighted by constraint strength and graph connectivity
    """
    n = json_data.get('n', 0)
    k = json_data.get('k', 1)
    nML = json_data.get('nML', 0)
    nCL = json_data.get('nCL', 0)
    
    # Parse arrays - they are flattened in JSON
    A = json_data.get('A', [])
    W = json_data.get('W', [])
    deg = json_data.get('deg', [])
    ML = json_data.get('ML', [])
    CL = json_data.get('CL', [])
    
    G = nx.Graph()
    
    # Reconstruct adjacency and weight matrices from flattened arrays
    A_matrix = []
    W_matrix = []
    for i in range(n):
        A_row = A[i*n:(i+1)*n] if len(A) >= (i+1)*n else [0] * n
        W_row = W[i*n:(i+1)*n] if len(W) >= (i+1)*n else [0] * n
        A_matrix.append(A_row)
        W_matrix.append(W_row)
    
    # Reconstruct ML and CL constraint pairs
    ML_pairs = []
    for i in range(0, len(ML), 2):
        if i+1 < len(ML):
            ML_pairs.append((ML[i]-1, ML[i+1]-1))  # Convert to 0-based indexing
    
    CL_pairs = []
    for i in range(0, len(CL), 2):
        if i+1 < len(CL):
            CL_pairs.append((CL[i]-1, CL[i+1]-1))  # Convert to 0-based indexing
    
    # Calculate max degree for normalization
    max_deg = max(deg) if deg else 1
    
    # Add vertex nodes (type 0) - weighted by degree and constraint involvement
    for i in range(n):
        vertex_deg = deg[i] if i < len(deg) else 0
        
        # Count how many constraints involve this vertex
        constraint_count = 0
        for v1, v2 in ML_pairs:
            if v1 == i or v2 == i:
                constraint_count += 1
        for v1, v2 in CL_pairs:
            if v1 == i or v2 == i:
                constraint_count += 1
        
        # Weight combines degree centrality and constraint involvement
        degree_weight = vertex_deg / max_deg if max_deg > 0 else 0
        constraint_weight = min(constraint_count / max(nML + nCL, 1), 1.0)
        
        # Use exponential function to emphasize high-degree, highly-constrained vertices
        combined_weight = 1.0 - math.exp(-2.0 * (degree_weight + constraint_weight))
        
        G.add_node(f'vertex_{i}', type=0, weight=combined_weight)
    
    # Add must-link constraint nodes (type 1)
    for idx, (v1, v2) in enumerate(ML_pairs):
        # Weight by the strength of connection in original graph
        connection_weight = 0.0
        if v1 < len(A_matrix) and v2 < len(A_matrix[v1]):
            # Normalize by maximum possible weight
            max_weight = max(max(row) for row in W_matrix) if W_matrix else 1
            min_weight = min(min(row) for row in W_matrix) if W_matrix else 0
            weight_range = max_weight - min_weight if max_weight > min_weight else 1
            
            raw_weight = W_matrix[v1][v2] if v2 < len(W_matrix[v1]) else 0
            connection_weight = (raw_weight - min_weight) / weight_range
            connection_weight = max(0, min(1, connection_weight))  # Clamp to [0,1]
        
        G.add_node(f'must_link_{idx}', type=1, weight=connection_weight)
        
        # Connect to involved vertices
        G.add_edge(f'vertex_{v1}', f'must_link_{idx}', weight=0.9)
        G.add_edge(f'vertex_{v2}', f'must_link_{idx}', weight=0.9)
    
    # Add cannot-link constraint nodes (type 1)
    for idx, (v1, v2) in enumerate(CL_pairs):
        # Weight inversely by connection strength (high weight = strong prohibition needed)
        connection_weight = 0.5  # Default moderate constraint
        if v1 < len(A_matrix) and v2 < len(A_matrix[v1]):
            if A_matrix[v1][v2] > 0:  # They are connected in original graph
                # Strong prohibition needed for connected vertices
                connection_weight = 0.8
            else:
                # Weaker constraint for already disconnected vertices
                connection_weight = 0.3
        
        G.add_node(f'cannot_link_{idx}', type=1, weight=connection_weight)
        
        # Connect to involved vertices with weight reflecting conflict strength
        edge_weight = connection_weight * 0.8  # Scale down for edge weight
        G.add_edge(f'vertex_{v1}', f'cannot_link_{idx}', weight=edge_weight)
        G.add_edge(f'vertex_{v2}', f'cannot_link_{idx}', weight=edge_weight)
    
    # Add modularity constraint nodes for graph edges (type 1)
    # Create constraint nodes for dense regions of the graph
    edge_count = 0
    total_edges = sum(sum(row) for row in A_matrix) // 2  # Undirected edges
    
    for i in range(n):
        for j in range(i+1, n):
            if i < len(A_matrix) and j < len(A_matrix[i]) and A_matrix[i][j] > 0:
                # Create modularity constraint for this edge
                # Weight by edge importance (degree product normalized)
                v1_deg = deg[i] if i < len(deg) else 1
                v2_deg = deg[j] if j < len(deg) else 1
                
                # Higher weight for edges between high-degree vertices
                modularity_weight = math.sqrt((v1_deg * v2_deg) / (max_deg * max_deg)) if max_deg > 0 else 0.5
                
                G.add_node(f'edge_mod_{i}_{j}', type=1, weight=modularity_weight)
                
                # Connect vertices to their modularity constraint
                # Weight by edge weight from original graph
                if i < len(W_matrix) and j < len(W_matrix[i]):
                    max_weight = max(max(row) for row in W_matrix) if W_matrix else 1
                    min_weight = min(min(row) for row in W_matrix) if W_matrix else 0
                    weight_range = max_weight - min_weight if max_weight > min_weight else 1
                    
                    edge_weight = W_matrix[i][j] if j < len(W_matrix[i]) else 0
                    normalized_weight = (edge_weight - min_weight) / weight_range
                    normalized_weight = max(0.1, min(1.0, normalized_weight))  # Keep minimum edge weight
                else:
                    normalized_weight = 0.5
                
                G.add_edge(f'vertex_{i}', f'edge_mod_{i}_{j}', weight=normalized_weight)
                G.add_edge(f'vertex_{j}', f'edge_mod_{i}_{j}', weight=normalized_weight)
                
                edge_count += 1
                
                # Limit number of modularity constraints to avoid explosion
                if edge_count >= min(total_edges, n * 3):
                    break
        if edge_count >= min(total_edges, n * 3):
            break
    
    # Add global complexity constraint representing community balancing (type 1)
    # This captures the difficulty of balancing k communities
    k_complexity = min(k / n, 1.0) if n > 0 else 0.5  # More communities = harder
    constraint_complexity = (nML + nCL) / n if n > 0 else 0  # More constraints = harder
    
    global_weight = 0.5 + 0.3 * k_complexity + 0.2 * min(constraint_complexity, 1.0)
    G.add_node('global_balance', type=1, weight=global_weight)
    
    # Connect high-degree vertices to global balance constraint
    sorted_vertices = sorted(range(n), key=lambda x: deg[x] if x < len(deg) else 0, reverse=True)
    for i in sorted_vertices[:min(k*2, n)]:  # Connect top vertices
        vertex_importance = (deg[i] / max_deg) if max_deg > 0 else 0.5
        G.add_edge(f'vertex_{i}', 'global_balance', weight=vertex_importance)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()