#!/usr/bin/env python3
"""
Graph converter for CostasArray problem.
Created using subagent_prompt.md version: v_02

This problem is about finding permutations where triangular difference table rows are distinct.
Key challenges: Quadratic number of constraints, complex interdependencies between differences
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the CostasArray problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Bipartite graph with position variables and constraint nodes
    - Positions (type 0): Array positions with centrality-based weights
    - AllDifferent constraint (type 1): Main permutation constraint
    - Difference constraints (type 1): One per row of difference triangle
    - Edge weights reflect constraint scope and interaction strength
    """
    n = json_data.get('n', 8)
    
    G = nx.Graph()
    
    # Position nodes (type 0) - decision variables for the permutation
    for i in range(1, n+1):
        # Weight based on centrality - middle positions are more constrained
        # due to participating in more difference calculations
        centrality = 1.0 - abs(i - (n+1)/2) / ((n-1)/2) if n > 1 else 1.0
        # Add slight non-linearity to reflect exponential difficulty growth
        weight = math.sqrt(centrality)
        G.add_node(f'pos_{i}', type=0, weight=weight)
    
    # Main alldifferent constraint node (type 1)
    # This is the most critical constraint as it affects all positions
    G.add_node('alldiff_main', type=1, weight=1.0)
    
    # Connect all positions to the main alldifferent constraint
    for i in range(1, n+1):
        G.add_edge(f'pos_{i}', 'alldiff_main', weight=1.0)
    
    # Difference constraint nodes (type 1) - one per row of the difference triangle
    # Row i has (n-i) differences
    for i in range(1, n):
        scope_size = n - i  # Number of differences in this row
        # Weight by relative scope and non-linear scaling for constraint difficulty
        # Smaller rows are typically harder as they're more constrained
        scope_weight = math.exp(-2.0 * (scope_size - 1) / (n - 1))  # Exponential decay
        normalized_weight = max(scope_weight, 0.1)  # Ensure minimum weight
        G.add_node(f'diff_row_{i}', type=1, weight=normalized_weight)
        
        # Connect positions that participate in this difference row
        # For row i, we need positions j and j-i where j > i
        for j in range(i+1, n+1):
            pos1 = j
            pos2 = j - i
            if pos2 >= 1:  # Valid position
                # Edge weight reflects how critical this participation is
                # Positions closer to boundaries contribute more to constraint tightness
                boundary_factor = min(pos1, pos2, n+1-pos1, n+1-pos2) / (n/2)
                participation_weight = 0.5 + 0.5 * (1.0 - boundary_factor)
                
                G.add_edge(f'pos_{pos1}', f'diff_row_{i}', weight=participation_weight)
                G.add_edge(f'pos_{pos2}', f'diff_row_{i}', weight=participation_weight)
    
    # Add conflict edges between positions that create particularly difficult interactions
    # These represent implicit conflicts in the difference structure
    if n > 4:  # Only for larger instances where this matters
        # Connect positions that are exactly n/2 apart as they create symmetric conflicts
        mid = (n + 1) // 2
        for i in range(1, mid):
            opposite = n + 1 - i
            if opposite <= n and opposite > i:
                # Conflict strength based on how central these positions are
                conflict_strength = (1.0 - abs(i - mid) / mid) * 0.7
                G.add_edge(f'pos_{i}', f'pos_{opposite}', weight=conflict_strength)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()