#!/usr/bin/env python3
"""
Graph converter for Tower problem.
Created using subagent_prompt.md version: v_02

This problem is about cellular tower optimization.
Key challenges: balancing tower capacity, signal strength, and handset demand
while considering distance-based attenuation and power efficiency.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the tower problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create bipartite graph with handsets, towers, and constraints
    - Handsets (type 0): weighted by demand
    - Towers (type 2): weighted by capacity utilization pressure  
    - Constraints (type 1): capacity and signal strength constraints
    - Use exponential decay for distance-based edge weights
    """
    # Extract data
    min_signal = json_data.get('min_signal_strength', 5)
    demand = json_data.get('demand', [])
    capacity = json_data.get('capacity', [])
    distance_flat = json_data.get('distance', [])
    
    # Get dimensions
    n_handsets = len(demand)
    n_towers = len(capacity)
    
    # Reshape distance matrix (handset x tower)
    distance = []
    for h in range(n_handsets):
        row = []
        for t in range(n_towers):
            idx = h * n_towers + t
            if idx < len(distance_flat):
                row.append(distance_flat[idx])
            else:
                row.append(100.0)  # Default large distance
        distance.append(row)
    
    G = nx.Graph()
    
    # Calculate normalization factors
    max_demand = max(demand) if demand else 1
    max_capacity = max(capacity) if capacity else 1
    total_demand = sum(demand)
    max_distance = max(distance_flat) if distance_flat else 1
    
    # Add handset nodes (type 0) - weighted by relative demand
    for h in range(n_handsets):
        handset_demand = demand[h] if h < len(demand) else 1
        # Higher demand = more critical handset
        weight = handset_demand / max_demand
        G.add_node(f'handset_{h}', type=0, weight=weight)
    
    # Add tower nodes (type 2) - weighted by capacity pressure
    for t in range(n_towers):
        tower_capacity = capacity[t] if t < len(capacity) else 1
        # Calculate pressure: how much of total demand this tower could handle
        capacity_ratio = tower_capacity / total_demand if total_demand > 0 else 1
        # Invert: lower capacity relative to total demand = higher pressure
        pressure = 1.0 - min(capacity_ratio, 1.0)
        G.add_node(f'tower_{t}', type=2, weight=0.5 + 0.5 * pressure)
    
    # Add capacity constraint nodes (type 1) - one per tower
    for t in range(n_towers):
        tower_capacity = capacity[t] if t < len(capacity) else 1
        # Weight by potential tightness: smaller capacity = tighter constraint
        tightness = 1.0 - (tower_capacity / max_capacity)
        G.add_node(f'capacity_constraint_{t}', type=1, weight=0.3 + 0.7 * tightness)
    
    # Add signal strength constraint nodes (type 1) - one per handset
    for h in range(n_handsets):
        # Weight by minimum distance to any tower (harder to satisfy = higher weight)
        min_dist = min(distance[h]) if h < len(distance) else 1
        difficulty = min_dist / max_distance
        # Use sqrt to make it less linear
        weight = 0.2 + 0.8 * math.sqrt(difficulty)
        G.add_node(f'signal_constraint_{h}', type=1, weight=weight)
    
    # Add edges: handset to tower (potential connections)
    for h in range(n_handsets):
        for t in range(n_towers):
            if h < len(distance) and t < len(distance[h]):
                dist = distance[h][t]
                # Exponential decay for signal strength
                # Closer towers = stronger connections
                signal_strength = math.exp(-2.0 * dist / max_distance)
                G.add_edge(f'handset_{h}', f'tower_{t}', weight=signal_strength)
    
    # Add edges: handset to capacity constraints
    # Each handset affects all tower capacity constraints (potential assignments)
    for h in range(n_handsets):
        handset_demand = demand[h] if h < len(demand) else 1
        for t in range(n_towers):
            tower_capacity = capacity[t] if t < len(capacity) else 1
            # Weight by demand impact on capacity
            impact = handset_demand / tower_capacity if tower_capacity > 0 else 1
            weight = min(impact, 1.0)
            G.add_edge(f'handset_{h}', f'capacity_constraint_{t}', weight=weight)
    
    # Add edges: handset to signal constraints 
    # Each handset participates in its own signal constraint
    for h in range(n_handsets):
        G.add_edge(f'handset_{h}', f'signal_constraint_{h}', weight=1.0)
    
    # Add edges: tower to capacity constraints
    # Each tower is constrained by its own capacity constraint
    for t in range(n_towers):
        G.add_edge(f'tower_{t}', f'capacity_constraint_{t}', weight=1.0)
    
    # Add edges: tower to signal constraints
    # Towers can help satisfy signal constraints based on distance
    for t in range(n_towers):
        for h in range(n_handsets):
            if h < len(distance) and t < len(distance[h]):
                dist = distance[h][t]
                # Better signal capability = stronger edge
                capability = math.exp(-1.5 * dist / max_distance)
                if capability > 0.1:  # Only add significant connections
                    G.add_edge(f'tower_{t}', f'signal_constraint_{h}', weight=capability)
    
    # Add conflict edges between high-demand handsets that are close to same towers
    high_demand_threshold = 0.7 * max_demand
    for h1 in range(n_handsets):
        if demand[h1] < high_demand_threshold:
            continue
        for h2 in range(h1 + 1, n_handsets):
            if demand[h2] < high_demand_threshold:
                continue
            
            # Check if they prefer similar towers (compete for same resources)
            common_tower_strength = 0
            for t in range(n_towers):
                if (h1 < len(distance) and t < len(distance[h1]) and
                    h2 < len(distance) and t < len(distance[h2])):
                    
                    strength1 = math.exp(-2.0 * distance[h1][t] / max_distance)
                    strength2 = math.exp(-2.0 * distance[h2][t] / max_distance)
                    
                    # If both have good signal to this tower, they compete
                    if strength1 > 0.3 and strength2 > 0.3:
                        common_tower_strength += min(strength1, strength2)
            
            # Add conflict edge if significant competition
            if common_tower_strength > 0.5:
                demand_conflict = (demand[h1] + demand[h2]) / (2 * max_demand)
                G.add_edge(f'handset_{h1}', f'handset_{h2}', 
                          weight=0.3 + 0.7 * demand_conflict)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()