#!/usr/bin/env python3
"""
Graph converter for TSPTW (Traveling Salesman Problem with Time Windows) problem.
Created using subagent_prompt.md version: v_02

This problem is about finding the shortest Hamiltonian circuit visiting all locations
while respecting time windows. Each location has an earliest and latest visit time.
Key challenges: tight time windows, long travel times, and conflicting scheduling requirements.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the TSPTW problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Model as bipartite graph with location nodes and constraint nodes
    - Location nodes (type 0): Each location with time window tightness weights
    - Time window constraints (type 1): Individual constraints for each location 
    - Circuit constraint (type 1): Hamiltonian circuit requirement
    - Distance-based edges with exponential decay for proximity
    - Time pressure weighting based on window size and urgency
    """
    num_locations = json_data.get('numLocations', 0)
    early = json_data.get('early', [])
    late = json_data.get('late', [])
    duration = json_data.get('duration', [])
    
    if num_locations == 0 or not early or not late or not duration:
        return nx.Graph()
    
    G = nx.Graph()
    
    # Parse duration matrix (flattened row-major format)
    duration_matrix = []
    for i in range(num_locations):
        row = []
        for j in range(num_locations):
            idx = i * num_locations + j
            if idx < len(duration):
                row.append(duration[idx])
            else:
                row.append(0)
        duration_matrix.append(row)
    
    # Calculate problem-wide statistics for normalization
    max_duration = max(max(row) for row in duration_matrix) if duration_matrix else 1
    total_time_span = max(late) - min(early) if early and late else 1
    avg_window_size = sum(late[i] - early[i] for i in range(min(len(early), len(late)))) / num_locations if early and late else 1
    
    # Location nodes (type 0) - weight by time window tightness and centrality
    for i in range(num_locations):
        if i < len(early) and i < len(late):
            window_size = late[i] - early[i]
            # Tighter windows = higher weight (more constrained)
            window_tightness = 1.0 - (window_size / total_time_span) if total_time_span > 0 else 0.5
            
            # Add urgency factor - earlier deadlines are more constraining
            urgency = 1.0 - (late[i] / max(late)) if late and max(late) > 0 else 0.5
            
            # Combine tightness and urgency with non-linear weighting
            combined_weight = math.sqrt(window_tightness * urgency)
            
            node_weight = max(0.1, min(1.0, combined_weight))
            G.add_node(f'loc_{i}', type=0, weight=node_weight)
    
    # Time window constraint nodes (type 1) - one per location
    for i in range(num_locations):
        if i < len(early) and i < len(late):
            window_size = late[i] - early[i]
            # Constraint tightness based on window size relative to average
            constraint_tightness = 1.0 - (window_size / avg_window_size) if avg_window_size > 0 else 0.5
            # Use exponential scaling for very tight constraints
            if constraint_tightness > 0.8:
                constraint_tightness = 0.8 + 0.2 * math.exp(5 * (constraint_tightness - 0.8))
            
            constraint_weight = max(0.1, min(1.0, constraint_tightness))
            G.add_node(f'tw_constraint_{i}', type=1, weight=constraint_weight)
    
    # Circuit constraint node (type 1) - represents Hamiltonian circuit requirement
    # Weight by problem complexity (larger problems have more complex circuits)
    circuit_complexity = math.log(num_locations) / math.log(100) if num_locations > 1 else 0.5
    circuit_weight = max(0.3, min(1.0, circuit_complexity))
    G.add_node('circuit_constraint', type=1, weight=circuit_weight)
    
    # Edges: Location to time window constraint (bipartite)
    for i in range(num_locations):
        if i < len(early) and i < len(late):
            # Edge weight based on how critical the time window is
            window_size = late[i] - early[i]
            criticality = 1.0 - (window_size / total_time_span) if total_time_span > 0 else 0.5
            edge_weight = max(0.1, min(1.0, criticality))
            G.add_edge(f'loc_{i}', f'tw_constraint_{i}', weight=edge_weight)
    
    # Edges: All locations to circuit constraint
    for i in range(num_locations):
        # Weight by location's contribution to circuit complexity
        # Locations with many close neighbors are easier to include
        close_neighbors = sum(1 for j in range(num_locations) 
                             if i < len(duration_matrix) and j < len(duration_matrix[i]) 
                             and duration_matrix[i][j] <= max_duration * 0.3)
        connectivity = close_neighbors / max(num_locations - 1, 1)
        circuit_edge_weight = 1.0 - connectivity  # Less connected = harder to route
        circuit_edge_weight = max(0.2, min(1.0, circuit_edge_weight))
        G.add_edge(f'loc_{i}', 'circuit_constraint', weight=circuit_edge_weight)
    
    # Distance-based edges between locations (for proximity conflicts)
    # Only add edges for locations that are close but have conflicting time windows
    for i in range(num_locations):
        for j in range(i + 1, num_locations):
            if (i < len(duration_matrix) and j < len(duration_matrix[i]) and
                i < len(early) and j < len(early) and i < len(late) and j < len(late)):
                
                travel_time = duration_matrix[i][j]
                
                # Check for time window conflicts
                # If we visit i at latest time, can we reach j in time?
                min_arrival_j = late[i] + travel_time
                time_conflict = min_arrival_j > late[j]
                
                # Only add edge if there's potential for conflict and they're close
                if time_conflict and travel_time <= max_duration * 0.4:
                    # Weight by severity of conflict and travel time
                    conflict_severity = (min_arrival_j - late[j]) / total_time_span if total_time_span > 0 else 0.5
                    distance_factor = 1.0 - (travel_time / max_duration) if max_duration > 0 else 0.5
                    
                    # Use exponential decay for distance weighting
                    conflict_weight = conflict_severity * math.exp(-2.0 * travel_time / max_duration)
                    conflict_weight = max(0.1, min(1.0, conflict_weight))
                    
                    G.add_edge(f'loc_{i}', f'loc_{j}', weight=conflict_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()