#!/usr/bin/env python3
"""
Graph converter for yumi-static problem.
Created using subagent_prompt.md version: v_02

This problem is about dual-arm robotic scheduling with pick-and-place operations.
The YuMi robot has two arms that must coordinate to perform tasks at different locations
while avoiding collisions and satisfying tool constraints (suction/gripper).

Key challenges: 
- Dual-arm coordination and collision avoidance
- Tool-specific task ordering (gripper/suction pick-and-place sequences)
- Location accessibility constraints per arm
- Travel time optimization between locations
- Fixture task dependencies and timing
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the dual-arm robotic scheduling problem.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph with tasks as variables and constraints as explicit nodes.
    - Task nodes (type 0): Individual tasks with weights based on duration and complexity
    - Constraint nodes (type 1): Arm capacity, tool sequences, location conflicts, timing
    - Location nodes (type 2): Shared locations with accessibility weights
    
    The graph captures:
    - Task-to-constraint relationships (bipartite edges)
    - Location accessibility and travel costs
    - Tool sequence dependencies
    - Arm workload distribution
    """
    
    G = nx.Graph()
    
    # Extract key problem data
    task_durations = json_data.get('task_durations', [])
    left_travel_times = json_data.get('left_arm_travel_times', [])
    right_travel_times = json_data.get('right_arm_travel_times', [])
    location_order = json_data.get('location_order', [])
    gripper_orders = json_data.get('gripper_pick_tasks_orders', [])
    suction_orders = json_data.get('suction_pick_tasks_orders', [])
    fixture_orders = json_data.get('fixture_task_orders', [])
    output_tasks = json_data.get('OUTPUT_TASKS', [])
    
    # Calculate problem dimensions
    n_tasks = len(task_durations) // 2 if task_durations else 0  # Divided by 2 (left/right arm durations)
    n_locations = int(math.sqrt(len(left_travel_times))) if left_travel_times else 0
    
    if n_tasks == 0 or n_locations == 0:
        return G
    
    # Normalize durations for weighting
    max_duration = max(task_durations) if task_durations else 1
    
    # Add task nodes (type 0) - decision variables for task assignment and timing
    for task_id in range(1, n_tasks + 1):
        if task_id <= len(task_durations) // 2:
            left_duration = task_durations[task_id - 1] if task_id - 1 < len(task_durations) else 0
            right_duration = task_durations[task_id - 1 + n_tasks] if task_id - 1 + n_tasks < len(task_durations) else 0
            
            # Weight by average duration and arm flexibility
            avg_duration = (left_duration + right_duration) / 2
            duration_weight = avg_duration / max_duration
            
            # Higher weight for tasks with significant duration difference (harder to assign)
            arm_flexibility = 1.0 - abs(left_duration - right_duration) / max(left_duration + right_duration, 1)
            
            # Combine duration and flexibility using non-linear scaling
            task_weight = 0.3 * duration_weight + 0.7 * math.exp(-2.0 * arm_flexibility)
            
            G.add_node(f'task_{task_id}', type=0, weight=min(task_weight, 1.0))
    
    # Add location nodes (type 2) - shared resources
    location_accessibility = {}
    for loc in range(n_locations):
        left_accessible = left_travel_times[loc * n_locations + loc] >= 0 if loc * n_locations + loc < len(left_travel_times) else False
        right_accessible = right_travel_times[loc * n_locations + loc] >= 0 if loc * n_locations + loc < len(right_travel_times) else False
        
        # Weight by accessibility (lower weight = more constrained = harder)
        if left_accessible and right_accessible:
            accessibility = 1.0  # Both arms can reach
        elif left_accessible or right_accessible:
            accessibility = 0.5  # Only one arm can reach
        else:
            accessibility = 0.1  # Neither arm (shouldn't happen in valid instances)
        
        location_accessibility[loc] = accessibility
        G.add_node(f'location_{loc}', type=2, weight=1.0 - accessibility)
    
    # Add constraint nodes (type 1)
    
    # 1. Arm capacity constraints (one per arm)
    G.add_node('left_arm_capacity', type=1, weight=0.8)
    G.add_node('right_arm_capacity', type=1, weight=0.8)
    
    # 2. Tool sequence constraints (gripper pick-and-place)
    gripper_sequence_id = 0
    for i in range(0, len(gripper_orders), 3):  # Each sequence has 3 elements
        if i + 2 < len(gripper_orders) and gripper_orders[i] > 0:
            pick_task = gripper_orders[i]
            place_task = gripper_orders[i + 1]
            
            # Weight by sequence length and task importance
            sequence_weight = 0.6 + 0.4 * math.log(max(pick_task, place_task)) / math.log(max(n_tasks, 2))
            
            G.add_node(f'gripper_seq_{gripper_sequence_id}', type=1, weight=min(sequence_weight, 1.0))
            
            # Connect tasks to sequence constraint
            if pick_task <= n_tasks:
                G.add_edge(f'task_{pick_task}', f'gripper_seq_{gripper_sequence_id}', weight=0.9)
            if place_task <= n_tasks:
                G.add_edge(f'task_{place_task}', f'gripper_seq_{gripper_sequence_id}', weight=0.9)
            
            gripper_sequence_id += 1
    
    # 3. Suction sequence constraints
    suction_sequence_id = 0
    for i in range(0, len(suction_orders), 4):  # Each sequence has 4 elements
        if i + 3 < len(suction_orders) and suction_orders[i] > 0:
            pick_task = suction_orders[i]
            transfer_task = suction_orders[i + 1]
            place_task = suction_orders[i + 2]
            
            # Weight by sequence complexity (3-task sequence is more complex)
            sequence_weight = 0.7 + 0.3 * (3 / max(n_tasks / 3, 1))
            
            G.add_node(f'suction_seq_{suction_sequence_id}', type=1, weight=min(sequence_weight, 1.0))
            
            # Connect all tasks in sequence
            for task in [pick_task, transfer_task, place_task]:
                if task <= n_tasks and task > 0:
                    G.add_edge(f'task_{task}', f'suction_seq_{suction_sequence_id}', weight=0.85)
            
            suction_sequence_id += 1
    
    # 4. Fixture ordering constraints
    fixture_id = 0
    for i in range(0, len(fixture_orders), 8):  # Each fixture has up to 8 tasks
        fixture_tasks = []
        for j in range(8):
            if i + j < len(fixture_orders) and fixture_orders[i + j] > 0:
                fixture_tasks.append(fixture_orders[i + j])
        
        if fixture_tasks:
            # Weight by fixture complexity (number of tasks)
            complexity_weight = 0.5 + 0.5 * len(fixture_tasks) / 8
            
            G.add_node(f'fixture_{fixture_id}', type=1, weight=complexity_weight)
            
            # Connect all tasks in fixture
            for task in fixture_tasks:
                if task <= n_tasks:
                    # Higher weight for longer sequences
                    edge_weight = 0.6 + 0.4 * len(fixture_tasks) / 8
                    G.add_edge(f'task_{task}', f'fixture_{fixture_id}', weight=min(edge_weight, 1.0))
            
            fixture_id += 1
    
    # 5. Location accessibility constraints
    for loc in range(n_locations):
        constraint_name = f'location_access_{loc}'
        # Weight by how restrictive the location is
        constraint_weight = 1.0 - location_accessibility[loc]
        G.add_node(constraint_name, type=1, weight=constraint_weight)
        
        # Connect to location resource node
        G.add_edge(constraint_name, f'location_{loc}', weight=0.8)
    
    # 6. Travel time constraints (for each location pair with significant travel cost)
    travel_constraint_id = 0
    for i in range(n_locations):
        for j in range(n_locations):
            if i != j:
                left_idx = i * n_locations + j
                right_idx = i * n_locations + j
                
                left_time = left_travel_times[left_idx] if left_idx < len(left_travel_times) else -1
                right_time = right_travel_times[right_idx] if right_idx < len(right_travel_times) else -1
                
                # Only create constraint for significant travel times
                if (left_time > 50 or right_time > 50) and (left_time >= 0 or right_time >= 0):
                    # Weight by travel cost
                    max_travel = max(left_time if left_time >= 0 else 0, right_time if right_time >= 0 else 0)
                    travel_weight = min(max_travel / 150.0, 1.0)  # Normalize to reasonable range
                    
                    G.add_node(f'travel_{travel_constraint_id}', type=1, weight=travel_weight)
                    
                    # Connect to both locations
                    G.add_edge(f'travel_{travel_constraint_id}', f'location_{i}', weight=0.7)
                    G.add_edge(f'travel_{travel_constraint_id}', f'location_{j}', weight=0.7)
                    
                    travel_constraint_id += 1
    
    # Connect tasks to arm capacity constraints (bipartite structure)
    for task_id in range(1, n_tasks + 1):
        task_node = f'task_{task_id}'
        if task_node in G:
            # Each task must be assigned to exactly one arm
            G.add_edge(task_node, 'left_arm_capacity', weight=0.5)
            G.add_edge(task_node, 'right_arm_capacity', weight=0.5)
    
    # Add high-priority output task constraint if present
    if output_tasks:
        G.add_node('output_completion', type=1, weight=0.9)
        for output_task in output_tasks:
            if output_task <= n_tasks:
                G.add_edge(f'task_{output_task}', 'output_completion', weight=0.95)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()