#!/usr/bin/env python3
"""
Graph converter for whirlpool-x problem.
Created using subagent_prompt.md version: v_02

This problem is about placing numbers 1..n² in an n×n matrix such that:
- All numbers are different (permutation)
- Every 2×2 submatrix follows whirlpool ordering (clockwise/counter-clockwise)
- Outer rings follow whirlpool ordering
- Both diagonals sum to n*(n+1)*(n+1)/2

Key challenges: Complex structural constraints with quadratic scaling, diagonal constraints create tight bottlenecks
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the whirlpool-x problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create bipartite graph modeling variables and constraints
    - Matrix positions are variables (type 0) with centrality-based weights
    - Each constraint type gets separate nodes (type 1) weighted by scope/tightness
    - 2×2 whirlpool constraints are most numerous and critical
    - Ring constraints become tighter for larger rings
    - Diagonal constraints are global and highly constraining
    """
    n = json_data.get('n', 8)
    
    G = nx.Graph()
    
    # Variable nodes: matrix positions with centrality-based weights
    # Central positions participate in more 2×2 constraints, edge positions in fewer
    for i in range(n):
        for j in range(n):
            # Calculate how many 2x2 submatrices this position participates in
            # Interior positions participate in up to 4, edge positions in fewer
            participation_count = 0
            # Check all possible 2x2 submatrices this position could be part of
            for di in range(-1, 1):  # This position could be top-left, top-right, bottom-left, or bottom-right
                for dj in range(-1, 1):
                    top_left_i, top_left_j = i + di, j + dj
                    if (0 <= top_left_i < n-1) and (0 <= top_left_j < n-1):
                        participation_count += 1
            
            # Weight by participation in constraints (more constrained = higher weight)
            max_participation = 4  # Maximum for interior positions
            centrality_weight = participation_count / max_participation
            
            # Also consider diagonal importance (diagonal positions have additional constraints)
            diagonal_bonus = 0.0
            if i == j or i == (n-1-j):  # On main or anti-diagonal
                diagonal_bonus = 0.3
            
            final_weight = min(centrality_weight + diagonal_bonus, 1.0)
            G.add_node(f'pos_{i}_{j}', type=0, weight=final_weight)
    
    # Constraint nodes (type 1):
    
    # 1. Global alldifferent constraint (highest weight - affects all variables)
    G.add_node('alldifferent', type=1, weight=1.0)
    
    # 2. 2×2 whirlpool constraints (most numerous, moderate weight)
    # There are (n-1)×(n-1) such constraints
    constraint_count = (n-1) * (n-1)
    for i in range(n-1):
        for j in range(n-1):
            # Weight based on position - corner 2x2s are often easier to satisfy
            distance_from_center = abs(i - (n-2)/2) + abs(j - (n-2)/2)
            max_distance = (n-2)  # Maximum distance from center
            # Use non-linear scaling - positions further from center are easier
            if max_distance > 0:
                centrality = 1.0 - math.exp(-2.0 * distance_from_center / max_distance)
            else:
                centrality = 0.8
            
            G.add_node(f'whirlpool_2x2_{i}_{j}', type=1, weight=0.6 + 0.3 * centrality)
    
    # 3. Ring whirlpool constraints (fewer but increasingly complex for larger rings)
    # Each ring k has different complexity based on perimeter length
    num_rings = min(n, n) // 2
    for k in range(num_rings):
        # Calculate ring perimeter (number of positions in ring k)
        if k == 0:
            ring_size = 4 * (n - 2*k) - 4 if n - 2*k > 1 else 1  # Outer ring
        else:
            ring_size = max(4 * (n - 2*k) - 4, 0)  # Inner rings
        
        # Larger rings are more complex to satisfy due to longer constraint chains
        max_ring_size = 4 * n - 4  # Size of outermost ring
        if max_ring_size > 0:
            complexity = ring_size / max_ring_size
            # Use exponential scaling for ring complexity
            ring_weight = 0.4 + 0.4 * math.exp(complexity * 2)
        else:
            ring_weight = 0.7
            
        G.add_node(f'ring_whirlpool_{k}', type=1, weight=min(ring_weight, 1.0))
    
    # 4. Diagonal sum constraints (2 constraints, very tight)
    # These are global constraints affecting n positions each
    diagonal_weight = 0.9  # High weight due to exact sum requirement
    G.add_node('main_diagonal_sum', type=1, weight=diagonal_weight)
    G.add_node('anti_diagonal_sum', type=1, weight=diagonal_weight)
    
    # Edges: Variable-Constraint participation (bipartite structure)
    
    # Connect all positions to alldifferent constraint
    for i in range(n):
        for j in range(n):
            G.add_edge(f'pos_{i}_{j}', 'alldifferent', weight=1.0/n**2)  # Each variable equally constrained
    
    # Connect positions to their 2×2 whirlpool constraints
    for i in range(n-1):
        for j in range(n-1):
            constraint_node = f'whirlpool_2x2_{i}_{j}'
            # Each 2×2 constraint involves exactly 4 positions
            positions = [(i,j), (i,j+1), (i+1,j), (i+1,j+1)]
            for pi, pj in positions:
                G.add_edge(f'pos_{pi}_{pj}', constraint_node, weight=0.25)  # Equal participation
    
    # Connect positions to ring whirlpool constraints
    for k in range(num_rings):
        constraint_node = f'ring_whirlpool_{k}'
        # Calculate positions in ring k
        ring_positions = []
        
        if n - 2*k > 1:  # Non-trivial ring
            # Top row
            for j in range(k, n-k):
                ring_positions.append((k, j))
            # Right column (excluding corners)
            for i in range(k+1, n-k-1):
                ring_positions.append((i, n-1-k))
            # Bottom row (if different from top)
            if n-1-k > k:
                for j in range(n-1-k, k-1, -1):
                    ring_positions.append((n-1-k, j))
            # Left column (excluding corners)
            if k < n-1-k:
                for i in range(n-2-k, k, -1):
                    ring_positions.append((i, k))
        
        # Connect ring positions with weight inversely proportional to ring size
        if ring_positions:
            edge_weight = 1.0 / len(ring_positions)
            for pi, pj in ring_positions:
                G.add_edge(f'pos_{pi}_{pj}', constraint_node, weight=edge_weight)
    
    # Connect diagonal positions to diagonal sum constraints
    for i in range(n):
        # Main diagonal
        G.add_edge(f'pos_{i}_{i}', 'main_diagonal_sum', weight=1.0/n)
        # Anti-diagonal
        G.add_edge(f'pos_{i}_{n-1-i}', 'anti_diagonal_sum', weight=1.0/n)
    
    # Optional: Add conflict edges between positions that strongly compete
    # For positions that share many constraints, add weak conflict edges
    if n >= 6:  # Only for larger instances to avoid over-connecting small ones
        # Add conflicts between diagonal positions (they have the most constraints)
        diagonal_positions = [(i, i) for i in range(n)] + [(i, n-1-i) for i in range(n)]
        diagonal_positions = list(set(diagonal_positions))  # Remove duplicates
        
        for idx1 in range(len(diagonal_positions)):
            for idx2 in range(idx1+1, min(idx1+4, len(diagonal_positions))):  # Limit connections
                i1, j1 = diagonal_positions[idx1]
                i2, j2 = diagonal_positions[idx2]
                # Weight by shared constraint count (both have diagonal constraints)
                conflict_weight = 0.3  # Moderate conflict due to competing for similar value ranges
                G.add_edge(f'pos_{i1}_{j1}', f'pos_{i2}_{j2}', weight=conflict_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()