#!/usr/bin/env python3
"""
Graph converter for Sugiyama2 problem.
Created using subagent_prompt.md version: v_02

This problem is about optimal ordering of nodes in layered graphs to minimize edge crossings.
Key challenges: Cross-layer edge density, layer width imbalances, and potential crossing patterns.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the Sugiyama layered graph layout problem.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create bipartite graph modeling layout constraints and crossing potential
    - Node variables (type 0): Represent positions in layers with centrality weights
    - Constraint nodes (type 1): Layer constraints, edge crossing constraints, ordering constraints
    - Edge relationships: Position participation in constraints, crossing dependencies
    - Weights: Based on crossing potential, layer density, and constraint tightness
    """
    layers = json_data.get('layers', 0)
    nodes = json_data.get('nodes', 0)
    edges = json_data.get('edges', 0)
    width = json_data.get('width', [])
    start = json_data.get('start', [])
    end = json_data.get('end', [])
    
    G = nx.Graph()
    
    # Calculate layer boundaries for node assignment
    layer_boundaries = [0]
    for i in range(layers):
        if i < len(width):
            layer_boundaries.append(layer_boundaries[-1] + width[i])
    
    # Helper function to determine which layer a node belongs to
    def get_node_layer(node_id):
        for layer in range(layers):
            if layer_boundaries[layer] < node_id <= layer_boundaries[layer + 1]:
                return layer
        return 0
    
    # Node variables (type 0): Position nodes with centrality-based weights
    max_width = max(width) if width else 1
    for node_id in range(1, nodes + 1):
        layer = get_node_layer(node_id)
        layer_width = width[layer] if layer < len(width) else 1
        
        # Position within layer (0 to 1)
        pos_in_layer = (node_id - layer_boundaries[layer] - 1) / max(layer_width - 1, 1)
        
        # Centrality weight: nodes in center of layer have higher weight (more constrained)
        centrality = 1.0 - abs(pos_in_layer - 0.5) * 2
        
        # Layer density factor: denser layers create more constraints
        density_factor = layer_width / max_width
        
        # Final weight combines centrality and density
        weight = (centrality * 0.7 + density_factor * 0.3)
        
        G.add_node(f'node_{node_id}', type=0, weight=weight)
    
    # Constraint nodes (type 1): Layer ordering constraints
    for layer in range(layers):
        if layer < len(width):
            layer_width = width[layer]
            # Weight by complexity: wider layers have more ordering constraints
            complexity = math.log(max(layer_width, 1) + 1) / math.log(max_width + 1)
            G.add_node(f'layer_order_{layer}', type=1, weight=complexity)
    
    # Edge crossing constraint nodes (type 1)
    edge_pairs = []
    crossing_potential = {}
    
    for i in range(len(start)):
        for j in range(i + 1, len(start)):
            if i < len(start) and j < len(start) and i < len(end) and j < len(end):
                start1, end1 = start[i], end[i]
                start2, end2 = start[j], end[j]
                
                # Check if edges are on adjacent layers (potential crossings)
                layer1_start = get_node_layer(start1)
                layer1_end = get_node_layer(end1)
                layer2_start = get_node_layer(start2)
                layer2_end = get_node_layer(end2)
                
                if (layer1_start == layer2_start and layer1_end == layer2_end and 
                    abs(layer1_end - layer1_start) == 1):
                    
                    edge_pairs.append((i, j))
                    
                    # Calculate crossing potential based on position spread
                    pos1_start = start1 - layer_boundaries[layer1_start]
                    pos1_end = end1 - layer_boundaries[layer1_end]
                    pos2_start = start2 - layer_boundaries[layer2_start]
                    pos2_end = end2 - layer_boundaries[layer2_end]
                    
                    # Higher potential if edges "cross" in their natural ordering
                    cross_potential = abs((pos1_start - pos2_start) - (pos1_end - pos2_end))
                    max_spread = max(width[layer1_start], width[layer1_end])
                    normalized_potential = min(cross_potential / max(max_spread, 1), 1.0)
                    
                    crossing_potential[(i, j)] = normalized_potential
    
    # Create crossing constraint nodes
    for idx, (i, j) in enumerate(edge_pairs[:min(len(edge_pairs), 50)]):  # Limit for complexity
        potential = crossing_potential.get((i, j), 0.5)
        G.add_node(f'crossing_constraint_{idx}', type=1, weight=potential)
    
    # Global complexity constraint (type 1)
    # Weight by overall crossing potential of the instance
    total_edges = len(start)
    density = total_edges / max(nodes, 1) if nodes > 0 else 0
    layer_variance = math.sqrt(sum((w - (nodes / layers))**2 for w in width) / max(layers, 1)) if layers > 0 else 0
    complexity = min(density * 0.6 + (layer_variance / max(width)) * 0.4, 1.0)
    G.add_node('global_complexity', type=1, weight=complexity)
    
    # Bipartite edges: Nodes participate in layer constraints
    for node_id in range(1, nodes + 1):
        layer = get_node_layer(node_id)
        G.add_edge(f'node_{node_id}', f'layer_order_{layer}', weight=1.0)
        
        # Connect to global complexity
        G.add_edge(f'node_{node_id}', 'global_complexity', weight=0.3)
    
    # Edges participate in crossing constraints
    for idx, (i, j) in enumerate(edge_pairs[:min(len(edge_pairs), 50)]):
        if i < len(start) and j < len(start) and i < len(end) and j < len(end):
            start1, end1 = start[i], end[i]
            start2, end2 = start[j], end[j]
            
            # Connect all four nodes involved in potential crossing
            crossing_node = f'crossing_constraint_{idx}'
            potential = crossing_potential.get((i, j), 0.5)
            
            G.add_edge(f'node_{start1}', crossing_node, weight=potential)
            G.add_edge(f'node_{end1}', crossing_node, weight=potential)
            G.add_edge(f'node_{start2}', crossing_node, weight=potential)
            G.add_edge(f'node_{end2}', crossing_node, weight=potential)
    
    # Add direct edges between nodes that share edges (within-layer conflicts)
    node_connections = {}
    for i in range(len(start)):
        if i < len(start) and i < len(end):
            s, e = start[i], end[i]
            if s not in node_connections:
                node_connections[s] = set()
            if e not in node_connections:
                node_connections[e] = set()
            node_connections[s].add(e)
            node_connections[e].add(s)
    
    # Add conflict edges for nodes in same layer with many connections
    for layer in range(layers):
        layer_nodes = []
        for node_id in range(1, nodes + 1):
            if get_node_layer(node_id) == layer:
                layer_nodes.append(node_id)
        
        # Add conflicts between highly connected nodes in same layer
        for i, node1 in enumerate(layer_nodes):
            for node2 in layer_nodes[i+1:]:
                conn1 = len(node_connections.get(node1, set()))
                conn2 = len(node_connections.get(node2, set()))
                
                if conn1 > 2 and conn2 > 2:  # Both highly connected
                    # Weight by degree of connectivity
                    conflict_weight = min((conn1 + conn2) / (2 * max(total_edges / nodes, 1)), 1.0)
                    if conflict_weight > 0.3:  # Only add significant conflicts
                        G.add_edge(f'node_{node1}', f'node_{node2}', weight=conflict_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()