#!/usr/bin/env python3
"""
Graph converter for Latin Squares problem.
Converter created with subagent_prompt.md v_02

This problem is about placing numbers 1..n in an n×n grid such that
each row and column contains each number exactly once.
Key challenges: Constraint propagation scales quadratically with n,
and the search space grows exponentially.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the Latin squares instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create bipartite graph with variables and constraints
    - Variables: Each cell position (i,j) with centrality-based weights
    - Constraints: Row and column all-different constraints
    - Edges: Participation of variables in constraints
    - Challenge: Corner cells are less constrained than center cells
    """
    n = json_data.get('n', 3)
    
    G = nx.Graph()
    
    # Variable nodes for each cell position
    # Weight by constraint density: cells involved in more constraints are harder
    for i in range(n):
        for j in range(n):
            # Corner positions have fewer neighboring conflicts in extended models
            # but in Latin squares all positions are equally constrained
            # Weight by centrality: center positions interact more with others
            center_i = n // 2
            center_j = n // 2
            distance_from_center = abs(i - center_i) + abs(j - center_j)
            max_distance = center_i + center_j
            
            # Central positions are more critical for propagation
            if max_distance > 0:
                centrality = 1.0 - (distance_from_center / (max_distance * 2))
            else:
                centrality = 1.0
            
            # Ensure weight is in [0,1] and meaningful
            weight = max(0.3, min(1.0, 0.5 + 0.5 * centrality))
            
            G.add_node(f'cell_{i}_{j}', type=0, weight=weight)
    
    # Constraint nodes for all-different constraints
    # Row constraints - each has scope n
    for i in range(n):
        # Tightness increases with n (more values to place)
        scope_weight = min(1.0, 0.3 + 0.7 * (n / 25.0))  # Scale with problem size
        G.add_node(f'row_constraint_{i}', type=1, weight=scope_weight)
    
    # Column constraints - each has scope n  
    for j in range(n):
        scope_weight = min(1.0, 0.3 + 0.7 * (n / 25.0))
        G.add_node(f'col_constraint_{j}', type=1, weight=scope_weight)
    
    # Bipartite edges: cell variables participate in row and column constraints
    for i in range(n):
        for j in range(n):
            cell = f'cell_{i}_{j}'
            
            # Each cell participates in exactly one row and one column constraint
            # Weight represents the impact of this variable on the constraint
            participation_weight = 1.0 / n  # Each variable is 1/n of the constraint
            
            # Connect to row constraint
            G.add_edge(cell, f'row_constraint_{i}', weight=participation_weight)
            
            # Connect to column constraint  
            G.add_edge(cell, f'col_constraint_{j}', weight=participation_weight)
    
    # Add conflict edges for cells that strongly interact
    # In Latin squares, cells in same row/column have high interdependence
    # Add edges between cells that create propagation cascades
    if n >= 5:  # Only for larger instances to avoid dense graphs
        for i in range(n):
            for j in range(n):
                # Connect adjacent cells with exponential decay
                for di in [-1, 0, 1]:
                    for dj in [-1, 0, 1]:
                        if di == 0 and dj == 0:
                            continue
                        ni, nj = i + di, j + dj
                        if 0 <= ni < n and 0 <= nj < n:
                            # Same row or column creates stronger conflicts
                            if di == 0 or dj == 0:  # Same row or column
                                conflict_weight = 0.6
                            else:  # Diagonal neighbors
                                conflict_weight = 0.3
                            
                            # Avoid duplicate edges
                            cell1 = f'cell_{i}_{j}'
                            cell2 = f'cell_{ni}_{nj}'
                            if cell1 < cell2:  # Lexicographic ordering to avoid duplicates
                                G.add_edge(cell1, cell2, weight=conflict_weight)
    
    # Add global complexity node for very large instances
    if n >= 10:
        complexity_factor = min(1.0, (n * n) / 400.0)  # Scale with grid size
        G.add_node('global_complexity', type=1, weight=complexity_factor)
        
        # Connect to most central cells
        center_i = n // 2
        center_j = n // 2
        radius = max(1, n // 4)
        
        for i in range(max(0, center_i - radius), min(n, center_i + radius + 1)):
            for j in range(max(0, center_j - radius), min(n, center_j + radius + 1)):
                cell = f'cell_{i}_{j}'
                distance = abs(i - center_i) + abs(j - center_j)
                if distance <= radius:
                    # Exponential decay with distance
                    proximity_weight = math.exp(-2.0 * distance / radius)
                    G.add_edge('global_complexity', cell, weight=proximity_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()