#!/usr/bin/env python3
"""
Graph converter for trains problem.
Created using subagent_prompt.md version: v_02

This problem is about train scheduling and routing optimization.
Key challenges: Platform capacity at stations, track conflicts (single vs double track), 
service timing constraints, engine assignment, and minimizing delays and skip costs.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the train scheduling problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph with services/stations as variables and 
    constraints as constraint nodes. Focus on platform capacity conflicts, 
    timing dependencies, and track usage patterns.
    - Services and stations are key entities needing scheduling decisions
    - Platform capacity constraints create bottlenecks at stations
    - Track segments with limited capacity create scheduling conflicts
    - Service timing requirements create precedence relationships
    """
    # Extract data from JSON
    makespan = json_data.get('makespan', 240)
    min_sep = json_data.get('min_sep', 4)
    max_route_length = json_data.get('max_route_length', 6)
    
    # Station data
    minimal_wait = json_data.get('minimal_wait', [])
    skip_cost = json_data.get('skip_cost', [])
    platform = json_data.get('platform', [])
    
    # Service data  
    rlength = json_data.get('rlength', [])
    service_start = json_data.get('service_start', [])
    service_end = json_data.get('service_end', [])
    travel_time = json_data.get('travel_time', [])
    
    # Route and service mappings
    sroute_data = json_data.get('SERVICE', [])  # Service to route mapping
    engine_data = json_data.get('ENGINE', [])
    
    num_stops = len(minimal_wait) if minimal_wait else 7
    num_services = len(service_start) if service_start else 12
    num_engines = len(engine_data) if engine_data else 6
    num_routes = len(rlength) if rlength else 2
    
    G = nx.Graph()
    
    # Add service nodes (type 0 - decision variables)
    max_service_duration = max(service_end) - min(service_start) if service_end and service_start else makespan
    for i in range(num_services):
        # Weight services by their timing criticality and duration
        start_time = service_start[i] if i < len(service_start) else 0
        end_time = service_end[i] if i < len(service_end) else makespan
        duration = end_time - start_time
        
        # Services with tight timing windows are more critical
        timing_pressure = duration / max_service_duration if max_service_duration > 0 else 0.5
        urgency = 1.0 - (start_time / makespan) if makespan > 0 else 0.5
        
        # Combine factors with non-linear weighting
        criticality = math.sqrt(timing_pressure * urgency)
        G.add_node(f'service_{i}', type=0, weight=min(criticality, 1.0))
    
    # Add station nodes (type 2 - resource-like)
    max_platforms = max(platform) if platform else 1
    for i in range(num_stops):
        if i < len(platform) and platform[i] > 0:  # Skip dummy stop
            # Weight stations by scarcity (inverse of platform capacity)
            station_platforms = platform[i] if i < len(platform) else 1
            wait_time = minimal_wait[i] if i < len(minimal_wait) else 0
            
            # Stations with fewer platforms and longer wait times are bottlenecks
            capacity_scarcity = 1.0 - (station_platforms / max_platforms) if max_platforms > 0 else 0.5
            delay_factor = wait_time / 20.0 if wait_time > 0 else 0.1  # Normalize wait time
            
            # Use exponential scaling for bottleneck effect
            bottleneck_weight = math.exp(-2.0 * capacity_scarcity) * (1.0 + delay_factor)
            G.add_node(f'station_{i}', type=2, weight=min(bottleneck_weight, 1.0))
    
    # Add engine nodes (type 2 - resource-like)  
    for i in range(num_engines):
        # All engines have similar capability, slight variation for ML features
        engine_weight = 0.7 + 0.2 * (i % 3) / 3.0  # Slight variation
        G.add_node(f'engine_{i}', type=2, weight=engine_weight)
    
    # Add platform capacity constraint nodes (type 1)
    for i in range(num_stops):
        if i < len(platform) and platform[i] > 0:
            station_platforms = platform[i]
            
            # Estimate demand vs capacity for tightness
            services_using_station = sum(1 for s in range(num_services) 
                                       if i < len(minimal_wait))  # Rough estimate
            demand_ratio = services_using_station / (station_platforms * 4) if station_platforms > 0 else 1.0
            
            # Platform constraints are tighter when demand exceeds capacity
            tightness = min(demand_ratio, 1.0) if demand_ratio > 0.5 else 0.3
            G.add_node(f'platform_constraint_{i}', type=1, weight=tightness)
    
    # Add timing constraint nodes (type 1)
    for i in range(num_services):
        # Timing constraints based on service window tightness
        if i < len(service_start) and i < len(service_end):
            start_time = service_start[i]
            end_time = service_end[i]
            window_size = end_time - start_time
            
            # Tighter time windows create more constrained scheduling
            window_tightness = 1.0 - (window_size / makespan) if makespan > 0 else 0.5
            G.add_node(f'timing_constraint_{i}', type=1, weight=window_tightness)
    
    # Add track separation constraint nodes (type 1) 
    if min_sep > 0:
        # Track separation is critical for safety and capacity
        separation_criticality = min_sep / 10.0  # Normalize typical separation times
        G.add_node('track_separation', type=1, weight=min(separation_criticality, 1.0))
    
    # Add engine assignment constraint (type 1)
    engine_assignment_complexity = num_services / num_engines if num_engines > 0 else 1.0
    G.add_node('engine_assignment', type=1, weight=min(engine_assignment_complexity / 3.0, 1.0))
    
    # Add edges: Services to platform capacity constraints
    for s in range(num_services):
        for st in range(num_stops):
            if st < len(platform) and platform[st] > 0:
                # Services participate in platform constraints at stations they visit
                # Weight by platform utilization impact
                utilization_impact = 1.0 / platform[st] if platform[st] > 0 else 1.0
                G.add_edge(f'service_{s}', f'platform_constraint_{st}', 
                          weight=min(utilization_impact, 1.0))
    
    # Add edges: Services to timing constraints
    for s in range(num_services):
        if s < len(service_start) and s < len(service_end):
            # Each service has its own timing constraint
            G.add_edge(f'service_{s}', f'timing_constraint_{s}', weight=1.0)
    
    # Add edges: Services to track separation constraint
    for s in range(num_services):
        # All services participate in track separation constraints
        separation_weight = 0.8  # High participation but not maximum
        G.add_edge(f'service_{s}', 'track_separation', weight=separation_weight)
    
    # Add edges: Services to engine assignment constraint
    for s in range(num_services):
        assignment_weight = 0.6  # Medium weight for assignment participation
        G.add_edge(f'service_{s}', 'engine_assignment', weight=assignment_weight)
    
    # Add edges: Engines to engine assignment constraint
    for e in range(num_engines):
        G.add_edge(f'engine_{e}', 'engine_assignment', weight=0.7)
    
    # Add edges: Stations to platform constraints
    for st in range(num_stops):
        if st < len(platform) and platform[st] > 0:
            # Stations are the resources constrained by platform limits
            resource_weight = 0.9  # High weight for resource constraint relationship
            G.add_edge(f'station_{st}', f'platform_constraint_{st}', weight=resource_weight)
    
    # Add conflict edges between services with overlapping time windows (sparse)
    for s1 in range(num_services):
        for s2 in range(s1 + 1, min(s1 + 3, num_services)):  # Only check nearby services
            if (s1 < len(service_start) and s1 < len(service_end) and 
                s2 < len(service_start) and s2 < len(service_end)):
                
                start1, end1 = service_start[s1], service_end[s1]
                start2, end2 = service_start[s2], service_end[s2]
                
                # Check for time window overlap
                overlap = min(end1, end2) - max(start1, start2)
                if overlap > 0:
                    # Weight by degree of temporal conflict
                    total_span = max(end1, end2) - min(start1, start2)
                    conflict_intensity = overlap / total_span if total_span > 0 else 0
                    if conflict_intensity > 0.3:  # Only add significant conflicts
                        G.add_edge(f'service_{s1}', f'service_{s2}', 
                                  weight=min(conflict_intensity, 1.0))
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()