#!/usr/bin/env python3
"""
Graph converter for JSP0 (Job Shop Scheduling with Speed Scaling) problem.
Created using subagent_prompt.md version: v_02

This problem is about scheduling jobs on machines with precedence constraints
and speed scaling options to minimize makespan plus energy consumption.
Key challenges: precedence constraints, machine conflicts, speed-time-energy tradeoffs
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the JSP0 problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Model as bipartite graph with variable and constraint nodes
    - Job-machine pairs as variable nodes (type 0)
    - Precedence constraints as constraint nodes (type 1) 
    - Machine capacity constraints as constraint nodes (type 1)
    - Use precedence order and machine load to determine weights
    """
    # Get problem dimensions from precedence data
    precedence_flat = json_data.get('precedence', [])
    speed = json_data.get('SPEED', 1)
    
    # Infer number of jobs and machines from precedence array length
    # precedence is jobs x machines matrix flattened
    n_elements = len(precedence_flat)
    n_jobs = int(math.sqrt(n_elements))
    n_machines = n_elements // n_jobs
    
    # Reshape precedence to 2D
    precedence = []
    for i in range(n_jobs):
        row = []
        for j in range(n_machines):
            idx = i * n_machines + j
            if idx < len(precedence_flat):
                row.append(precedence_flat[idx])
            else:
                row.append(0)
        precedence.append(row)
    
    G = nx.Graph()
    
    # Variable nodes: job-machine pairs (type 0)
    # Weight based on position in precedence order (later operations are more constrained)
    for j in range(n_jobs):
        for m in range(n_machines):
            prec_order = precedence[j][m]
            # Higher precedence numbers (later in sequence) have higher weight
            weight = (prec_order + 1) / n_machines
            weight = min(max(weight, 0.1), 1.0)  # Clamp to [0.1, 1.0]
            
            G.add_node(f'job_{j}_machine_{m}', type=0, weight=weight)
    
    # Constraint nodes for precedence relationships (type 1)
    # Each job has n_machines-1 precedence constraints
    for j in range(n_jobs):
        # Create precedence constraints for this job
        prec_pairs = []
        for m in range(n_machines):
            prec_pairs.append((precedence[j][m], m))
        prec_pairs.sort()  # Sort by precedence order
        
        # Add precedence constraint nodes
        for i in range(len(prec_pairs) - 1):
            curr_prec, curr_machine = prec_pairs[i]
            next_prec, next_machine = prec_pairs[i + 1]
            
            # Weight by gap in precedence (larger gaps are more flexible)
            prec_gap = next_prec - curr_prec
            tightness = 1.0 - (prec_gap / n_machines)
            tightness = max(tightness, 0.2)  # Minimum tightness
            
            constraint_id = f'precedence_job_{j}_step_{i}'
            G.add_node(constraint_id, type=1, weight=tightness)
            
            # Connect to involved job-machine pairs
            G.add_edge(f'job_{j}_machine_{curr_machine}', constraint_id, weight=0.8)
            G.add_edge(f'job_{j}_machine_{next_machine}', constraint_id, weight=0.8)
    
    # Constraint nodes for machine capacity (disjunctive constraints) (type 1)
    # Each machine can only process one job at a time
    for m in range(n_machines):
        # Count jobs on this machine to determine constraint tightness
        jobs_on_machine = n_jobs
        load_factor = jobs_on_machine / max(n_jobs, 1)
        
        # Add machine capacity constraint
        constraint_id = f'machine_capacity_{m}'
        G.add_node(constraint_id, type=1, weight=load_factor)
        
        # Connect all job-machine pairs that use this machine
        for j in range(n_jobs):
            edge_weight = 1.0 / max(jobs_on_machine, 1)  # Higher contention = lower weight
            G.add_edge(f'job_{j}_machine_{m}', constraint_id, weight=edge_weight)
    
    # Add conflict edges between jobs on the same machine (high contention machines)
    for m in range(n_machines):
        jobs_on_machine = []
        for j in range(n_jobs):
            jobs_on_machine.append(j)
        
        # Add conflict edges between jobs competing for machine time
        # Weight by expected processing time conflicts
        if len(jobs_on_machine) > 1:
            for i in range(len(jobs_on_machine)):
                for k in range(i + 1, min(len(jobs_on_machine), i + 4)):  # Limit conflicts
                    j1, j2 = jobs_on_machine[i], jobs_on_machine[k]
                    
                    # Weight by precedence difference (closer in sequence = higher conflict)
                    prec1 = precedence[j1][m]
                    prec2 = precedence[j2][m]
                    prec_diff = abs(prec1 - prec2)
                    conflict_weight = math.exp(-prec_diff / n_machines)  # Exponential decay
                    
                    if conflict_weight > 0.1:  # Only add significant conflicts
                        G.add_edge(f'job_{j1}_machine_{m}', 
                                 f'job_{j2}_machine_{m}', 
                                 weight=conflict_weight)
    
    # Add speed scaling complexity nodes if SPEED > 1
    if speed > 1:
        # Global complexity node for speed decisions
        G.add_node('speed_complexity', type=1, weight=math.log(speed + 1) / 10)
        
        # Connect to all job-machine pairs (they all have speed decisions)
        speed_weight = 0.3  # Moderate impact
        for j in range(n_jobs):
            for m in range(n_machines):
                G.add_edge(f'job_{j}_machine_{m}', 'speed_complexity', weight=speed_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()