#!/usr/bin/env python3
"""
Graph converter for CELAR Radio Link Frequency Assignment problem.
Converter created with subagent_prompt.md v_02

This problem is about assigning frequencies to radio links while minimizing interference.
Key challenges: Hard equality constraints must be satisfied exactly, while soft inequality 
constraints contribute to the objective when violated. Different constraint weights create 
a complex optimization landscape.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the CELAR frequency assignment instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph with:
    - Type 0 nodes: Variables (frequency assignments)
    - Type 1 nodes: Individual constraints (hard and soft)
    - Edge weights based on constraint tightness and criticality
    - Variable weights based on constraint degree and domain restrictions
    """
    # Extract data
    num_variables = json_data.get('num_variables', 0)
    num_hardconstraints = json_data.get('num_hardconstraints', 0)
    num_softconstraints = json_data.get('num_softconstraints', 0)
    
    costs = json_data.get('costs', [1, 1, 1, 1])  # Default costs
    domains = json_data.get('domains', [])
    
    # Hard constraint data
    hardctrx = json_data.get('hardctrx', [])
    hardctry = json_data.get('hardctry', [])
    hardctrk = json_data.get('hardctrk', [])
    
    # Soft constraint data
    softctrx = json_data.get('softctrx', [])
    softctry = json_data.get('softctry', [])
    softctrk = json_data.get('softctrk', [])
    softctrw = json_data.get('softctrw', [])
    
    # Frequency range for normalization
    min_freq = json_data.get('min_freq', 1)
    max_freq = json_data.get('max_freq', 100)
    freq_range = max(max_freq - min_freq, 1)
    
    G = nx.Graph()
    
    # Count constraint participation for each variable
    var_hard_count = [0] * (num_variables + 1)  # 1-indexed
    var_soft_count = [0] * (num_variables + 1)
    
    for i in range(num_hardconstraints):
        if i < len(hardctrx) and i < len(hardctry):
            x, y = hardctrx[i], hardctry[i]
            if 1 <= x <= num_variables:
                var_hard_count[x] += 1
            if 1 <= y <= num_variables:
                var_hard_count[y] += 1
                
    for i in range(num_softconstraints):
        if i < len(softctrx) and i < len(softctry):
            x, y = softctrx[i], softctry[i]
            if 1 <= x <= num_variables:
                var_soft_count[x] += 1
            if 1 <= y <= num_variables:
                var_soft_count[y] += 1
    
    # Add variable nodes (Type 0) with constraint-based weights
    max_constraints = max(max(var_hard_count[1:] + var_soft_count[1:]) if num_variables > 0 else [1], 1)
    for var in range(1, num_variables + 1):
        # Weight based on constraint participation (more constrained = higher weight)
        hard_count = var_hard_count[var]
        soft_count = var_soft_count[var]
        total_constraints = hard_count + soft_count
        
        # Higher weight for more constrained variables, with emphasis on hard constraints
        constraint_weight = (hard_count * 2 + soft_count) / max(max_constraints * 2, 1)
        
        # Domain restriction weight (fewer domain categories = more constrained)
        domain_weight = 0.5  # Default
        if var <= len(domains) and domains:
            num_categories = json_data.get('num_categories', 1)
            domain_category = domains[var - 1] if var <= len(domains) else 1
            # Higher category numbers might indicate more restricted domains
            domain_weight = min(domain_category / max(num_categories, 1), 1.0)
        
        # Combine weights with non-linear scaling
        final_weight = min(math.sqrt(constraint_weight * 0.7 + domain_weight * 0.3), 1.0)
        
        G.add_node(f'var_{var}', type=0, weight=final_weight)
    
    # Add hard constraint nodes (Type 1) with criticality-based weights
    for i in range(num_hardconstraints):
        if i < len(hardctrk):
            k_value = hardctrk[i]
            # Hard constraints are critical - weight based on frequency separation requirement
            # Higher k values (larger required separations) are more constraining
            separation_weight = min(k_value / freq_range, 1.0)
            # Hard constraints get high base weight due to their criticality
            final_weight = 0.8 + 0.2 * separation_weight
            
            G.add_node(f'hard_{i}', type=1, weight=final_weight)
    
    # Add soft constraint nodes (Type 1) with cost-based weights  
    max_cost = max(costs) if costs else 1
    for i in range(num_softconstraints):
        if i < len(softctrk) and i < len(softctrw):
            k_value = softctrk[i]
            weight_idx = softctrw[i]
            
            # Weight based on violation cost and separation requirement
            cost_weight = 0.5  # Default
            if 1 <= weight_idx <= len(costs):
                cost_weight = costs[weight_idx - 1] / max_cost
            
            separation_weight = min(k_value / freq_range, 1.0)
            
            # Combine with exponential scaling for cost importance
            final_weight = min(math.exp(-1.0 + 2.0 * cost_weight) * 0.4 + separation_weight * 0.3, 1.0)
            
            G.add_node(f'soft_{i}', type=1, weight=final_weight)
    
    # Add bipartite edges: variables to constraints they participate in
    
    # Hard constraint participation edges
    for i in range(num_hardconstraints):
        if i < len(hardctrx) and i < len(hardctry) and i < len(hardctrk):
            x, y = hardctrx[i], hardctry[i]
            k_value = hardctrk[i]
            
            # Edge weight based on constraint tightness
            tightness = min(k_value / freq_range, 1.0)
            edge_weight = 0.7 + 0.3 * tightness  # Hard constraints get high edge weights
            
            if 1 <= x <= num_variables:
                G.add_edge(f'var_{x}', f'hard_{i}', weight=edge_weight)
            if 1 <= y <= num_variables and y != x:
                G.add_edge(f'var_{y}', f'hard_{i}', weight=edge_weight)
    
    # Soft constraint participation edges
    for i in range(num_softconstraints):
        if i < len(softctrx) and i < len(softctry) and i < len(softctrk) and i < len(softctrw):
            x, y = softctrx[i], softctry[i]
            k_value = softctrk[i]
            weight_idx = softctrw[i]
            
            # Edge weight based on cost and separation requirement
            cost_weight = 0.5
            if 1 <= weight_idx <= len(costs):
                cost_weight = costs[weight_idx - 1] / max_cost
            
            tightness = min(k_value / freq_range, 1.0)
            edge_weight = min(cost_weight * 0.6 + tightness * 0.4, 1.0)
            
            if 1 <= x <= num_variables:
                G.add_edge(f'var_{x}', f'soft_{i}', weight=edge_weight)
            if 1 <= y <= num_variables and y != x:
                G.add_edge(f'var_{y}', f'soft_{i}', weight=edge_weight)
    
    # Add conflict edges between variables that share many high-cost soft constraints
    variable_conflicts = {}
    for i in range(num_softconstraints):
        if i < len(softctrx) and i < len(softctry) and i < len(softctrw):
            x, y = softctrx[i], softctry[i]
            weight_idx = softctrw[i]
            
            if 1 <= x <= num_variables and 1 <= y <= num_variables and x != y:
                pair = tuple(sorted([x, y]))
                if pair not in variable_conflicts:
                    variable_conflicts[pair] = []
                
                cost = costs[weight_idx - 1] if 1 <= weight_idx <= len(costs) else 1
                variable_conflicts[pair].append(cost)
    
    # Add conflict edges for variable pairs with high cumulative soft constraint costs
    for (var1, var2), cost_list in variable_conflicts.items():
        total_cost = sum(cost_list)
        avg_cost = total_cost / len(cost_list)
        
        # Add conflict edge if average cost is high (indicates important conflicts)
        if avg_cost >= max_cost * 0.5 and len(cost_list) >= 2:
            conflict_weight = min(avg_cost / max_cost * len(cost_list) / 10, 1.0)
            G.add_edge(f'var_{var1}', f'var_{var2}', weight=conflict_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()