#!/usr/bin/env python3
"""
Graph converter for pax_model problem.
# Converter created with subagent_prompt.md v_02

This problem is about passenger assignment optimization in public transportation.
Passengers need to be assigned to services that travel between stations, respecting capacity constraints and timing windows.
Key challenges: complex time window constraints, service capacity limits, multi-station routing, overcrowding penalties.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the passenger assignment problem.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph with passengers as type 0 nodes and services as type 1 nodes,
    plus constraint nodes for capacity and timing constraints. Add station nodes as type 2.
    - Passengers (type 0): weighted by demand strength and timing flexibility
    - Services (type 1): weighted by capacity utilization and route complexity  
    - Stations (type 2): weighted by connectivity and passenger flow
    - Capacity constraints (type 1): weighted by tightness
    """
    # Extract data
    n_services = len(json_data.get('service_line', []))
    service_line = json_data.get('service_line', [])
    flat_pax = json_data.get('flat_pax', [])
    n_passengers = len(flat_pax) // 3
    service_schedule_flat = json_data.get('service_schedule', [])
    max_pax = json_data.get('MAX_PAX', 1500)
    forward_tw = json_data.get('forward_TW', 5)
    backward_tw = json_data.get('backward_TW', 10)
    
    # Calculate number of stations from schedule dimensions
    n_stations = len(service_schedule_flat) // n_services if n_services > 0 else 0
    
    # Reconstruct 2D service schedule array
    service_schedule = []
    for s in range(n_services):
        row = []
        for st in range(n_stations):
            idx = s * n_stations + st
            if idx < len(service_schedule_flat):
                row.append(service_schedule_flat[idx])
            else:
                row.append(0)
        service_schedule.append(row)
    
    # Reconstruct passenger trips (time, origin, destination)
    trips = []
    for i in range(n_passengers):
        if i * 3 + 2 < len(flat_pax):
            trips.append([flat_pax[i*3], flat_pax[i*3+1], flat_pax[i*3+2]])
    
    G = nx.Graph()
    
    # Add passenger nodes (type 0) with timing flexibility weighting
    max_tw = forward_tw + backward_tw
    for i in range(n_passengers):
        if i < len(trips):
            trip_time = trips[i][0]
            origin = trips[i][1] - 1  # Convert to 0-based
            dest = trips[i][2] - 1
            
            # Weight by timing flexibility (broader time windows = lower weight/easier)
            # and route distance (longer routes = higher weight/harder)
            route_distance = abs(dest - origin) if dest != origin else 1
            timing_pressure = 1.0 / (max_tw + 1)  # Higher TW = lower pressure
            distance_factor = min(route_distance / n_stations, 1.0)
            
            weight = timing_pressure + 0.5 * distance_factor
            weight = min(weight, 1.0)
            
            G.add_node(f'passenger_{i}', type=0, weight=weight)
    
    # Add service nodes (type 1) with capacity and route complexity weighting
    for s in range(n_services):
        if s < len(service_line):
            line = service_line[s]
            
            # Count active stations for this service (non-zero schedule entries)
            active_stations = sum(1 for st in range(n_stations) if service_schedule[s][st] > 0)
            
            # Weight by route complexity and line type
            route_complexity = active_stations / n_stations if n_stations > 0 else 0.5
            line_factor = min(line / max(service_line), 1.0) if service_line else 0.5
            
            weight = 0.3 * route_complexity + 0.7 * line_factor
            G.add_node(f'service_{s}', type=1, weight=weight)
    
    # Add station nodes (type 2) with connectivity weighting
    for st in range(n_stations):
        # Count how many services use this station
        service_count = sum(1 for s in range(n_services) if service_schedule[s][st] > 0)
        
        # Count passenger flow through this station
        origin_count = sum(1 for i in range(len(trips)) if trips[i][1] - 1 == st)
        dest_count = sum(1 for i in range(len(trips)) if trips[i][2] - 1 == st)
        
        connectivity = service_count / n_services if n_services > 0 else 0
        passenger_flow = (origin_count + dest_count) / n_passengers if n_passengers > 0 else 0
        
        weight = 0.6 * connectivity + 0.4 * passenger_flow
        G.add_node(f'station_{st}', type=2, weight=weight)
    
    # Add capacity constraint nodes (type 1) for each service-station pair with meaningful capacity
    for s in range(n_services):
        for st in range(n_stations):
            if service_schedule[s][st] > 0:  # Active service at this station
                # Estimate passenger load at this station for this service
                potential_passengers = sum(1 for i in range(len(trips)) 
                                         if trips[i][1] - 1 <= st < trips[i][2] - 1)
                
                # Calculate capacity tightness
                if potential_passengers > 0:
                    capacity_pressure = min(potential_passengers / max_pax, 1.0)
                    
                    # Use exponential weighting for capacity pressure
                    weight = 1.0 - math.exp(-3.0 * capacity_pressure)
                    
                    constraint_id = f'capacity_s{s}_st{st}'
                    G.add_node(constraint_id, type=1, weight=weight)
    
    # Add timing constraint nodes (type 1) for time window restrictions
    unique_times = set(trip[0] for trip in trips)
    for time in unique_times:
        passengers_at_time = sum(1 for trip in trips if trip[0] == time)
        time_pressure = passengers_at_time / n_passengers if n_passengers > 0 else 0
        
        # Non-linear weighting for timing pressure
        weight = math.sqrt(time_pressure)
        G.add_node(f'timing_{time}', type=1, weight=weight)
    
    # Add edges: passenger to compatible services (sample only for efficiency)
    # For large instances, only create representative connections to avoid explosion
    max_connections_per_passenger = min(5, n_services)  # Limit connections per passenger
    
    for i in range(n_passengers):  # Keep all passengers but limit connections
        if i < len(trips):
            trip_time, origin, dest = trips[i][0], trips[i][1] - 1, trips[i][2] - 1
            
            compatible_services = []
            # Find compatible services based on timing and routing constraints
            for s in range(n_services):
                if (origin < len(service_schedule[s]) and dest < len(service_schedule[s]) and
                    service_schedule[s][origin] > 0 and service_schedule[s][dest] > 0):
                    
                    # Check timing window compatibility
                    service_time = service_schedule[s][origin]
                    if (service_time >= trip_time - forward_tw and 
                        service_time <= trip_time + backward_tw):
                        
                        # Calculate edge weight based on timing fit and route efficiency
                        time_fit = 1.0 - abs(service_time - trip_time) / (forward_tw + backward_tw)
                        route_efficiency = 1.0 / (abs(dest - origin) + 1)
                        
                        weight = 0.7 * time_fit + 0.3 * route_efficiency
                        compatible_services.append((s, weight))
            
            # Only connect to best matching services to reduce density
            compatible_services.sort(key=lambda x: x[1], reverse=True)
            for s, weight in compatible_services[:max_connections_per_passenger]:
                G.add_edge(f'passenger_{i}', f'service_{s}', weight=weight)
    
    # Add edges: passengers to stations (origin/destination) - sample for efficiency
    for i in range(len(trips)):
        if i < n_passengers:
            origin = trips[i][1] - 1
            dest = trips[i][2] - 1
            
            # Strong connection to origin and destination
            G.add_edge(f'passenger_{i}', f'station_{origin}', weight=1.0)
            G.add_edge(f'passenger_{i}', f'station_{dest}', weight=1.0)
    
    # Add edges: services to stations they serve
    for s in range(n_services):
        for st in range(n_stations):
            if service_schedule[s][st] > 0:
                # Weight by schedule importance (earlier times = higher weight)
                max_time = max(max(row) for row in service_schedule) if service_schedule else 1
                time_importance = 1.0 - (service_schedule[s][st] / max_time) if max_time > 0 else 0.5
                
                G.add_edge(f'service_{s}', f'station_{st}', weight=time_importance)
    
    # Add edges: passengers to capacity constraints they affect (sample for efficiency)
    for i in range(min(n_passengers, 1500)):  # Moderate limit for capacity constraints
        if i < len(trips):
            origin = trips[i][1] - 1
            dest = trips[i][2] - 1
            
            # Connect to capacity constraints for key stations only
            key_stations = [origin, dest, (origin + dest) // 2]  # Origin, dest, midpoint
            for s in range(min(n_services, 5)):  # Limit services too
                for st in key_stations:
                    if 0 <= st < n_stations:
                        constraint_id = f'capacity_s{s}_st{st}'
                        if G.has_node(constraint_id):
                            # Weight by passenger's contribution to capacity pressure
                            weight = 1.0 / max_pax  # Each passenger adds unit load
                            G.add_edge(f'passenger_{i}', constraint_id, weight=weight)
    
    # Add edges: passengers to timing constraints (sample for efficiency)
    for i in range(len(trips)):
        if i < n_passengers:
            trip_time = trips[i][0]
            timing_node = f'timing_{trip_time}'
            if G.has_node(timing_node):
                G.add_edge(f'passenger_{i}', timing_node, weight=1.0)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()