#!/usr/bin/env python3
"""
Graph converter for TPP (Traveling Purchaser Problem) problem.
Created using subagent_prompt.md version: v_02

This problem is about finding an optimal route through cities while purchasing
required products at different locations to minimize total travel and purchase costs.
Key challenges: Route optimization combined with purchase location decisions,
asymmetric travel costs, and varying product prices across cities.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the TPP problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph with cities and products as variables,
    and constraints representing routing, purchasing, and cost relationships.
    - Cities: Type 0 nodes (route decision points)
    - Products: Type 0 nodes (purchase location decisions)  
    - Route constraints: Type 1 nodes (subcircuit, travel costs)
    - Purchase constraints: Type 1 nodes (product availability, pricing)
    - Travel edges: Connect cities based on allowed moves
    - Purchase edges: Connect products to cities where they can be bought
    """
    numproducts = json_data.get('numproducts', 0)
    numcities = json_data.get('numcities', 0)
    maxdist = json_data.get('maxdist', 1)
    maxprice = json_data.get('maxprice', 1)
    
    # dist is a flattened array of size numcities x numcities
    dist_flat = json_data.get('dist', [])
    # price is a flattened array of size (numcities-1) x numproducts
    price_flat = json_data.get('price', [])
    
    G = nx.Graph()
    
    # Convert flattened arrays to 2D for easier access
    dist = []
    for i in range(numcities):
        row = []
        for j in range(numcities):
            idx = i * numcities + j
            if idx < len(dist_flat):
                row.append(dist_flat[idx])
            else:
                row.append(-1)
        dist.append(row)
    
    price = []
    for i in range(numcities - 1):  # Last city has no products for sale
        row = []
        for j in range(numproducts):
            idx = i * numproducts + j
            if idx < len(price_flat):
                row.append(price_flat[idx])
            else:
                row.append(maxprice)
        price.append(row)
    
    # City nodes (Type 0) - route decision variables
    # Weight by connectivity and centrality in the travel network
    city_connectivity = []
    for i in range(numcities):
        connections = sum(1 for j in range(numcities) if dist[i][j] > 0)
        city_connectivity.append(connections)
    
    max_connectivity = max(city_connectivity) if city_connectivity else 1
    
    for i in range(numcities):
        # Weight by connectivity (more connected = more important for routing)
        connectivity_weight = city_connectivity[i] / max_connectivity if max_connectivity > 0 else 0.5
        # Last city (starting point) gets higher weight
        if i == numcities - 1:
            connectivity_weight = max(connectivity_weight, 0.8)
        G.add_node(f'city_{i}', type=0, weight=connectivity_weight)
    
    # Product nodes (Type 0) - purchase location decision variables
    # Weight by price variation (more variation = harder decision)
    for p in range(numproducts):
        product_prices = [price[i][p] for i in range(numcities - 1) if i < len(price)]
        if product_prices:
            min_price = min(product_prices)
            max_price = max(product_prices)
            # Higher weight for products with more price variation (harder decisions)
            price_variation = (max_price - min_price) / maxprice if maxprice > 0 else 0.5
            # Also consider average price level (expensive products are important)
            avg_price = sum(product_prices) / len(product_prices)
            price_importance = avg_price / maxprice if maxprice > 0 else 0.5
            weight = (price_variation + price_importance) / 2
        else:
            weight = 0.5
        G.add_node(f'product_{p}', type=0, weight=weight)
    
    # Subcircuit constraint (Type 1) - ensures valid tour
    # Weight by problem complexity (more cities = harder routing)
    routing_complexity = math.log(numcities) / math.log(20) if numcities > 1 else 0.5
    G.add_node('subcircuit_constraint', type=1, weight=min(routing_complexity, 1.0))
    
    # Travel cost constraints (Type 1) - one per city for successor selection
    for i in range(numcities):
        # Weight by outgoing travel cost variance (more variance = harder choice)
        valid_costs = [dist[i][j] for j in range(numcities) if dist[i][j] > 0]
        if len(valid_costs) > 1:
            avg_cost = sum(valid_costs) / len(valid_costs)
            cost_variance = sum((c - avg_cost) ** 2 for c in valid_costs) / len(valid_costs)
            # Normalize variance and use sqrt for more reasonable scaling
            weight = math.sqrt(cost_variance) / maxdist if maxdist > 0 else 0.5
        else:
            weight = 0.3  # Limited choices
        G.add_node(f'travel_constraint_{i}', type=1, weight=min(weight, 1.0))
    
    # Purchase location constraints (Type 1) - one per product
    for p in range(numproducts):
        # Weight by price spread and city availability
        product_prices = [price[i][p] for i in range(numcities - 1) if i < len(price)]
        if product_prices:
            price_spread = (max(product_prices) - min(product_prices)) / maxprice
            weight = price_spread
        else:
            weight = 0.5
        G.add_node(f'purchase_constraint_{p}', type=1, weight=min(weight, 1.0))
    
    # Visit constraint - ensures purchased cities are visited
    G.add_node('visit_constraint', type=1, weight=0.9)  # Critical constraint
    
    # Edges: Cities to routing constraints (bipartite structure)
    for i in range(numcities):
        # Connect city to subcircuit constraint
        G.add_edge(f'city_{i}', 'subcircuit_constraint', weight=0.8)
        # Connect city to its travel constraint
        G.add_edge(f'city_{i}', f'travel_constraint_{i}', weight=1.0)
        # Connect city to visit constraint (for purchased products)
        G.add_edge(f'city_{i}', 'visit_constraint', weight=0.7)
    
    # Edges: Products to purchase constraints
    for p in range(numproducts):
        G.add_edge(f'product_{p}', f'purchase_constraint_{p}', weight=1.0)
        G.add_edge(f'product_{p}', 'visit_constraint', weight=0.8)
    
    # Travel network edges (between cities that can be reached)
    for i in range(numcities):
        for j in range(numcities):
            if i != j and dist[i][j] > 0:
                # Weight by inverse travel cost (shorter distances = stronger connections)
                # Use exponential decay for non-linear relationship
                travel_weight = math.exp(-2.0 * dist[i][j] / maxdist) if maxdist > 0 else 0.5
                G.add_edge(f'city_{i}', f'city_{j}', weight=travel_weight)
    
    # Purchase option edges (products to cities where they can be bought)
    for p in range(numproducts):
        for i in range(numcities - 1):  # Last city has no products
            if i < len(price) and p < len(price[i]):
                # Weight by price attractiveness (cheaper = stronger connection)
                product_prices = [price[k][p] for k in range(numcities - 1) if k < len(price)]
                if product_prices:
                    min_price = min(product_prices)
                    current_price = price[i][p]
                    # Lower prices get higher weights (exponential preference for cheap options)
                    price_attractiveness = math.exp(-2.0 * (current_price - min_price) / maxprice) if maxprice > 0 else 0.5
                    G.add_edge(f'product_{p}', f'city_{i}', weight=price_attractiveness)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()