#!/usr/bin/env python3
"""
Graph converter for Prize Collecting problem.
Created using subagent_prompt.md version: v_02

This problem is about finding a path through nodes to maximize collected prizes.
Key challenges: Path connectivity constraints, profit/penalty tradeoffs, node selection decisions.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the Prize Collecting problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create bipartite graph with nodes and routing constraints
    - Node entities represent decision points (include/exclude from path)
    - Edge constraints model valid transitions between nodes
    - Profit values determine node attractiveness weights
    - Path connectivity constraints ensure valid tours
    """
    # Access data directly from json_data dict
    n = json_data.get('n', 0)
    p = json_data.get('p', [])
    
    # Create graph
    G = nx.Graph()
    
    # Calculate profit statistics for normalization
    profits = [val for val in p if val != -1000]  # Filter out invalid transitions
    if not profits:
        return G
        
    min_profit = min(profits)
    max_profit = max(profits)
    profit_range = max_profit - min_profit if max_profit != min_profit else 1
    
    # Count valid outgoing transitions for each node to assess connectivity
    valid_transitions = {}
    for i in range(n):
        valid_count = 0
        profit_sum = 0
        for j in range(n + 1):  # Include j=0 for "not used"
            idx = i * (n + 1) + j
            if idx < len(p) and p[idx] != -1000:
                valid_count += 1
                if j > 0:  # Don't count "not used" in profit sum
                    profit_sum += p[idx]
        valid_transitions[i] = (valid_count, profit_sum)
    
    # Add node entities (Type 0: decision variables)
    for i in range(n):
        valid_count, profit_sum = valid_transitions[i]
        
        # Node weight based on connectivity and profit potential
        if valid_count > 0:
            avg_profit = profit_sum / max(valid_count - 1, 1)  # Exclude "not used" from average
            # Normalize profit to [0,1] and apply exponential scaling for high-value nodes
            normalized_profit = (avg_profit - min_profit) / profit_range if profit_range > 0 else 0.5
            connectivity_factor = min(valid_count / n, 1.0)  # Higher connectivity = more strategic
            
            # Use exponential scaling to emphasize high-value, well-connected nodes
            base_weight = (normalized_profit + connectivity_factor) / 2
            node_weight = 1.0 - math.exp(-3.0 * base_weight)  # Exponential scaling
        else:
            node_weight = 0.1  # Isolated nodes have minimal weight
        
        # Special weight for starting node (node 1, index 0)
        if i == 0:
            node_weight = max(node_weight, 0.8)  # Starting node is critical
            
        G.add_node(f'node_{i}', type=0, weight=node_weight)
    
    # Add constraint nodes (Type 1: path and connectivity constraints)
    
    # 1. Path ordering constraints (one per node)
    for i in range(n):
        valid_count, _ = valid_transitions[i]
        # Weight by how constrained the node is (fewer options = tighter constraint)
        constraint_tightness = 1.0 - (valid_count / (n + 1)) if n > 0 else 0.5
        G.add_node(f'path_order_{i}', type=1, weight=constraint_tightness)
        
        # Connect to the node it constrains
        G.add_edge(f'node_{i}', f'path_order_{i}', weight=0.8)
    
    # 2. Transition validity constraints (based on profit matrix structure)
    constraint_id = 0
    for i in range(n):
        for j in range(n + 1):
            idx = i * (n + 1) + j
            if idx < len(p) and p[idx] != -1000:
                # Create constraint for valid i->j transition
                profit_val = p[idx]
                
                # Weight constraint by the significance of this transition
                if profit_val >= 0:
                    # Positive profits create beneficial constraints
                    normalized_profit = (profit_val - min_profit) / profit_range if profit_range > 0 else 0.5
                    constraint_weight = normalized_profit
                else:
                    # Negative profits create penalty constraints (higher weight = more important to model)
                    penalty_magnitude = abs(profit_val)
                    max_penalty = abs(min_profit) if min_profit < 0 else 1
                    constraint_weight = penalty_magnitude / max_penalty
                
                # Apply non-linear scaling to emphasize critical transitions
                constraint_weight = math.sqrt(constraint_weight)
                
                G.add_node(f'transition_{i}_{j}', type=1, weight=constraint_weight)
                
                # Connect to source node
                G.add_edge(f'node_{i}', f'transition_{i}_{j}', weight=0.6)
                
                # Connect to destination node if it's not the "not used" option
                if j > 0 and j <= n:
                    dest_idx = j - 1  # Convert to 0-based indexing
                    if dest_idx < n:
                        edge_weight = min(constraint_weight * 1.5, 1.0)
                        G.add_edge(f'node_{dest_idx}', f'transition_{i}_{j}', weight=edge_weight)
                
                constraint_id += 1
    
    # 3. Starting position constraint (node 1 must be at position 1)
    G.add_node('start_constraint', type=1, weight=1.0)
    G.add_edge('node_0', 'start_constraint', weight=1.0)
    
    # 4. Global path connectivity constraint
    # This models the overall constraint that selected nodes must form a valid path
    total_positive_profit = sum(val for val in profits if val > 0)
    total_negative_profit = abs(sum(val for val in profits if val < 0))
    
    if total_positive_profit > 0 and total_negative_profit > 0:
        # Problem difficulty based on profit/penalty balance
        difficulty = total_negative_profit / (total_positive_profit + total_negative_profit)
        G.add_node('global_connectivity', type=1, weight=difficulty)
        
        # Connect to most profitable nodes (they're key to solving the problem)
        node_profits = []
        for i in range(n):
            _, profit_sum = valid_transitions[i]
            if profit_sum > 0:
                node_profits.append((i, profit_sum))
        
        node_profits.sort(key=lambda x: x[1], reverse=True)
        for i, (node_idx, _) in enumerate(node_profits[:min(5, len(node_profits))]):
            # Stronger connections to more profitable nodes
            connection_weight = math.exp(-0.5 * i)  # Exponential decay
            G.add_edge(f'node_{node_idx}', 'global_connectivity', weight=connection_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()