#!/usr/bin/env python3
"""
Graph converter for train scheduling problem.
Created using subagent_prompt.md version: v_02

This problem is about rescheduling trains to minimize passenger travel time after a delay.
Key challenges: ordering constraints, capacity limits, passenger flow optimization, delay propagation.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the train scheduling problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Model trains, stations, and constraints as a bipartite graph
    - Train-station pairs as Type 0 nodes (decision variables)
    - Constraints as Type 1 nodes (ordering, capacity, delay)
    - Passenger flows and delays determine weights
    """
    n = json_data.get('n', 0)  # number of trains
    m = json_data.get('m', 0)  # number of stations
    capacity = json_data.get('capacity', 500)
    delay_train = json_data.get('delayTrain', 1)
    delay_duration = json_data.get('delayDuration', 0)
    delay_time = json_data.get('delayTime', 0)
    
    distance = json_data.get('distance', [])
    passenger_flow = json_data.get('passengerFlow', [])
    passenger_start = json_data.get('passengerStart', [])
    scheduled_arrival = json_data.get('scheduledArrival', [])
    scheduled_departure = json_data.get('scheduledDeparture', [])
    
    # Convert 2D arrays from flat format
    sched_arr = []
    sched_dep = []
    for i in range(n):
        sched_arr.append(scheduled_arrival[i*m:(i+1)*m])
        sched_dep.append(scheduled_departure[i*m:(i+1)*m])
    
    G = nx.Graph()
    
    # Type 0 nodes: Train-station pairs (decision variables)
    max_flow = max(passenger_flow) if passenger_flow else 1
    max_delay = delay_duration if delay_duration > 0 else 1
    
    for i in range(n):
        for j in range(m):
            # Weight based on passenger demand and delay impact
            station_demand = passenger_flow[j] if j < len(passenger_flow) else 0
            demand_weight = station_demand / max_flow if max_flow > 0 else 0.5
            
            # Higher weight for delayed train and critical stations
            delay_impact = 1.0 if i == delay_train - 1 else 0.7
            station_criticality = 1.0 if j in [0, m-1] else 0.8  # First/last stations critical
            
            weight = (demand_weight * delay_impact * station_criticality)
            weight = min(max(weight, 0.1), 1.0)  # Normalize to [0.1, 1.0]
            
            G.add_node(f'train_{i}_station_{j}', type=0, weight=weight)
    
    # Type 1 nodes: Constraint nodes
    
    # 1. Ordering constraints at station 1
    G.add_node('ordering_station_1', type=1, weight=1.0)
    
    # 2. Station dwelling constraints (at most one train at a time)
    for j in range(1, m-1):  # Stations 2 to m-1
        station_demand = passenger_flow[j] if j < len(passenger_flow) else 0
        tightness = station_demand / capacity if capacity > 0 else 0.5
        G.add_node(f'dwelling_station_{j}', type=1, weight=min(tightness, 1.0))
    
    # 3. Travel time constraints between consecutive stations
    max_distance = max(distance) if distance else 1
    for j in range(m-1):
        dist = distance[j] if j < len(distance) else 1
        # Longer distances create tighter timing constraints
        constraint_weight = math.sqrt(dist / max_distance) if max_distance > 0 else 0.5
        G.add_node(f'travel_constraint_{j}_{j+1}', type=1, weight=constraint_weight)
    
    # 4. Delay propagation constraint
    delay_severity = min(delay_duration / max_delay, 1.0) if max_delay > 0 else 0.5
    G.add_node('delay_constraint', type=1, weight=delay_severity)
    
    # 5. Capacity constraints for passenger collection
    for j in range(1, m-1):  # Only intermediate stations have passengers
        total_flow = sum(passenger_flow[k] for k in range(j, m) if k < len(passenger_flow))
        capacity_pressure = min(total_flow / capacity, 1.0) if capacity > 0 else 0.5
        G.add_node(f'capacity_station_{j}', type=1, weight=capacity_pressure)
    
    # 6. Schedule adherence constraints
    for i in range(n):
        schedule_criticality = 0.8 if i == delay_train - 1 else 0.6
        G.add_node(f'schedule_constraint_train_{i}', type=1, weight=schedule_criticality)
    
    # Bipartite edges: Variable-constraint participation
    
    # Connect trains to ordering constraint at station 1
    for i in range(n):
        order_strength = 1.0 - (i / n) if n > 1 else 0.5  # Earlier trains have stronger constraints
        G.add_edge(f'train_{i}_station_0', 'ordering_station_1', weight=order_strength)
    
    # Connect train-station pairs to dwelling constraints
    for i in range(n):
        for j in range(1, m-1):
            dwelling_importance = 0.8 if passenger_flow[j] > 0 else 0.3
            G.add_edge(f'train_{i}_station_{j}', f'dwelling_station_{j}', weight=dwelling_importance)
    
    # Connect consecutive stations to travel time constraints
    for i in range(n):
        for j in range(m-1):
            G.add_edge(f'train_{i}_station_{j}', f'travel_constraint_{j}_{j+1}', weight=0.9)
            G.add_edge(f'train_{i}_station_{j+1}', f'travel_constraint_{j}_{j+1}', weight=0.9)
    
    # Connect delayed train to delay constraint
    for j in range(m):
        delay_impact = math.exp(-0.5 * abs(j - m//2)) if m > 1 else 0.7  # Impact decreases from middle
        G.add_edge(f'train_{delay_train-1}_station_{j}', 'delay_constraint', weight=delay_impact)
    
    # Connect trains to capacity constraints at passenger stations
    for i in range(n):
        for j in range(1, m-1):
            if j < len(passenger_flow) and passenger_flow[j] > 0:
                capacity_involvement = min(passenger_flow[j] / capacity, 1.0) if capacity > 0 else 0.5
                G.add_edge(f'train_{i}_station_{j}', f'capacity_station_{j}', weight=capacity_involvement)
    
    # Connect trains to their schedule constraints
    for i in range(n):
        for j in range(m):
            schedule_importance = 0.9 if j in [0, m-1] else 0.7  # Terminal stations more important
            G.add_edge(f'train_{i}_station_{j}', f'schedule_constraint_train_{i}', weight=schedule_importance)
    
    # Add conflict edges between trains at the same station (resource competition)
    for j in range(1, m-1):  # Only stations with passengers
        station_demand = passenger_flow[j] if j < len(passenger_flow) else 0
        if station_demand > 0:
            for i1 in range(n):
                for i2 in range(i1+1, n):
                    # Conflict strength based on overlapping schedules and passenger demand
                    time_overlap = 0.6  # Simplified - could calculate from schedules
                    demand_pressure = min(station_demand / capacity, 1.0) if capacity > 0 else 0.5
                    conflict_weight = time_overlap * demand_pressure
                    
                    if conflict_weight > 0.3:  # Only add significant conflicts
                        G.add_edge(f'train_{i1}_station_{j}', f'train_{i2}_station_{j}', 
                                 weight=min(conflict_weight, 1.0))
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()