#!/usr/bin/env python3
"""
Graph converter for latin-squares-hybrid2 problem.
Created using subagent_prompt.md version: v_02

This problem is about constructing Latin squares of size n×n where each row 
and column is a permutation of numbers 1..n. Uses hybrid FD/LP approach.
Key challenges: symmetry breaking, constraint propagation efficiency, 
quadratic growth in variables and constraints with problem size.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the Latin squares problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create bipartite graph with position variables and constraints.
    - Variables: Each x[i,j,k] position-value assignment (type 0)
    - Constraints: Row, column, and cell constraints (type 1)
    - Complexity increases as O(n^3) variables and O(n^2) constraints
    - Weights reflect constraint tightness and variable centrality
    """
    n = json_data.get('n', 3)
    
    G = nx.Graph()
    
    # Variable nodes: x[i,j,k] - position (i,j) has value k
    # Weight by centrality: central positions are more constrained
    for i in range(1, n+1):
        for j in range(1, n+1):
            for k in range(1, n+1):
                # Central positions have higher weights (more constrained)
                centrality_i = 1.0 - abs(i - (n+1)/2) / (n/2) if n > 1 else 1.0
                centrality_j = 1.0 - abs(j - (n+1)/2) / (n/2) if n > 1 else 1.0
                centrality = (centrality_i + centrality_j) / 2
                
                # Value centrality: middle values might be more flexible
                value_centrality = 1.0 - abs(k - (n+1)/2) / (n/2) if n > 1 else 1.0
                
                # Combined weight with non-linear scaling
                weight = math.sqrt(centrality * value_centrality)
                
                G.add_node(f'x_{i}_{j}_{k}', type=0, weight=weight)
    
    # Constraint nodes (type 1)
    constraint_count = 0
    
    # Cell constraints: each cell (i,j) has exactly one value
    for i in range(1, n+1):
        for j in range(1, n+1):
            # Weight by position centrality
            centrality_i = 1.0 - abs(i - (n+1)/2) / (n/2) if n > 1 else 1.0
            centrality_j = 1.0 - abs(j - (n+1)/2) / (n/2) if n > 1 else 1.0
            weight = (centrality_i + centrality_j) / 2
            
            constraint_id = f'cell_{i}_{j}'
            G.add_node(constraint_id, type=1, weight=weight)
            
            # Connect to all variables for this cell
            for k in range(1, n+1):
                var_id = f'x_{i}_{j}_{k}'
                G.add_edge(var_id, constraint_id, weight=1.0/n)  # Equal participation
            
            constraint_count += 1
    
    # Row constraints: each row i has exactly one occurrence of value k
    for i in range(1, n+1):
        for k in range(1, n+1):
            # Edge rows/values might be easier to satisfy
            edge_penalty = 0.8 if i == 1 or i == n or k == 1 or k == n else 1.0
            weight = edge_penalty
            
            constraint_id = f'row_{i}_{k}'
            G.add_node(constraint_id, type=1, weight=weight)
            
            # Connect to all variables in this row with this value
            for j in range(1, n+1):
                var_id = f'x_{i}_{j}_{k}'
                # Weight by column position - center columns are more critical
                col_centrality = 1.0 - abs(j - (n+1)/2) / (n/2) if n > 1 else 1.0
                edge_weight = col_centrality / n
                G.add_edge(var_id, constraint_id, weight=edge_weight)
            
            constraint_count += 1
    
    # Column constraints: each column j has exactly one occurrence of value k
    for j in range(1, n+1):
        for k in range(1, n+1):
            # Edge columns/values might be easier to satisfy
            edge_penalty = 0.8 if j == 1 or j == n or k == 1 or k == n else 1.0
            weight = edge_penalty
            
            constraint_id = f'col_{j}_{k}'
            G.add_node(constraint_id, type=1, weight=weight)
            
            # Connect to all variables in this column with this value
            for i in range(1, n+1):
                var_id = f'x_{i}_{j}_{k}'
                # Weight by row position - center rows are more critical
                row_centrality = 1.0 - abs(i - (n+1)/2) / (n/2) if n > 1 else 1.0
                edge_weight = row_centrality / n
                G.add_edge(var_id, constraint_id, weight=edge_weight)
            
            constraint_count += 1
    
    # Add problem complexity nodes to capture overall difficulty
    if n >= 5:  # Only for larger instances
        # Global complexity node representing symmetry breaking difficulty
        complexity_weight = math.log(n) / math.log(25)  # Scales logarithmically
        G.add_node('symmetry_complexity', type=1, weight=complexity_weight)
        
        # Connect to corner and center variables (most critical for symmetry)
        critical_positions = [(1,1), (1,n), (n,1), (n,n)]
        if n % 2 == 1:  # Add center for odd n
            center = (n+1)//2
            critical_positions.append((center, center))
        
        for i, j in critical_positions:
            for k in range(1, min(4, n+1)):  # Connect to first few values
                var_id = f'x_{i}_{j}_{k}'
                if G.has_node(var_id):
                    # Higher weight for value 1 (often fixed in symmetry breaking)
                    sym_weight = 0.9 if k == 1 else 0.6
                    G.add_edge(var_id, 'symmetry_complexity', weight=sym_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()