#!/usr/bin/env python3
"""
Graph converter for tower problem.
Created using subagent_prompt.md version: v_02

This problem is about cellular tower placement and power optimization.
Key challenges: signal strength vs. interference, capacity vs. demand, power efficiency
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the tower placement problem.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create bipartite graph with handsets, towers, and constraints
    - Handsets (type 0): weighted by demand/distance criticality
    - Towers (type 0): weighted by capacity utilization potential
    - Signal constraints (type 1): weighted by difficulty of satisfaction
    - Capacity constraints (type 1): weighted by tightness
    - Distance-based edge weights with exponential decay
    """
    # Extract data
    handsets = json_data.get('HANDSET', [])
    towers = json_data.get('TOWER', [])
    demand = json_data.get('demand', [])
    capacity = json_data.get('capacity', [])
    distance = json_data.get('distance', [])
    min_signal_strength = json_data.get('min_signal_strength', 5)
    
    n_handsets = len(handsets)
    n_towers = len(towers)
    
    # Create distance matrix (handset x tower)
    if distance:
        dist_matrix = []
        for h in range(n_handsets):
            row = []
            for t in range(n_towers):
                idx = h * n_towers + t
                if idx < len(distance):
                    row.append(distance[idx])
                else:
                    row.append(100.0)  # Default high distance
            dist_matrix.append(row)
    else:
        dist_matrix = [[50.0] * n_towers for _ in range(n_handsets)]
    
    G = nx.Graph()
    
    # Calculate statistics for normalization
    max_demand = max(demand) if demand else 1
    max_capacity = max(capacity) if capacity else 1
    total_demand = sum(demand) if demand else n_handsets
    total_capacity = sum(capacity) if capacity else n_towers
    max_distance = max(max(row) for row in dist_matrix) if dist_matrix else 100.0
    
    # Handset nodes (type 0) - weighted by demand and average accessibility
    for h in range(n_handsets):
        handset_demand = demand[h] if h < len(demand) else 1
        
        # Calculate accessibility (inverse of average distance to towers)
        avg_distance = sum(dist_matrix[h]) / n_towers
        accessibility = math.exp(-2.0 * avg_distance / max_distance)
        
        # Combine demand and accessibility with non-linear scaling
        demand_factor = math.sqrt(handset_demand / max_demand)
        weight = min((demand_factor + accessibility) / 2.0, 1.0)
        
        G.add_node(f'handset_{h}', type=0, weight=weight)
    
    # Tower nodes (type 0) - weighted by capacity and demand pressure
    for t in range(n_towers):
        tower_capacity = capacity[t] if t < len(capacity) else 300
        
        # Calculate demand pressure (sum of demands from nearby handsets)
        demand_pressure = 0
        for h in range(n_handsets):
            if dist_matrix[h][t] < max_distance * 0.5:  # Nearby handsets
                handset_demand = demand[h] if h < len(demand) else 1
                proximity = math.exp(-3.0 * dist_matrix[h][t] / max_distance)
                demand_pressure += handset_demand * proximity
        
        # Weight by capacity utilization potential and demand pressure
        capacity_factor = tower_capacity / max_capacity
        pressure_factor = min(demand_pressure / (tower_capacity * 0.8), 1.0)
        weight = min((capacity_factor + pressure_factor) / 2.0, 1.0)
        
        G.add_node(f'tower_{t}', type=0, weight=weight)
    
    # Signal strength constraint nodes (type 1) - one per handset-tower pair
    signal_constraints = 0
    for h in range(n_handsets):
        for t in range(n_towers):
            dist = dist_matrix[h][t]
            # Signal constraint tightness based on distance and min signal requirement
            # Longer distance = harder to satisfy minimum signal
            signal_difficulty = min(dist / (max_distance * 0.7), 1.0)
            
            # Add exponential scaling for very difficult connections
            if signal_difficulty > 0.8:
                signal_difficulty = 0.8 + 0.2 * math.exp(5.0 * (signal_difficulty - 0.8))
                signal_difficulty = min(signal_difficulty, 1.0)
            
            constraint_id = f'signal_{h}_{t}'
            G.add_node(constraint_id, type=1, weight=signal_difficulty)
            signal_constraints += 1
            
            # Connect handset and tower to this signal constraint
            # Edge weight based on signal quality potential
            signal_quality = math.exp(-2.0 * dist / max_distance)
            G.add_edge(f'handset_{h}', constraint_id, weight=signal_quality)
            G.add_edge(f'tower_{t}', constraint_id, weight=signal_quality)
    
    # Tower capacity constraint nodes (type 1) - one per tower
    for t in range(n_towers):
        tower_capacity = capacity[t] if t < len(capacity) else 300
        
        # Calculate total potential demand for this tower
        potential_demand = 0
        for h in range(n_handsets):
            handset_demand = demand[h] if h < len(demand) else 1
            # Handsets prefer closer towers, so weight by proximity
            proximity = math.exp(-1.5 * dist_matrix[h][t] / max_distance)
            potential_demand += handset_demand * proximity
        
        # Capacity constraint tightness
        if potential_demand > tower_capacity:
            tightness = 1.0 - (tower_capacity / potential_demand)
        else:
            tightness = 0.3  # Some base difficulty even when not overloaded
        
        constraint_id = f'capacity_{t}'
        G.add_node(constraint_id, type=1, weight=tightness)
        
        # Connect tower to its capacity constraint
        G.add_edge(f'tower_{t}', constraint_id, weight=1.0)
        
        # Connect handsets that might use this tower
        for h in range(n_handsets):
            if dist_matrix[h][t] < max_distance * 0.8:  # Reachable handsets
                handset_demand = demand[h] if h < len(demand) else 1
                consumption_ratio = handset_demand / tower_capacity
                edge_weight = min(consumption_ratio * 3.0, 1.0)  # Amplify for sensitivity
                G.add_edge(f'handset_{h}', constraint_id, weight=edge_weight)
    
    # Add conflict edges between handsets competing for the same overloaded towers
    for t in range(n_towers):
        tower_capacity = capacity[t] if t < len(capacity) else 300
        
        # Find handsets that could use this tower
        competing_handsets = []
        for h in range(n_handsets):
            if dist_matrix[h][t] < max_distance * 0.6:  # Close enough to compete
                handset_demand = demand[h] if h < len(demand) else 1
                competing_handsets.append((h, handset_demand, dist_matrix[h][t]))
        
        # If tower is potentially overloaded, add conflicts
        total_competing_demand = sum(demand for _, demand, _ in competing_handsets)
        if total_competing_demand > tower_capacity * 1.2:
            # Sort by demand/distance ratio (priority)
            competing_handsets.sort(key=lambda x: x[1] / (x[2] + 1), reverse=True)
            
            # Add conflict edges between top competitors
            for i in range(min(len(competing_handsets), 6)):
                for j in range(i + 1, min(len(competing_handsets), 6)):
                    h1, demand1, dist1 = competing_handsets[i]
                    h2, demand2, dist2 = competing_handsets[j]
                    
                    if demand1 + demand2 > tower_capacity * 0.8:
                        # Conflict strength based on combined demand vs capacity
                        conflict_strength = min((demand1 + demand2) / tower_capacity, 1.0)
                        G.add_edge(f'handset_{h1}', f'handset_{h2}', weight=conflict_strength)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()