#!/usr/bin/env python3
"""
Graph converter for depot_placement problem.
Created using subagent_prompt.md version: v_02

This problem is about optimizing depot placement for two warehouses with trucks
that can serve each other's customers via a shared depot. The goal is to minimize
the maximum distance traveled by either truck.

Key challenges: 
- Depot location affects both trucks' route lengths
- Cross-warehouse deliveries require depot visits
- Balancing workload between trucks
- Distance matrix creates complex routing constraints
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the depot placement problem.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph with location nodes and constraint nodes
    - Location nodes (type 0): Warehouses and customers with distance-based weights
    - Depot constraint nodes (type 1): Model depot placement decisions
    - Tour constraint nodes (type 1): Model routing constraints for each truck
    - Cross-delivery constraint nodes (type 1): Model inter-warehouse delivery rules
    """
    psize = json_data.get('PSize', 0)
    all_dist = json_data.get('AllDist', [])
    
    # Calculate derived parameters from the MZN model
    tsize = (psize + 1) * 2  # Total locations: 2 warehouses + 2*psize customers
    tour_length = psize + 1
    
    # Warehouse locations
    a_warehouse = 1  # Location 1
    b_warehouse = tour_length + 1  # Location tour_length + 1
    
    G = nx.Graph()
    
    # Convert flat distance array to 2D matrix for easier access
    dist_matrix = []
    for i in range(tsize):
        row = []
        for j in range(tsize):
            idx = i * tsize + j
            if idx < len(all_dist):
                row.append(all_dist[idx])
            else:
                row.append(0)
        dist_matrix.append(row)
    
    # Calculate max distance for normalization
    max_dist = max(all_dist) if all_dist else 1
    
    # Add location nodes (type 0) - warehouses and customers
    for loc in range(1, tsize + 1):
        # Calculate centrality based on average distance to all other locations
        avg_dist = sum(dist_matrix[loc-1]) / max(len(dist_matrix[loc-1]), 1)
        centrality = 1.0 - (avg_dist / max_dist) if max_dist > 0 else 0.5
        
        # Warehouses are more critical (higher weight)
        if loc == a_warehouse or loc == b_warehouse:
            weight = 0.9 + 0.1 * centrality
            node_type = "warehouse"
        else:
            # Customer weight based on how "difficult" they are to reach
            weight = 0.3 + 0.7 * (1.0 - centrality)
            if loc <= tour_length:
                node_type = "customer_A"
            else:
                node_type = "customer_B"
        
        G.add_node(f'loc_{loc}', type=0, weight=min(weight, 1.0), 
                  location_type=node_type)
    
    # Add depot placement constraint nodes (type 1)
    # One for each potential depot location
    for depot_loc in range(1, tsize + 1):
        # Weight based on how "central" this location is as a depot
        total_dist_from_depot = sum(dist_matrix[depot_loc-1])
        depot_centrality = 1.0 - (total_dist_from_depot / (max_dist * tsize)) if max_dist > 0 else 0.5
        
        G.add_node(f'depot_choice_{depot_loc}', type=1, 
                  weight=0.4 + 0.6 * depot_centrality)
        
        # Connect to all locations (depot can affect all routing decisions)
        for loc in range(1, tsize + 1):
            # Edge weight based on distance impact
            distance = dist_matrix[depot_loc-1][loc-1] if loc-1 < len(dist_matrix[depot_loc-1]) else 0
            edge_weight = math.exp(-3.0 * distance / max_dist) if max_dist > 0 else 0.5
            G.add_edge(f'loc_{loc}', f'depot_choice_{depot_loc}', 
                      weight=min(edge_weight, 1.0))
    
    # Add tour constraints for each warehouse (type 1)
    # Constraint for warehouse A tour
    a_customers = list(range(1, tour_length + 1))  # Warehouse A's customers
    a_tour_complexity = len(a_customers) / (psize + 1) if psize > 0 else 0.5
    G.add_node('tour_A_constraint', type=1, weight=0.5 + 0.5 * a_tour_complexity)
    
    # Connect to warehouse A and its customers
    for loc in a_customers:
        G.add_edge(f'loc_{loc}', 'tour_A_constraint', weight=0.8)
    
    # Constraint for warehouse B tour  
    b_customers = list(range(tour_length + 1, tsize + 1))  # Warehouse B's customers
    b_tour_complexity = len(b_customers) / (psize + 1) if psize > 0 else 0.5
    G.add_node('tour_B_constraint', type=1, weight=0.5 + 0.5 * b_tour_complexity)
    
    # Connect to warehouse B and its customers
    for loc in b_customers:
        G.add_edge(f'loc_{loc}', 'tour_B_constraint', weight=0.8)
    
    # Add cross-delivery constraints (type 1)
    # Constraint for A delivering to B's customers
    G.add_node('cross_delivery_A_to_B', type=1, weight=0.7)
    
    # Connect A's warehouse to this constraint
    G.add_edge(f'loc_{a_warehouse}', 'cross_delivery_A_to_B', weight=1.0)
    # Connect B's customers to this constraint (A might deliver to them)
    for loc in b_customers:
        G.add_edge(f'loc_{loc}', 'cross_delivery_A_to_B', weight=0.6)
    
    # Constraint for B delivering to A's customers
    G.add_node('cross_delivery_B_to_A', type=1, weight=0.7)
    
    # Connect B's warehouse to this constraint
    G.add_edge(f'loc_{b_warehouse}', 'cross_delivery_B_to_A', weight=1.0)
    # Connect A's customers to this constraint (B might deliver to them)
    for loc in a_customers:
        G.add_edge(f'loc_{loc}', 'cross_delivery_B_to_A', weight=0.6)
    
    # Add distance-based conflict edges between distant locations
    # These represent the difficulty of including both in the same route
    distance_threshold = max_dist * 0.7  # Only add conflicts for distant pairs
    
    for i in range(1, tsize + 1):
        for j in range(i + 1, tsize + 1):
            distance = dist_matrix[i-1][j-1] if i-1 < len(dist_matrix) and j-1 < len(dist_matrix[i-1]) else 0
            
            if distance > distance_threshold:
                # Add conflict edge with weight based on distance
                conflict_strength = min(distance / max_dist, 1.0) if max_dist > 0 else 0.5
                G.add_edge(f'loc_{i}', f'loc_{j}', weight=conflict_strength)
    
    # Add all-different constraint for tour locations (type 1)
    G.add_node('alldiff_tours', type=1, weight=0.6)
    
    # Connect all customer locations to the all-different constraint
    for loc in range(1, tsize + 1):
        if loc != a_warehouse and loc != b_warehouse:  # Exclude warehouses
            G.add_edge(f'loc_{loc}', 'alldiff_tours', weight=0.5)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()