#!/usr/bin/env python3
"""
Graph converter for yumi-static problem.
Created using subagent_prompt.md version: v_02

This problem is about dual-arm robotic scheduling for pick-and-place assembly tasks.
Key challenges: Collision avoidance between arms, sequencing constraints for fixtures and grippers,
travel time optimization, and resource allocation between two robotic arms.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the yumi-static problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Model as bipartite graph with explicit constraint nodes
    - Task nodes (type 0): Individual assembly tasks requiring scheduling
    - Constraint nodes (type 1): Fixture sequences, gripper constraints, collision zones
    - Resource nodes (type 2): The two robotic arms and shared locations
    - Edge weights reflect travel times, sequence criticality, and resource contention
    """
    G = nx.Graph()
    
    # Extract problem dimensions
    task_durations = json_data.get('task_durations', [])
    if len(task_durations) < 2:
        return G
        
    # Task durations are stored as [left_arm_durations, right_arm_durations]
    # But in flat format - need to reshape
    n_tasks = len(task_durations) // 2
    if n_tasks == 0:
        return G
    
    left_durations = task_durations[:n_tasks]
    right_durations = task_durations[n_tasks:]
    
    # Travel time matrices
    left_travel = json_data.get('left_arm_travel_times', [])
    right_travel = json_data.get('right_arm_travel_times', [])
    
    # Compute matrix size from travel times
    n_locations = int(math.sqrt(len(left_travel))) if left_travel else 0
    
    # Task ordering constraints
    fixture_orders = json_data.get('fixture_task_orders', [])
    gripper_orders = json_data.get('gripper_pick_tasks_orders', [])
    suction_orders = json_data.get('suction_pick_tasks_orders', [])
    
    # Location information
    location_order = json_data.get('location_order', [])
    output_tasks = json_data.get('OUTPUT_TASKS', [])
    camera_tasks = json_data.get('CAMERA_TASKS', [])
    
    # Collision constraints
    fixture_obstruction = json_data.get('FixtureWorkObstruction', [2, 4])
    
    # Add task nodes (type 0) with duration-based weights
    max_duration = max(max(left_durations, default=1), max(right_durations, default=1))
    
    for i in range(n_tasks):
        # Weight by average duration (higher duration = more critical)
        avg_duration = (left_durations[i] + right_durations[i]) / 2
        duration_weight = avg_duration / max_duration
        
        # Apply non-linear scaling - exponential to emphasize long tasks
        weight = 1.0 - math.exp(-3.0 * duration_weight)
        
        G.add_node(f'task_{i+1}', type=0, weight=weight)
    
    # Add arm resource nodes (type 2)
    G.add_node('left_arm', type=2, weight=0.5)
    G.add_node('right_arm', type=2, weight=0.5)
    
    # Add location nodes (type 2) with collision criticality weights
    for i in range(n_locations):
        if i < len(location_order):
            loc_order = location_order[i]
            # Locations near collision boundaries are more critical
            if loc_order >= 0:
                collision_criticality = 0.5
                if loc_order <= fixture_obstruction[0] or loc_order >= fixture_obstruction[1]:
                    collision_criticality = 0.8
                G.add_node(f'location_{i+1}', type=2, weight=collision_criticality)
    
    # Add fixture constraint nodes (type 1)
    fixture_count = 0
    if fixture_orders:
        # Reshape fixture orders into sequences
        max_fixture_len = 8  # Estimate from data structure
        n_fixtures = len(fixture_orders) // max_fixture_len
        
        for f in range(n_fixtures):
            fixture_tasks = []
            for pos in range(max_fixture_len):
                idx = f * max_fixture_len + pos
                if idx < len(fixture_orders) and fixture_orders[idx] > 0:
                    fixture_tasks.append(fixture_orders[idx])
                elif fixture_orders[idx] == -1:
                    break
            
            if len(fixture_tasks) > 1:
                fixture_count += 1
                # Weight by sequence length and average task duration
                avg_seq_duration = sum(left_durations[t-1] + right_durations[t-1] 
                                     for t in fixture_tasks if t <= n_tasks) / (2 * len(fixture_tasks))
                seq_weight = min(len(fixture_tasks) / 5.0 + avg_seq_duration / max_duration, 1.0)
                
                constraint_id = f'fixture_seq_{f+1}'
                G.add_node(constraint_id, type=1, weight=seq_weight)
                
                # Connect tasks to fixture constraint
                for t in fixture_tasks:
                    if t <= n_tasks:
                        task_id = f'task_{t}'
                        if G.has_node(task_id):
                            # Weight by position in sequence (earlier = more critical)
                            pos_weight = 1.0 - (fixture_tasks.index(t) / len(fixture_tasks) * 0.3)
                            G.add_edge(task_id, constraint_id, weight=pos_weight)
    
    # Add gripper constraint nodes (type 1)
    if gripper_orders:
        # Process gripper pick-and-place sequences
        gripper_count = 0
        for i in range(0, len(gripper_orders), 3):  # Assuming triplets: pick, place, -1
            if i + 1 < len(gripper_orders) and gripper_orders[i] > 0 and gripper_orders[i+1] > 0:
                gripper_count += 1
                pick_task = gripper_orders[i]
                place_task = gripper_orders[i+1]
                
                # High weight for pick-place pairs (tight coupling)
                constraint_id = f'gripper_seq_{gripper_count}'
                G.add_node(constraint_id, type=1, weight=0.9)
                
                if pick_task <= n_tasks:
                    G.add_edge(f'task_{pick_task}', constraint_id, weight=0.95)
                if place_task <= n_tasks:
                    G.add_edge(f'task_{place_task}', constraint_id, weight=0.85)
    
    # Add collision constraint nodes for critical location pairs
    collision_constraints = 0
    for i in range(n_locations):
        for j in range(i+1, min(i+3, n_locations)):  # Only nearby locations
            if (i < len(location_order) and j < len(location_order) and 
                location_order[i] >= 0 and location_order[j] >= 0):
                
                # Check if locations are in collision-prone zones
                order_i, order_j = location_order[i], location_order[j]
                if (abs(order_i - fixture_obstruction[0]) <= 1 or 
                    abs(order_i - fixture_obstruction[1]) <= 1 or
                    abs(order_j - fixture_obstruction[0]) <= 1 or 
                    abs(order_j - fixture_obstruction[1]) <= 1):
                    
                    collision_constraints += 1
                    constraint_id = f'collision_{i+1}_{j+1}'
                    # Weight by proximity to obstruction zones
                    collision_weight = 0.6 + 0.3 * math.exp(-abs(order_i - order_j) / 2.0)
                    G.add_node(constraint_id, type=1, weight=collision_weight)
                    
                    G.add_edge(f'location_{i+1}', constraint_id, weight=0.7)
                    G.add_edge(f'location_{j+1}', constraint_id, weight=0.7)
    
    # Add task-arm feasibility edges based on travel times
    for t in range(1, n_tasks + 1):
        task_id = f'task_{t}'
        if not G.has_node(task_id):
            continue
            
        # Check left arm feasibility
        left_feasible = any(left_travel[i * n_locations + i] >= 0 for i in range(n_locations))
        if left_feasible:
            # Weight by relative efficiency (shorter duration = higher weight)
            left_weight = 1.0 - (left_durations[t-1] / max_duration) * 0.5
            G.add_edge(task_id, 'left_arm', weight=left_weight)
        
        # Check right arm feasibility
        right_feasible = any(right_travel[i * n_locations + i] >= 0 for i in range(n_locations))
        if right_feasible:
            right_weight = 1.0 - (right_durations[t-1] / max_duration) * 0.5
            G.add_edge(task_id, 'right_arm', weight=right_weight)
    
    # Add high-value task connections for output tasks
    for output_task in output_tasks:
        if output_task <= n_tasks:
            output_id = f'task_{output_task}'
            if G.has_node(output_id):
                # Output tasks are critical - connect to both arms with high weight
                G.add_edge(output_id, 'left_arm', weight=0.9)
                G.add_edge(output_id, 'right_arm', weight=0.9)
    
    # Add global complexity constraint if problem is large
    if n_tasks > 10 or fixture_count > 2:
        G.add_node('global_complexity', type=1, weight=0.8)
        
        # Connect to most critical tasks and constraints
        for node_id, data in G.nodes(data=True):
            if data['type'] == 0 and data['weight'] > 0.7:  # High-weight tasks
                G.add_edge(node_id, 'global_complexity', weight=data['weight'] * 0.6)
            elif data['type'] == 1 and data['weight'] > 0.8:  # Critical constraints
                G.add_edge(node_id, 'global_complexity', weight=data['weight'] * 0.5)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()