#!/usr/bin/env python3
"""
Graph converter for Flexible Job Shop Scheduling Problem (FJSP).
Converter created with subagent_prompt.md v_02

This problem is about scheduling jobs on flexible machines where each task
can be processed on multiple machines with different durations.
Key challenges: precedence constraints, machine conflicts, resource utilization
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the FJSP instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy:
    - Jobs and Tasks are variable-like nodes (decision variables)
    - Machines and Precedence/Resource constraints are constraint-like nodes
    - Model flexibility through task-machine option connectivity
    - Weight by processing complexity, resource contention, and criticality
    """
    no_mach = json_data.get('no_mach', 0)
    no_jobs = json_data.get('no_jobs', 0) 
    no_task = json_data.get('no_task', 0)
    no_optt = json_data.get('no_optt', 0)
    
    tasks = json_data.get('tasks', [])
    optts = json_data.get('optts', [])
    optt_mach = json_data.get('optt_mach', [])
    optt_dur = json_data.get('optt_dur', [])
    
    G = nx.Graph()
    
    # Parse tasks array - it contains task ranges per job in the JSON
    # Based on DZN format: tasks = [1..3, 4..6, 7..9] means job 1 has tasks 1-3, etc.
    job_tasks = {}
    if len(tasks) >= no_jobs * 2:  # Should have start and end for each job
        for job in range(1, no_jobs + 1):
            idx = (job - 1) * 2
            if idx + 1 < len(tasks):
                start_task = tasks[idx]
                end_task = tasks[idx + 1] 
                job_tasks[job] = list(range(start_task, end_task + 1))
            else:
                job_tasks[job] = []
    else:
        # Fallback: distribute tasks evenly across jobs
        tasks_per_job = max(1, no_task // no_jobs)
        for job in range(1, no_jobs + 1):
            start_task = (job - 1) * tasks_per_job + 1
            end_task = min(job * tasks_per_job, no_task)
            job_tasks[job] = list(range(start_task, end_task + 1))
    
    # Calculate durations and processing complexity
    max_dur = max(optt_dur) if optt_dur else 1
    min_dur = min(optt_dur) if optt_dur else 1
    
    # Add job nodes (Type 0 - decision variables about job scheduling)
    for job in range(1, no_jobs + 1):
        tasks_in_job = len(job_tasks.get(job, []))
        # Weight by job complexity (number of tasks and processing options)
        complexity = tasks_in_job / max(no_task, 1)
        G.add_node(f'job_{job}', type=0, weight=min(complexity * 2, 1.0))
    
    # Add task nodes (Type 0 - decision variables about task-machine assignment)
    task_options = {}  # task_id -> list of (machine, duration) options
    
    # Parse optts array - it contains option ranges per task similar to tasks array
    # optts = [1..2, 3..4, 5..6, ...] means task 1 has options 1-2, task 2 has options 3-4, etc.
    if len(optts) >= no_task * 2:  # Should have start and end for each task
        for task in range(1, no_task + 1):
            idx = (task - 1) * 2
            if idx + 1 < len(optts):
                start_opt = optts[idx] - 1  # Convert to 0-based index
                end_opt = optts[idx + 1] - 1
                options = []
                for opt_idx in range(start_opt, end_opt + 1):
                    if opt_idx < len(optt_mach) and opt_idx < len(optt_dur):
                        machine = optt_mach[opt_idx]
                        duration = optt_dur[opt_idx]
                        options.append((machine, duration))
                task_options[task] = options
            else:
                task_options[task] = []
    else:
        # Fallback: assume each task has 2 options
        options_per_task = max(1, no_optt // no_task)
        for task in range(1, no_task + 1):
            start_opt = (task - 1) * options_per_task
            end_opt = min(start_opt + options_per_task, no_optt)
            options = []
            for opt_idx in range(start_opt, end_opt):
                if opt_idx < len(optt_mach) and opt_idx < len(optt_dur):
                    machine = optt_mach[opt_idx]
                    duration = optt_dur[opt_idx]
                    options.append((machine, duration))
            task_options[task] = options
    
    for task in range(1, no_task + 1):
        options = task_options.get(task, [])
        
        # Weight by flexibility (more options = easier to schedule)
        # But also by processing complexity (longer durations = harder)
        if options:
            avg_duration = sum(dur for _, dur in options) / len(options)
            flexibility = 1.0 / len(options)  # Less flexibility = higher weight
            duration_complexity = avg_duration / max_dur
            weight = (flexibility * 0.6 + duration_complexity * 0.4)
        else:
            weight = 1.0  # No options = maximum difficulty
        
        G.add_node(f'task_{task}', type=0, weight=min(weight, 1.0))
    
    # Add machine nodes (Type 2 - shared resources)
    machine_load = {}  # machine -> total processing time
    for machine in range(1, no_mach + 1):
        machine_load[machine] = 0
    
    # Calculate machine loads
    for task, options in task_options.items():
        for machine, duration in options:
            machine_load[machine] += duration / len(options)  # Weighted by probability
    
    max_load = max(machine_load.values()) if machine_load.values() else 1
    
    for machine in range(1, no_mach + 1):
        # Weight by utilization (higher load = more constrained resource)
        utilization = machine_load[machine] / max_load
        G.add_node(f'machine_{machine}', type=2, weight=utilization)
    
    # Add precedence constraint nodes (Type 1)
    for job in range(1, no_jobs + 1):
        job_task_list = job_tasks.get(job, [])
        for i, task in enumerate(job_task_list[:-1]):  # All but last task
            next_task = job_task_list[i + 1]
            constraint_id = f'prec_{task}_{next_task}'
            
            # Weight by criticality (tasks with fewer options are more critical)
            task_options_count = len(task_options.get(task, []))
            next_options_count = len(task_options.get(next_task, []))
            total_options = task_options_count + next_options_count
            if total_options > 0:
                criticality = 2.0 / total_options
            else:
                criticality = 1.0  # Maximum criticality if no options
            
            G.add_node(constraint_id, type=1, weight=min(criticality, 1.0))
    
    # Add machine capacity constraint nodes (Type 1) - one per machine
    for machine in range(1, no_mach + 1):
        # Weight by contention level (how many tasks compete for this machine)
        tasks_on_machine = sum(1 for options in task_options.values() 
                              for mach, _ in options if mach == machine)
        contention = tasks_on_machine / max(no_task, 1)
        G.add_node(f'capacity_{machine}', type=1, weight=min(contention * 2, 1.0))
    
    # Add edges: Job-Task participation
    for job in range(1, no_jobs + 1):
        for task in job_tasks.get(job, []):
            G.add_edge(f'job_{job}', f'task_{task}', weight=1.0)
    
    # Add edges: Task-Machine resource consumption
    for task, options in task_options.items():
        for machine, duration in options:
            # Weight by processing intensity (longer duration = higher consumption)
            intensity = duration / max_dur
            G.add_edge(f'task_{task}', f'machine_{machine}', weight=intensity)
            
            # Connect task to machine capacity constraint
            G.add_edge(f'task_{task}', f'capacity_{machine}', weight=intensity)
    
    # Add edges: Precedence constraints
    for job in range(1, no_jobs + 1):
        job_task_list = job_tasks.get(job, [])
        for i, task in enumerate(job_task_list[:-1]):
            next_task = job_task_list[i + 1]
            constraint_id = f'prec_{task}_{next_task}'
            
            G.add_edge(f'task_{task}', constraint_id, weight=1.0)
            G.add_edge(f'task_{next_task}', constraint_id, weight=1.0)
    
    # Add conflict edges between tasks that compete for overloaded machines
    for machine in range(1, no_mach + 1):
        # Find tasks that can use this machine
        competing_tasks = []
        for task, options in task_options.items():
            for mach, duration in options:
                if mach == machine:
                    competing_tasks.append((task, duration))
                    break
        
        # If machine is overloaded, add conflicts between competing tasks
        if len(competing_tasks) > no_jobs // no_mach + 1:  # More than expected load
            competing_tasks.sort(key=lambda x: x[1], reverse=True)  # Sort by duration
            
            # Add conflicts between tasks with longest durations
            for i in range(min(len(competing_tasks), 3)):
                for j in range(i + 1, min(len(competing_tasks), 3)):
                    task1, dur1 = competing_tasks[i]
                    task2, dur2 = competing_tasks[j]
                    
                    # Conflict strength based on combined processing time
                    conflict_strength = (dur1 + dur2) / (2 * max_dur)
                    G.add_edge(f'task_{task1}', f'task_{task2}', 
                             weight=min(conflict_strength, 1.0))
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()