#!/usr/bin/env python3
"""
Graph converter for Schur Numbers problem.
Created using subagent_prompt.md version: v_02

This problem is about placing n balls (labeled 1..n) into c boxes such that
no box contains a triple {x,y,z} where x+y=z. This is a classic problem in
Ramsey theory and combinatorics.

Key challenges: 
- The number of forbidden triples grows quadratically with n
- Constraint tightness increases as c decreases relative to n
- Central numbers (around n/2) are involved in more triples
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the Schur Numbers problem.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph with ball nodes and constraint nodes
    - Ball nodes (type 0): weighted by how many triples they participate in
    - Constraint nodes (type 1): one per forbidden triple {x,y,z} where x+y=z
    - Box resource nodes (type 2): represent the c boxes with capacity weights
    - Edges connect balls to their forbidden triples and to available boxes
    """
    n = json_data.get('n', 0)
    c = json_data.get('c', 1)
    
    G = nx.Graph()
    
    # Ball nodes (type 0) - weighted by involvement in forbidden triples
    triple_count = {}
    forbidden_triples = []
    
    # Find all forbidden triples {x,y,z} where x+y=z and x,y,z <= n
    for x in range(1, n+1):
        for y in range(x+1, n+1):
            z = x + y
            if z <= n:
                forbidden_triples.append((x, y, z))
                # Count how many triples each ball participates in
                for ball in [x, y, z]:
                    triple_count[ball] = triple_count.get(ball, 0) + 1
    
    max_triples = max(triple_count.values()) if triple_count else 1
    
    # Add ball nodes with weights based on triple involvement
    for ball in range(1, n+1):
        involvement = triple_count.get(ball, 0)
        # Use square root to avoid extreme weights for high-involvement balls
        weight = math.sqrt(involvement / max_triples) if max_triples > 0 else 0.5
        G.add_node(f'ball_{ball}', type=0, weight=weight)
    
    # Constraint nodes (type 1) - one per forbidden triple
    for i, (x, y, z) in enumerate(forbidden_triples):
        # Weight by how "central" the triple is (closer to n/2 means more constrained)
        center = n / 2
        centrality = 1.0 - abs((x + y + z) / 3 - center) / center
        weight = max(0.2, min(1.0, centrality))  # Clamp to reasonable range
        G.add_node(f'triple_{i}', type=1, weight=weight)
        
        # Connect balls to their forbidden triples
        G.add_edge(f'ball_{x}', f'triple_{i}', weight=1.0)
        G.add_edge(f'ball_{y}', f'triple_{i}', weight=1.0)
        G.add_edge(f'ball_{z}', f'triple_{i}', weight=1.0)
    
    # Box resource nodes (type 2) - represent available boxes
    total_triples = len(forbidden_triples)
    expected_triples_per_box = total_triples / (c ** 3) if c > 0 else total_triples
    
    for box in range(1, c+1):
        # Weight boxes by their relative scarcity
        # Fewer boxes means each is more valuable
        scarcity = 1.0 / c if c > 0 else 1.0
        # Also consider the expected load
        load_factor = min(1.0, expected_triples_per_box / 10.0)  # Normalize load
        weight = (scarcity + load_factor) / 2
        G.add_node(f'box_{box}', type=2, weight=weight)
        
        # Connect each ball to each box (representing assignment possibility)
        for ball in range(1, n+1):
            # Weight by assignment cost - higher for balls involved in more triples
            ball_involvement = triple_count.get(ball, 0)
            assignment_cost = ball_involvement / max_triples if max_triples > 0 else 0.5
            G.add_edge(f'ball_{ball}', f'box_{box}', weight=assignment_cost)
    
    # Add global problem difficulty node (optional)
    difficulty_ratio = total_triples / (n * c) if n * c > 0 else 1.0
    difficulty_weight = min(1.0, difficulty_ratio)
    G.add_node('global_difficulty', type=1, weight=difficulty_weight)
    
    # Connect high-involvement balls to global difficulty
    high_involvement_threshold = max_triples * 0.7 if max_triples > 0 else 0
    for ball in range(1, n+1):
        if triple_count.get(ball, 0) >= high_involvement_threshold:
            involvement_ratio = triple_count.get(ball, 0) / max_triples
            G.add_edge(f'ball_{ball}', 'global_difficulty', weight=involvement_ratio)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()