#!/usr/bin/env python3
"""
Graph converter for WCSP (Weighted Constraint Satisfaction Problem) problem.
Created using subagent_prompt.md version: v_02

This problem is about finding variable assignments that minimize total cost from
unary and binary cost functions. Each variable has a finite domain, and cost
functions define preferences over variable assignments.
Key challenges: Cost function interactions, large domains, constraint tightness.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the WCSP instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph with variable nodes and constraint nodes.
    - Variables (type 0): Weighted by domain size and constraint involvement
    - Unary constraints (type 1): Weighted by cost variation/tightness
    - Binary constraints (type 1): Weighted by cost variation/tightness
    - Edges represent participation with weights based on cost impact
    """
    # Extract data from JSON
    num_variables = json_data.get('num_variables', 0)
    domains = json_data.get('domains', [])
    max_domain = json_data.get('max_domain', 1)
    top = json_data.get('top', 1)
    
    # Unary constraint data
    num_constraints1 = json_data.get('num_constraints1', 0)
    func1x = json_data.get('func1x', [])
    num_tuples1 = json_data.get('num_tuples1', [])
    costs1 = json_data.get('costs1', [])
    max_costs1 = json_data.get('max_costs1', 1)
    
    # Binary constraint data
    num_constraints2 = json_data.get('num_constraints2', 0)
    func2x = json_data.get('func2x', [])
    func2y = json_data.get('func2y', [])
    num_tuples2 = json_data.get('num_tuples2', [])
    costs2 = json_data.get('costs2', [])
    max_costs2 = json_data.get('max_costs2', 1)
    
    G = nx.Graph()
    
    # Variable nodes (type 0)
    # Weight by domain size (larger domains = more choices = potentially easier)
    # and inverse domain size (smaller domains = more constrained)
    for i in range(num_variables):
        var_id = f'var_{i+1}'
        domain_size = domains[i] if i < len(domains) else max_domain
        
        # Count how many constraints involve this variable
        constraint_count = 0
        # Count unary constraints
        for j in range(num_constraints1):
            if j < len(func1x) and func1x[j] == i+1:
                constraint_count += 1
        # Count binary constraints
        for j in range(num_constraints2):
            if j < len(func2x) and (func2x[j] == i+1 or 
                                  (j < len(func2y) and func2y[j] == i+1)):
                constraint_count += 1
        
        # Weight combines domain size pressure and constraint involvement
        domain_pressure = 1.0 - (domain_size / max_domain) if max_domain > 0 else 0.5
        constraint_pressure = min(constraint_count / max(num_variables, 1), 1.0)
        
        # Use exponential combination to emphasize high-pressure variables
        weight = math.sqrt(domain_pressure * constraint_pressure + 0.1)
        weight = min(weight, 1.0)
        
        G.add_node(var_id, type=0, weight=weight)
    
    # Unary constraint nodes (type 1)
    cum_tuples1 = json_data.get('cum_tuples1', [])
    for j in range(num_constraints1):
        if j >= len(func1x) or j >= len(num_tuples1):
            continue
            
        constraint_id = f'unary_{j}'
        var_index = func1x[j]
        tuple_count = num_tuples1[j]
        
        # Calculate cost variation for this constraint
        if j < len(cum_tuples1) and tuple_count > 0:
            start_idx = 2 * cum_tuples1[j]
            end_idx = start_idx + tuple_count * 2
            
            # Extract costs for this constraint (every 2nd element starting from cost position)
            constraint_costs = []
            for k in range(start_idx, min(end_idx, len(costs1)), 2):
                constraint_costs.append(costs1[k])
            
            if constraint_costs:
                min_cost = min(constraint_costs)
                max_cost = max(constraint_costs)
                cost_range = max_cost - min_cost
                
                # Weight by cost variation (more variation = more discriminating)
                if max_costs1 > 0:
                    weight = cost_range / max_costs1
                else:
                    weight = 0.5
                    
                # Add domain size factor (smaller domains = tighter constraints)
                domain_size = domains[var_index-1] if var_index-1 < len(domains) else max_domain
                domain_factor = 1.0 - (domain_size / max_domain) if max_domain > 0 else 0.5
                
                # Combine with exponential emphasis on high-variation constraints
                weight = math.sqrt(weight * (0.7 + 0.3 * domain_factor))
                weight = min(weight, 1.0)
            else:
                weight = 0.5
        else:
            weight = 0.5
            
        G.add_node(constraint_id, type=1, weight=weight)
        
        # Add edge from variable to constraint
        if var_index <= num_variables:
            var_id = f'var_{var_index}'
            # Edge weight based on cost impact
            edge_weight = weight  # Higher cost variation = stronger relationship
            G.add_edge(var_id, constraint_id, weight=edge_weight)
    
    # Binary constraint nodes (type 1)
    cum_tuples2 = json_data.get('cum_tuples2', [])
    for j in range(num_constraints2):
        if j >= len(func2x) or j >= len(func2y) or j >= len(num_tuples2):
            continue
            
        constraint_id = f'binary_{j}'
        var1_index = func2x[j]
        var2_index = func2y[j]
        tuple_count = num_tuples2[j]
        
        # Calculate cost variation for this constraint
        if j < len(cum_tuples2) and tuple_count > 0:
            start_idx = 3 * cum_tuples2[j]
            end_idx = start_idx + tuple_count * 3
            
            # Extract costs for this constraint (every 3rd element starting from cost position)
            constraint_costs = []
            for k in range(start_idx, min(end_idx, len(costs2)), 3):
                constraint_costs.append(costs2[k])
            
            if constraint_costs:
                min_cost = min(constraint_costs)
                max_cost = max(constraint_costs)
                cost_range = max_cost - min_cost
                
                # Weight by cost variation and scope
                if max_costs2 > 0:
                    weight = cost_range / max_costs2
                else:
                    weight = 0.5
                    
                # Binary constraints are generally more complex than unary
                weight = weight * 1.2
                
                # Add domain size factors
                domain1_size = domains[var1_index-1] if var1_index-1 < len(domains) else max_domain
                domain2_size = domains[var2_index-1] if var2_index-1 < len(domains) else max_domain
                avg_domain = (domain1_size + domain2_size) / 2
                domain_factor = 1.0 - (avg_domain / max_domain) if max_domain > 0 else 0.5
                
                # Use non-linear combination to emphasize tight, high-variation constraints
                weight = math.sqrt(weight * (0.6 + 0.4 * domain_factor))
                weight = min(weight, 1.0)
            else:
                weight = 0.7  # Default higher weight for binary constraints
        else:
            weight = 0.7
            
        G.add_node(constraint_id, type=1, weight=weight)
        
        # Add edges from both variables to constraint
        if var1_index <= num_variables:
            var1_id = f'var_{var1_index}'
            edge_weight = weight * 0.8  # Slightly lower than constraint weight
            G.add_edge(var1_id, constraint_id, weight=edge_weight)
            
        if var2_index <= num_variables:
            var2_id = f'var_{var2_index}'
            edge_weight = weight * 0.8
            G.add_edge(var2_id, constraint_id, weight=edge_weight)
    
    # Add conflict edges between variables that share high-cost binary constraints
    # This captures direct variable interactions beyond constraint participation
    for j in range(num_constraints2):
        if (j >= len(func2x) or j >= len(func2y) or j >= len(cum_tuples2) or 
            j >= len(num_tuples2)):
            continue
            
        var1_index = func2x[j]
        var2_index = func2y[j]
        tuple_count = num_tuples2[j]
        
        if tuple_count > 0 and j < len(cum_tuples2):
            start_idx = 3 * cum_tuples2[j]
            end_idx = start_idx + tuple_count * 3
            
            # Check if this constraint has high costs (indicating strong conflicts)
            constraint_costs = []
            for k in range(start_idx, min(end_idx, len(costs2)), 3):
                constraint_costs.append(costs2[k])
            
            if constraint_costs:
                avg_cost = sum(constraint_costs) / len(constraint_costs)
                max_cost = max(constraint_costs)
                
                # Add conflict edge if constraint has high average cost or high max cost
                if (avg_cost > max_costs2 * 0.3 or max_cost > max_costs2 * 0.7) and max_costs2 > 0:
                    var1_id = f'var_{var1_index}'
                    var2_id = f'var_{var2_index}'
                    
                    # Conflict strength based on cost level
                    conflict_weight = min(avg_cost / max_costs2, 1.0) if max_costs2 > 0 else 0.5
                    
                    # Only add if not already connected and both variables exist
                    if (var1_index <= num_variables and var2_index <= num_variables and
                        not G.has_edge(var1_id, var2_id)):
                        G.add_edge(var1_id, var2_id, weight=conflict_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()