#!/usr/bin/env python3
"""
Graph converter for train rescheduling problem.
Created using subagent_prompt.md version: v_02

This problem is about rescheduling trains on a single track with multiple stations
to minimize passenger travel time when a delay occurs. 
Key challenges: train ordering constraints, capacity limits, passenger flow optimization,
delay propagation effects, and temporal dependencies between trains.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the train rescheduling problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Model as bipartite graph with trains/stations as variables and constraints as explicit nodes
    - Trains at stations are decision variables (type 0)
    - Temporal ordering constraints between trains (type 1) 
    - Capacity constraints for passenger collection (type 1)
    - Station-specific passenger flow constraints (type 1)
    - Delay propagation constraints (type 1)
    - Edge weights reflect constraint tightness and temporal criticality
    """
    # Extract problem parameters
    n = json_data.get('n', 0)  # number of trains
    m = json_data.get('m', 0)  # number of stations
    capacity = json_data.get('capacity', 500)
    delay_time = json_data.get('delayTime', 0)
    delay_train = json_data.get('delayTrain', 1) - 1  # Convert to 0-indexed
    delay_duration = json_data.get('delayDuration', 0)
    max_time = json_data.get('maxTime', 1140)
    
    distance = json_data.get('distance', [])
    passenger_flow = json_data.get('passengerFlow', [])
    scheduled_arrival = json_data.get('scheduledArrival', [])
    scheduled_departure = json_data.get('scheduledDeparture', [])
    
    G = nx.Graph()
    
    # Convert flat arrays to 2D (trains x stations)
    def get_2d_value(flat_array, train_idx, station_idx, n_stations):
        if train_idx * n_stations + station_idx < len(flat_array):
            return flat_array[train_idx * n_stations + station_idx]
        return 0
    
    # Add train-station variable nodes (type 0)
    # Weight based on criticality: delayed trains, high passenger flow stations, tight schedules
    max_flow = max(passenger_flow) if passenger_flow else 1
    for i in range(n):
        for j in range(m):
            # Calculate criticality based on multiple factors
            delay_impact = 1.0 if i == delay_train else 0.5
            
            station_flow = passenger_flow[j] if j < len(passenger_flow) else 0
            flow_importance = station_flow / max_flow if max_flow > 0 else 0.0
            
            # Schedule tightness - how much buffer time exists
            arrival_time = get_2d_value(scheduled_arrival, i, j, m)
            departure_time = get_2d_value(scheduled_departure, i, j, m)
            dwell_time = departure_time - arrival_time
            schedule_tightness = 1.0 - min(dwell_time / 20.0, 1.0)  # Normalize by typical dwell
            
            # Combined weight using non-linear scaling
            weight = math.sqrt(delay_impact * 0.4 + flow_importance * 0.4 + schedule_tightness * 0.2)
            weight = max(0.1, min(weight, 1.0))
            
            G.add_node(f'train_{i}_station_{j}', type=0, weight=weight)
    
    # Add temporal ordering constraint nodes (type 1)
    # Trains must maintain order at station 1 and cannot overlap at other stations
    for i in range(n-1):
        # Station 1 ordering constraint
        ordering_tightness = 0.8  # High importance for maintaining order
        G.add_node(f'order_station_1_trains_{i}_{i+1}', type=1, weight=ordering_tightness)
        
        # Connect to relevant train variables
        G.add_edge(f'train_{i}_station_1', f'order_station_1_trains_{i}_{i+1}', weight=0.9)
        G.add_edge(f'train_{i+1}_station_1', f'order_station_1_trains_{i}_{i+1}', weight=0.9)
        
        # Non-overlap constraints at intermediate stations
        for j in range(1, m-1):
            overlap_risk = 1.0 if abs(i - delay_train) <= 1 or abs(i+1 - delay_train) <= 1 else 0.6
            G.add_node(f'no_overlap_station_{j}_trains_{i}_{i+1}', type=1, weight=overlap_risk)
            
            G.add_edge(f'train_{i}_station_{j}', f'no_overlap_station_{j}_trains_{i}_{i+1}', weight=0.8)
            G.add_edge(f'train_{i+1}_station_{j}', f'no_overlap_station_{j}_trains_{i}_{i+1}', weight=0.8)
    
    # Add travel time constraints (type 1)
    for i in range(n):
        for j in range(m-1):
            if j < len(distance):
                # Travel time constraint tightness based on distance and delay impact
                travel_distance = distance[j]
                base_tightness = min(travel_distance / 10.0, 1.0)  # Normalize by typical distance
                
                # Higher tightness for delayed train and subsequent trains
                delay_multiplier = 1.5 if i >= delay_train else 1.0
                constraint_weight = min(base_tightness * delay_multiplier, 1.0)
                
                G.add_node(f'travel_train_{i}_segment_{j}', type=1, weight=constraint_weight)
                
                # Connect to departure and arrival variables
                G.add_edge(f'train_{i}_station_{j}', f'travel_train_{i}_segment_{j}', weight=0.7)
                G.add_edge(f'train_{i}_station_{j+1}', f'travel_train_{i}_segment_{j}', weight=0.7)
    
    # Add passenger capacity constraints (type 1)
    for i in range(n):
        for j in range(m):
            if j < len(passenger_flow) and passenger_flow[j] > 0:
                # Capacity constraint tightness based on passenger flow and capacity
                flow_at_station = passenger_flow[j]
                capacity_utilization = flow_at_station / capacity
                
                # Non-linear scaling for capacity pressure
                if capacity_utilization > 0.8:
                    tightness = 0.9 + 0.1 * math.exp(5 * (capacity_utilization - 0.8))
                else:
                    tightness = capacity_utilization
                
                tightness = min(tightness, 1.0)
                
                G.add_node(f'capacity_train_{i}_station_{j}', type=1, weight=tightness)
                G.add_edge(f'train_{i}_station_{j}', f'capacity_train_{i}_station_{j}', weight=0.8)
    
    # Add delay propagation constraint (type 1) - special high-importance constraint
    if delay_duration > 0:
        # Delay impact spreads through the network
        delay_impact_weight = min(delay_duration / max_time * 2.0, 1.0)
        G.add_node(f'delay_propagation', type=1, weight=delay_impact_weight)
        
        # Connect to all trains that could be affected by the delay
        for i in range(delay_train, n):
            for j in range(m):
                # Exponential decay of delay impact based on train distance from delayed train
                distance_factor = math.exp(-0.5 * (i - delay_train))
                edge_weight = min(delay_impact_weight * distance_factor, 1.0)
                
                if edge_weight > 0.1:  # Only add meaningful connections
                    G.add_edge(f'train_{i}_station_{j}', f'delay_propagation', weight=edge_weight)
    
    # Add passenger flow constraints between consecutive stations for trains
    for i in range(n):
        for j in range(m-1):
            if j < len(passenger_flow) and passenger_flow[j] > 0:
                # Flow conservation constraint
                flow_importance = passenger_flow[j] / max_flow if max_flow > 0 else 0.5
                G.add_node(f'flow_conservation_train_{i}_segment_{j}', type=1, weight=flow_importance)
                
                G.add_edge(f'train_{i}_station_{j}', f'flow_conservation_train_{i}_segment_{j}', weight=0.6)
                G.add_edge(f'train_{i}_station_{j+1}', f'flow_conservation_train_{i}_segment_{j}', weight=0.6)
    
    # Add cross-train competition for passengers (type 1)
    high_flow_stations = [j for j in range(m) if j < len(passenger_flow) and passenger_flow[j] > max_flow * 0.5]
    for j in high_flow_stations:
        if passenger_flow[j] > 0:
            competition_intensity = passenger_flow[j] / max_flow
            G.add_node(f'passenger_competition_station_{j}', type=1, weight=competition_intensity)
            
            # Connect all trains that could pick up passengers at this station
            for i in range(n):
                competition_weight = 0.5 + 0.3 * (1.0 - abs(i - n//2) / (n//2 + 1))  # Central trains more competitive
                G.add_edge(f'train_{i}_station_{j}', f'passenger_competition_station_{j}', weight=competition_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()