#!/usr/bin/env python3
"""
Graph converter for openshop problem.
Created using subagent_prompt.md version: v_02

This problem is about scheduling jobs on machines where each job has tasks that can be performed 
in any order (unlike jobshop where tasks must be done in a specific sequence).
Key challenges: minimizing makespan while ensuring no machine processes two tasks simultaneously 
and no job has two tasks running simultaneously.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the open shop scheduling problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Bipartite graph with task nodes and constraint nodes
    - Task nodes (type 0): Each (job, machine) task with duration-based weight
    - Job constraint nodes (type 1): Ensure no job has overlapping tasks  
    - Machine constraint nodes (type 1): Ensure no machine processes multiple tasks
    - Edges connect tasks to the constraints they participate in
    - Weights reflect task criticality and constraint tightness
    """
    # Extract data
    n_jobs = json_data.get('n_jobs', 0)
    n_machines = json_data.get('n_machines', 0)
    durations_flat = json_data.get('job_task_duration', [])
    
    if n_jobs == 0 or n_machines == 0:
        return nx.Graph()
    
    # Reshape flattened duration array to 2D (job, machine)
    durations = []
    for j in range(n_jobs):
        job_durations = []
        for m in range(n_machines):
            idx = j * n_machines + m
            if idx < len(durations_flat):
                job_durations.append(durations_flat[idx])
            else:
                job_durations.append(0)
        durations.append(job_durations)
    
    G = nx.Graph()
    
    # Calculate statistics for weighting
    all_durations = [d for row in durations for d in row if d > 0]
    if not all_durations:
        return G
        
    max_duration = max(all_durations)
    avg_duration = sum(all_durations) / len(all_durations)
    
    # Add task nodes (type 0) - one for each (job, machine) pair
    for j in range(n_jobs):
        for m in range(n_machines):
            duration = durations[j][m]
            if duration > 0:  # Only add if task exists
                # Weight by normalized duration with non-linear scaling
                # Longer tasks are more critical for scheduling
                weight = math.sqrt(duration / max_duration) if max_duration > 0 else 0.5
                G.add_node(f'task_j{j}_m{m}', type=0, weight=weight)
    
    # Add job constraint nodes (type 1) - one per job
    # These ensure no job has overlapping tasks
    for j in range(n_jobs):
        job_tasks = [durations[j][m] for m in range(n_machines) if durations[j][m] > 0]
        if job_tasks:
            # Weight by job complexity (total duration and variance)
            total_duration = sum(job_tasks)
            variance = sum((d - avg_duration) ** 2 for d in job_tasks) / len(job_tasks)
            # Jobs with high total duration and high variance are harder to schedule
            complexity = (total_duration / (max_duration * n_machines)) + \
                        math.sqrt(variance) / max_duration
            weight = min(complexity, 1.0)
            G.add_node(f'job_constraint_{j}', type=1, weight=weight)
    
    # Add machine constraint nodes (type 1) - one per machine  
    # These ensure no machine processes multiple tasks simultaneously
    for m in range(n_machines):
        machine_tasks = [durations[j][m] for j in range(n_jobs) if durations[j][m] > 0]
        if machine_tasks:
            # Weight by machine load and utilization pattern
            total_load = sum(machine_tasks)
            num_tasks = len(machine_tasks)
            # Machines with high load and many tasks create more scheduling conflicts
            utilization = total_load / (max_duration * n_jobs)
            task_density = num_tasks / n_jobs  # Fraction of jobs using this machine
            weight = min(utilization + task_density * 0.5, 1.0)
            G.add_node(f'machine_constraint_{m}', type=1, weight=weight)
    
    # Add bipartite edges: tasks to job constraints
    for j in range(n_jobs):
        for m in range(n_machines):
            task_node = f'task_j{j}_m{m}'
            job_constraint = f'job_constraint_{j}'
            if G.has_node(task_node) and G.has_node(job_constraint):
                # Weight by how much this task contributes to job complexity
                duration = durations[j][m]
                job_total = sum(durations[j][k] for k in range(n_machines) if durations[j][k] > 0)
                contribution = duration / job_total if job_total > 0 else 1.0
                G.add_edge(task_node, job_constraint, weight=contribution)
    
    # Add bipartite edges: tasks to machine constraints  
    for j in range(n_jobs):
        for m in range(n_machines):
            task_node = f'task_j{j}_m{m}'
            machine_constraint = f'machine_constraint_{m}'
            if G.has_node(task_node) and G.has_node(machine_constraint):
                # Weight by how much this task loads the machine
                duration = durations[j][m]
                machine_total = sum(durations[k][m] for k in range(n_jobs) if durations[k][m] > 0)
                load_fraction = duration / machine_total if machine_total > 0 else 1.0
                G.add_edge(task_node, machine_constraint, weight=load_fraction)
    
    # Add conflict edges between tasks that compete for same resources
    # This captures the hardest scheduling conflicts
    for j1 in range(n_jobs):
        for j2 in range(j1 + 1, n_jobs):
            for m in range(n_machines):
                task1 = f'task_j{j1}_m{m}'
                task2 = f'task_j{j2}_m{m}'
                if G.has_node(task1) and G.has_node(task2):
                    # Conflict weight based on duration overlap potential
                    d1, d2 = durations[j1][m], durations[j2][m]
                    if d1 > 0 and d2 > 0:
                        # Higher weight for tasks with similar durations (harder to order)
                        duration_similarity = 1.0 - abs(d1 - d2) / max(d1, d2)
                        conflict_weight = duration_similarity * 0.8  # Scale down conflict edges
                        G.add_edge(task1, task2, weight=conflict_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()