#!/usr/bin/env python3
"""
Graph converter for jobshop_1 problem.
Created using subagent_prompt.md version: v_02

This problem is about job-shop scheduling where jobs consist of tasks that must be executed
on specific machines in a specific order. Each machine can only process one task at a time.
Key challenges: resource contention on machines, precedence constraints within jobs, and 
optimizing makespan (total completion time).
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the job-shop scheduling problem.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph with:
    - Task nodes (type 0): Each job-task combination, weighted by duration/criticality
    - Machine constraint nodes (type 1): Capacity constraints for each machine
    - Job precedence nodes (type 1): Ordering constraints within each job
    - Machine nodes (type 2): Physical resources with utilization-based weights
    
    What makes instances hard:
    - High machine utilization (many tasks per machine)
    - Long task durations creating bottlenecks
    - Complex precedence patterns
    - Unbalanced workload distribution across machines
    """
    
    n_jobs = json_data.get('n_jobs', 0)
    n_machines = json_data.get('n_machines', 0)
    job_task_machine = json_data.get('job_task_machine', [])
    job_task_duration = json_data.get('job_task_duration', [])
    
    if not job_task_machine or not job_task_duration:
        # Empty instance
        return nx.Graph()
    
    G = nx.Graph()
    
    # Calculate instance characteristics
    max_duration = max(job_task_duration) if job_task_duration else 1
    total_duration = sum(job_task_duration)
    avg_duration = total_duration / len(job_task_duration) if job_task_duration else 1
    
    # Calculate machine workloads for resource analysis
    machine_workload = [0] * n_machines
    machine_task_count = [0] * n_machines
    
    for i, duration in enumerate(job_task_duration):
        if i < len(job_task_machine):
            machine_id = job_task_machine[i]
            if 0 <= machine_id < n_machines:
                machine_workload[machine_id] += duration
                machine_task_count[machine_id] += 1
    
    max_machine_workload = max(machine_workload) if machine_workload else 1
    
    # Add task nodes (type 0) - each represents a (job, task_in_job) pair
    for j in range(n_jobs):
        job_total_duration = 0
        job_start_idx = j * n_machines
        
        # Calculate job characteristics for weighting
        for k in range(n_machines):
            idx = job_start_idx + k
            if idx < len(job_task_duration):
                job_total_duration += job_task_duration[idx]
        
        for k in range(n_machines):  # task index within job
            idx = job_start_idx + k
            if idx < len(job_task_duration) and idx < len(job_task_machine):
                duration = job_task_duration[idx]
                machine_id = job_task_machine[idx]
                
                # Weight by duration criticality and position in job
                # Tasks with longer durations and those on busy machines are more critical
                duration_weight = duration / max_duration if max_duration > 0 else 0.5
                machine_pressure = machine_workload[machine_id] / max_machine_workload if max_machine_workload > 0 else 0.5
                position_weight = 1.0 - (k / n_machines) * 0.3  # Earlier tasks slightly more critical
                
                # Combine weights with non-linear scaling
                task_criticality = (duration_weight * 0.4 + machine_pressure * 0.4 + position_weight * 0.2)
                task_weight = math.sqrt(task_criticality)  # Non-linear scaling
                
                G.add_node(f'task_j{j}_k{k}', type=0, weight=task_weight)
    
    # Add machine resource nodes (type 2)
    for m in range(n_machines):
        # Weight by utilization and contention
        utilization = machine_workload[m] / max_machine_workload if max_machine_workload > 0 else 0.5
        contention = machine_task_count[m] / max(machine_task_count) if machine_task_count else 0.5
        
        # Machines with high utilization and many tasks are more critical resources
        machine_weight = math.tanh(2.0 * (utilization * 0.6 + contention * 0.4))  # Non-linear scaling
        G.add_node(f'machine_{m}', type=2, weight=machine_weight)
    
    # Add machine capacity constraint nodes (type 1) - one per machine
    for m in range(n_machines):
        tasks_on_machine = machine_task_count[m]
        # Weight by number of tasks competing for this machine
        constraint_tightness = min(1.0, tasks_on_machine / (n_jobs + 1))  # Normalized by expected load
        tightness_weight = 1.0 - math.exp(-2.0 * constraint_tightness)  # Exponential scaling for tightness
        
        G.add_node(f'machine_capacity_{m}', type=1, weight=tightness_weight)
    
    # Add job precedence constraint nodes (type 1) - one per job for internal ordering
    for j in range(n_jobs):
        job_start_idx = j * n_machines
        job_complexity = 0
        
        # Calculate job complexity based on duration variance and machine diversity
        durations_in_job = []
        machines_in_job = set()
        
        for k in range(n_machines):
            idx = job_start_idx + k
            if idx < len(job_task_duration) and idx < len(job_task_machine):
                durations_in_job.append(job_task_duration[idx])
                machines_in_job.add(job_task_machine[idx])
        
        if durations_in_job:
            # Jobs with high duration variance are harder to schedule optimally
            avg_job_duration = sum(durations_in_job) / len(durations_in_job)
            duration_variance = sum((d - avg_job_duration) ** 2 for d in durations_in_job) / len(durations_in_job)
            normalized_variance = duration_variance / (max_duration ** 2) if max_duration > 0 else 0
            
            # Weight combines precedence complexity with duration characteristics  
            precedence_weight = math.sqrt(normalized_variance * 0.7 + (len(machines_in_job) / n_machines) * 0.3)
        else:
            precedence_weight = 0.5
            
        G.add_node(f'job_precedence_{j}', type=1, weight=precedence_weight)
    
    # Add edges: Task participation in constraints
    
    # 1. Tasks to machine capacity constraints (bipartite edges)
    for j in range(n_jobs):
        job_start_idx = j * n_machines
        for k in range(n_machines):
            idx = job_start_idx + k
            if idx < len(job_task_machine) and idx < len(job_task_duration):
                machine_id = job_task_machine[idx]
                duration = job_task_duration[idx]
                
                # Edge weight represents how much this task loads the machine
                load_weight = duration / max_duration if max_duration > 0 else 0.5
                
                G.add_edge(f'task_j{j}_k{k}', f'machine_capacity_{machine_id}', weight=load_weight)
    
    # 2. Tasks to job precedence constraints
    for j in range(n_jobs):
        for k in range(n_machines):
            # Connect all tasks in the job to the job's precedence constraint
            precedence_strength = 1.0 - (k / n_machines) * 0.2  # Earlier tasks have slightly stronger precedence
            G.add_edge(f'task_j{j}_k{k}', f'job_precedence_{j}', weight=precedence_strength)
    
    # 3. Tasks to machine resource nodes
    for j in range(n_jobs):
        job_start_idx = j * n_machines
        for k in range(n_machines):
            idx = job_start_idx + k
            if idx < len(job_task_machine) and idx < len(job_task_duration):
                machine_id = job_task_machine[idx]
                duration = job_task_duration[idx]
                
                # Edge weight represents resource consumption intensity
                consumption_weight = duration / avg_duration if avg_duration > 0 else 0.5
                consumption_weight = min(consumption_weight, 1.0)  # Cap at 1.0
                
                G.add_edge(f'task_j{j}_k{k}', f'machine_{machine_id}', weight=consumption_weight)
    
    # 4. Add conflict edges between tasks that compete heavily for overloaded machines
    for m in range(n_machines):
        if machine_task_count[m] > n_jobs * 0.8:  # Machine is overloaded
            # Find all tasks on this machine
            tasks_on_machine = []
            for j in range(n_jobs):
                job_start_idx = j * n_machines
                for k in range(n_machines):
                    idx = job_start_idx + k
                    if idx < len(job_task_machine) and job_task_machine[idx] == m:
                        if idx < len(job_task_duration):
                            tasks_on_machine.append((f'task_j{j}_k{k}', job_task_duration[idx]))
            
            # Add conflict edges between longest tasks on overloaded machines
            tasks_on_machine.sort(key=lambda x: x[1], reverse=True)
            for i in range(min(len(tasks_on_machine), 4)):  # Top 4 consumers
                for j in range(i + 1, min(len(tasks_on_machine), 4)):
                    task1, dur1 = tasks_on_machine[i]
                    task2, dur2 = tasks_on_machine[j]
                    # Conflict strength based on combined resource pressure
                    conflict_weight = min(1.0, (dur1 + dur2) / (2.0 * avg_duration))
                    G.add_edge(task1, task2, weight=conflict_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()