#!/usr/bin/env python3
"""
Graph converter for latin-squares-hybrid problem.
Created using subagent_prompt.md version: v_02

This problem is about filling an n×n grid (Latin square) where each row,
column, and value must contain each number from 1 to n exactly once.
Key challenges: constraint propagation, symmetry breaking, and maintaining
all three types of constraints simultaneously.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the Latin squares problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph modeling the constraint structure
    - Variable nodes (type 0): Each cell (i,j) position in the n×n grid
    - Constraint nodes (type 1): Row constraints, column constraints, and value constraints
    - Edges connect cells to the constraints they participate in
    - Weights reflect constraint tightness (higher n = more constrained) and centrality
    """
    n = json_data.get('n', 3)
    
    G = nx.Graph()
    
    # Variable nodes: Each cell position (i,j) in the n×n grid
    # Weight by centrality - center cells are more constrained due to higher connectivity
    center = (n - 1) / 2.0
    for i in range(n):
        for j in range(n):
            # Distance from center, normalized
            center_dist = math.sqrt((i - center)**2 + (j - center)**2)
            max_dist = math.sqrt(2 * center**2)
            
            # Non-linear centrality weight: exponential decay from center
            if max_dist > 0:
                centrality = math.exp(-2.0 * center_dist / max_dist)
            else:
                centrality = 1.0
            
            G.add_node(f'cell_{i}_{j}', type=0, weight=centrality)
    
    # Constraint nodes (type 1): 
    # Each row, column, and value must have exactly one occurrence of each number
    
    # Row constraints: each row must contain values 1..n exactly once
    for i in range(n):
        for k in range(n):
            # Weight by problem size - larger squares are exponentially harder
            constraint_weight = 1.0 - math.exp(-n / 5.0)  # Approaches 1.0 for large n
            G.add_node(f'row_{i}_val_{k}', type=1, weight=constraint_weight)
    
    # Column constraints: each column must contain values 1..n exactly once  
    for j in range(n):
        for k in range(n):
            constraint_weight = 1.0 - math.exp(-n / 5.0)
            G.add_node(f'col_{j}_val_{k}', type=1, weight=constraint_weight)
    
    # Cell constraints: each cell must contain exactly one value
    for i in range(n):
        for j in range(n):
            # Weight by the number of choices (n) - more choices = harder decisions
            cell_constraint_weight = math.log(n + 1) / math.log(26)  # Normalized log scale
            G.add_node(f'cell_{i}_{j}_unique', type=1, weight=min(cell_constraint_weight, 1.0))
    
    # Bipartite edges: Connect each cell to the constraints it participates in
    
    # Connect cells to row constraints
    for i in range(n):
        for j in range(n):
            for k in range(n):
                # Edge weight represents participation strength
                # Each cell participates equally in each row constraint for different values
                participation_weight = 1.0 / n  # Normalized by number of values
                G.add_edge(f'cell_{i}_{j}', f'row_{i}_val_{k}', weight=participation_weight)
    
    # Connect cells to column constraints  
    for i in range(n):
        for j in range(n):
            for k in range(n):
                participation_weight = 1.0 / n
                G.add_edge(f'cell_{i}_{j}', f'col_{j}_val_{k}', weight=participation_weight)
    
    # Connect cells to their uniqueness constraints (each cell must have exactly one value)
    for i in range(n):
        for j in range(n):
            # Strong participation since this is a direct constraint on the cell
            G.add_edge(f'cell_{i}_{j}', f'cell_{i}_{j}_unique', weight=1.0)
    
    # Add conflict edges between cells that share multiple constraints
    # This captures the interdependence between cells in the same row/column
    for i1 in range(n):
        for j1 in range(n):
            for i2 in range(n):
                for j2 in range(n):
                    if (i1, j1) != (i2, j2):
                        # Cells conflict if they're in the same row OR same column
                        conflict_strength = 0.0
                        
                        if i1 == i2:  # Same row
                            conflict_strength += 0.5
                        if j1 == j2:  # Same column
                            conflict_strength += 0.5
                        
                        if conflict_strength > 0:
                            # Add conflict edge with strength proportional to shared constraints
                            # Weight decays with distance to model locality effects
                            distance = abs(i1 - i2) + abs(j1 - j2)
                            distance_factor = math.exp(-distance / n) if distance > 0 else 1.0
                            
                            final_weight = min(conflict_strength * distance_factor, 1.0)
                            if final_weight > 0.1:  # Only add significant conflicts
                                G.add_edge(f'cell_{i1}_{j1}', f'cell_{i2}_{j2}', weight=final_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()