#!/usr/bin/env python3
"""
Graph converter for VRP (Vehicle Routing Problem) problem.
Created using subagent_prompt.md version: v_02

This problem is about finding optimal routes for multiple vehicles to serve customers from a depot.
Key challenges: vehicle capacity constraints, depot routing, subtour elimination, distance minimization.
The difficulty depends on tight capacity constraints, customer distribution, and number of vehicles needed.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the VRP instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Model as bipartite graph with customers, depot, constraints, and vehicles
    - Customer nodes (type 0): weighted by demand/capacity ratio  
    - Constraint nodes (type 1): capacity constraints, routing constraints
    - Vehicle/depot nodes (type 2): limited vehicle resources
    - Distance-based edge weights with exponential decay for geographical relationships
    """
    N = json_data.get('N', 0)  # Number of customers (excluding depot)
    capacity = json_data.get('Capacity', 100)
    demand = json_data.get('Demand', [])
    distance_matrix = json_data.get('Distance', [])
    
    G = nx.Graph()
    
    # Calculate some metrics for intelligent weighting
    total_demand = sum(demand) if demand else 0
    max_demand = max(demand) if demand else 1
    avg_demand = total_demand / N if N > 0 else 0
    min_vehicles_needed = math.ceil(total_demand / capacity) if capacity > 0 else 1
    
    # Reconstruct distance matrix (stored as flattened array)
    # Matrix is (N+1) x (N+1) where index 0 is depot, 1..N are customers
    matrix_size = N + 1
    
    def get_distance(i, j):
        """Get distance between nodes i and j from flattened matrix"""
        if len(distance_matrix) >= matrix_size * matrix_size:
            return distance_matrix[i * matrix_size + j]
        return 1  # fallback
    
    # Customer nodes (type 0) - weighted by demand pressure
    for i in range(1, N + 1):  # customers 1..N
        customer_demand = demand[i-1] if i-1 < len(demand) else 0
        # Weight by demand relative to capacity (higher demand = higher weight)
        demand_pressure = customer_demand / capacity if capacity > 0 else 0.5
        # Add difficulty from being high-demand customer
        difficulty_bonus = customer_demand / max_demand if max_demand > 0 else 0.5
        weight = min((demand_pressure + difficulty_bonus) / 2, 1.0)
        G.add_node(f'customer_{i}', type=0, weight=weight)
    
    # Depot node (type 2) - resource/hub node
    # Weight by utilization pressure (more vehicles needed = higher weight)
    depot_weight = min(min_vehicles_needed / 10.0, 1.0)  # normalize to reasonable range
    G.add_node('depot', type=2, weight=depot_weight)
    
    # Vehicle capacity constraints (type 1) - one per estimated vehicle needed
    for v in range(min_vehicles_needed):
        # Weight by how tight the capacity constraint is
        remaining_capacity_ratio = (capacity * (v + 1) - total_demand) / capacity if capacity > 0 else 0.5
        tightness = max(0.1, 1.0 - remaining_capacity_ratio)  # tighter = higher weight
        G.add_node(f'vehicle_capacity_{v}', type=1, weight=min(tightness, 1.0))
    
    # Route connectivity constraints (type 1) - ensure proper routing structure
    # Each customer must have exactly one incoming and one outgoing edge
    for i in range(1, N + 1):
        G.add_node(f'route_in_{i}', type=1, weight=0.8)   # indegree = 1 constraint
        G.add_node(f'route_out_{i}', type=1, weight=0.8)  # outdegree = 1 constraint
    
    # Depot flow constraint (type 1) - vehicle count limit
    depot_flow_tightness = min_vehicles_needed / N if N > 0 else 0.5
    G.add_node('depot_flow', type=1, weight=min(depot_flow_tightness, 1.0))
    
    # Subtour elimination constraints (type 1) - prevent disconnected tours
    # Weight by problem complexity (more customers = more subtour risk)
    subtour_complexity = math.log(N + 1) / math.log(100) if N > 0 else 0.5  # log scaling
    G.add_node('subtour_elimination', type=1, weight=min(subtour_complexity, 1.0))
    
    # Edges: Customer participation in constraints
    for i in range(1, N + 1):
        customer = f'customer_{i}'
        customer_demand = demand[i-1] if i-1 < len(demand) else 0
        
        # Connect to vehicle capacity constraints
        for v in range(min_vehicles_needed):
            # Weight by how much this customer contributes to capacity pressure
            capacity_usage = customer_demand / capacity if capacity > 0 else 0.5
            G.add_edge(customer, f'vehicle_capacity_{v}', weight=min(capacity_usage, 1.0))
        
        # Connect to routing constraints
        G.add_edge(customer, f'route_in_{i}', weight=1.0)
        G.add_edge(customer, f'route_out_{i}', weight=1.0)
        
        # Connect to depot flow constraint
        G.add_edge(customer, 'depot_flow', weight=0.7)
        
        # Connect to subtour elimination (all customers participate)
        G.add_edge(customer, 'subtour_elimination', weight=0.8)
    
    # Connect depot to relevant constraints
    G.add_edge('depot', 'depot_flow', weight=1.0)
    
    # Distance-based edges between geographically close customers
    # Only add edges for customers that are relatively close (creates routing conflicts)
    if len(distance_matrix) >= matrix_size * matrix_size:
        # Calculate distance statistics for thresholding
        customer_distances = []
        for i in range(1, N + 1):
            for j in range(i + 1, N + 1):
                dist = get_distance(i, j)
                customer_distances.append(dist)
        
        if customer_distances:
            avg_distance = sum(customer_distances) / len(customer_distances)
            close_threshold = avg_distance * 0.7  # consider "close" if below 70% of average
            
            for i in range(1, N + 1):
                for j in range(i + 1, N + 1):
                    dist = get_distance(i, j)
                    if dist <= close_threshold:
                        # Close customers create routing conflicts (can't both be first in route)
                        # Use exponential decay for distance-based weight
                        distance_weight = math.exp(-3.0 * dist / close_threshold)
                        
                        # Add demand-based conflict (higher combined demand = more conflict)
                        demand_i = demand[i-1] if i-1 < len(demand) else 0
                        demand_j = demand[j-1] if j-1 < len(demand) else 0
                        demand_conflict = (demand_i + demand_j) / (2 * capacity) if capacity > 0 else 0.3
                        
                        final_weight = min((distance_weight + demand_conflict) / 2, 1.0)
                        
                        if final_weight > 0.1:  # only add meaningful conflicts
                            G.add_edge(f'customer_{i}', f'customer_{j}', weight=final_weight)
    
    # Add edges from depot to customers based on distance
    # Closer customers to depot are easier to serve (lower weight = easier)
    for i in range(1, N + 1):
        depot_distance = get_distance(0, i)  # depot is index 0
        if len(distance_matrix) >= matrix_size * matrix_size:
            # Normalize by maximum depot distance
            max_depot_dist = max(get_distance(0, j) for j in range(1, N + 1))
            if max_depot_dist > 0:
                # Closer customers have lower weights (easier to serve)
                distance_difficulty = depot_distance / max_depot_dist
                # Add exponential scaling to emphasize distance differences
                weight = math.exp(-2.0 * (1.0 - distance_difficulty))
                G.add_edge('depot', f'customer_{i}', weight=min(weight, 1.0))
            else:
                G.add_edge('depot', f'customer_{i}', weight=0.5)
        else:
            G.add_edge('depot', f'customer_{i}', weight=0.5)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()