#!/usr/bin/env python3
"""
Graph converter for valve-network problem.
Created using subagent_prompt.md version: v_02

This problem is about optimizing valve operations in a network to maximize flow.
Key challenges: coordinating two agents (Me, Elephant), timing valve openings, 
network traversal constraints, and maximizing cumulative flow over time horizon.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the valve network problem instance.
    
    Args:
        mzn_file: Path to .mzn file (contains network structure)
        json_data: Dict containing parsed DZN data (horizon only)
    
    Strategy: Create a graph that captures the valve network structure,
    flow opportunities, and temporal coordination challenges.
    - Valve nodes (type 0): Decision points for opening valves
    - Constraint nodes (type 1): Timing, movement, and coordination constraints  
    - Resource nodes (type 2): Time slots and positions
    - Edge weights reflect flow potential, distance, and coordination difficulty
    """
    
    # Get horizon from JSON data
    horizon = json_data.get('horizon', 6)
    
    # Network structure is embedded in MZN file - extract key information
    # Based on the MZN file analysis:
    nodes = ['GJ', 'HE', 'ET', 'SG', 'LC', 'EE', 'AA', 'TF', 'GO', 'QE', 'MI', 'BR', 'UV', 'EH', 'WK', 'NT', 'KI', 'AH', 'EL', 'GP', 'GM', 'LU', 'LB', 'QC', 'JJ', 'MM', 'VI', 'NV', 'VT', 'RE', 'FO', 'DV', 'SQ', 'OQ', 'FF', 'IV', 'HY', 'ML', 'JS', 'KU', 'QA', 'EU', 'SV', 'JG', 'DW', 'UD', 'QJ', 'HU', 'ZR', 'YA', 'JH', 'OS', 'LG', 'SB', 'UU', 'VL', 'AO', 'EM']
    
    # Flow rates for each valve (0 means no flow benefit)
    flow_rates = [14, 0, 0, 0, 0, 13, 0, 0, 0, 24, 0, 0, 0, 0, 0, 0, 0, 22, 0, 0, 0, 9, 0, 0, 0, 0, 18, 0, 0, 0, 0, 10, 12, 23, 3, 0, 8, 0, 0, 5, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 15, 0, 4, 0, 0, 0, 0]
    
    # Network connections (simplified representation)
    connections = {
        'GJ': ['UV', 'AO', 'MM', 'UD', 'GM'],
        'HE': ['QE', 'SV'],
        'ET': ['LU', 'SB'],
        'SG': ['FF', 'SB'],
        'LC': ['QJ', 'GM'],
        'EE': ['RE', 'BR'],
        'AA': ['QC', 'ZR', 'NT', 'JG', 'FO'],  # Starting position
        'TF': ['LU', 'MM'],
        'GO': ['LB', 'AH'],
        'QE': ['LG', 'HE'],
        'MI': ['KU', 'FF'],
        'BR': ['HY', 'EE'],
        'UV': ['GP', 'GJ'],
        'EH': ['UU', 'FF'],
        'WK': ['HY', 'EL'],
        'NT': ['FF', 'AA'],
        'KI': ['OQ', 'AO'],
        'AH': ['GO', 'RE'],
        'EL': ['WK', 'SQ'],
        'GP': ['SB', 'UV'],
        'GM': ['LC', 'GJ'],
        'LU': ['UU', 'DW', 'TF', 'ET', 'ML'],
        'LB': ['GO', 'VI'],
        'QC': ['ML', 'AA'],
        'JJ': ['QJ', 'DV'],
        'MM': ['TF', 'GJ'],
        'VI': ['LB'],
        'NV': ['SB', 'KU'],
        'VT': ['HY', 'JG'],
        'RE': ['AH', 'EE'],
        'FO': ['SB', 'AA'],
        'DV': ['JH', 'UD', 'JJ'],
        'SQ': ['EL', 'QA'],
        'OQ': ['KI', 'IV', 'JS'],
        'FF': ['EU', 'NT', 'SG', 'MI', 'EH'],
        'IV': ['LG', 'OQ'],
        'HY': ['VT', 'BR', 'WK'],
        'ML': ['LU', 'QC'],
        'JS': ['EM', 'OQ'],
        'KU': ['MI', 'VL', 'NV', 'HU', 'DW'],
        'QA': ['OS', 'SQ'],
        'EU': ['FF', 'OS'],
        'SV': ['QJ', 'HE'],
        'JG': ['AA', 'VT'],
        'DW': ['LU', 'KU'],
        'UD': ['DV', 'GJ'],
        'QJ': ['JJ', 'SV', 'LC', 'EM', 'YA'],
        'HU': ['JH', 'KU'],
        'ZR': ['AA', 'VL'],
        'YA': ['QJ', 'OS'],
        'JH': ['HU', 'DV'],
        'OS': ['EU', 'YA', 'QA'],
        'LG': ['QE', 'IV'],
        'SB': ['FO', 'SG', 'NV', 'GP', 'ET'],
        'UU': ['EH', 'LU'],
        'VL': ['ZR', 'KU'],
        'AO': ['GJ', 'KI'],
        'EM': ['QJ', 'JS']
    }
    
    G = nx.Graph()
    
    # Calculate metrics for weighting
    max_flow = max(flow_rates) if flow_rates else 1
    total_flow = sum(flow_rates)
    valuable_valves = [i for i, flow in enumerate(flow_rates) if flow > 0]
    num_nodes = len(nodes)
    
    # Type 0: Valve nodes (decision points)
    for i, (node, flow_rate) in enumerate(zip(nodes, flow_rates)):
        if flow_rate > 0:
            # Weight by flow potential and time value
            flow_importance = flow_rate / max_flow
            time_value = flow_rate * horizon  # Total potential if opened at start
            total_potential = sum(flow_rates) * horizon
            weight = min(flow_importance * 0.7 + (time_value / total_potential) * 0.3, 1.0)
        else:
            # Non-valuable valves still matter for positioning/movement
            weight = 0.1
        
        G.add_node(f'valve_{node}', type=0, weight=weight)
    
    # Type 1: Constraint nodes for different constraint types
    
    # 1. Movement constraints - each node has movement limitations
    for node in nodes:
        num_connections = len(connections.get(node, []))
        # More connections = less constraining = lower weight
        constraint_tightness = 1.0 - (num_connections / 5.0)  # Normalize by max connections
        constraint_tightness = max(0.2, min(constraint_tightness, 1.0))
        G.add_node(f'movement_constraint_{node}', type=1, weight=constraint_tightness)
    
    # 2. Timing constraints - harder with longer horizons
    horizon_complexity = min(horizon / 30.0, 1.0)  # Normalize by reasonable max horizon
    G.add_node('timing_constraint', type=1, weight=horizon_complexity)
    
    # 3. Coordination constraints - managing two agents
    # Weight by potential for conflicts (more valuable valves = more coordination needed)
    coordination_complexity = len(valuable_valves) / num_nodes
    G.add_node('coordination_constraint', type=1, weight=coordination_complexity)
    
    # 4. Flow optimization constraint - balancing immediate vs future gains
    G.add_node('flow_optimization_constraint', type=1, weight=0.8)
    
    # Type 2: Resource nodes
    
    # 1. Time slots - each minute is a limited resource
    for minute in range(1, horizon + 1):
        # Later time slots are more valuable (less remaining time)
        time_pressure = minute / horizon
        weight = time_pressure ** 0.5  # Square root for non-linear scaling
        G.add_node(f'time_{minute}', type=2, weight=weight)
    
    # 2. Position resources - starting position is critical
    for node in nodes:
        if node == 'AA':  # Starting position
            weight = 1.0
        else:
            # Distance from start affects accessibility
            weight = 0.5  # Simplified - all other positions have moderate importance
        G.add_node(f'position_{node}', type=2, weight=weight)
    
    # Edges: Model relationships and interactions
    
    # 1. Valve-to-movement constraint edges (bipartite)
    for node in nodes:
        valve_node = f'valve_{node}'
        movement_constraint = f'movement_constraint_{node}'
        # Weight by how constraining the movement is
        num_connections = len(connections.get(node, []))
        constraint_strength = 1.0 - (num_connections / 5.0)
        constraint_strength = max(0.3, min(constraint_strength, 1.0))
        G.add_edge(valve_node, movement_constraint, weight=constraint_strength)
    
    # 2. Valuable valves to timing constraint
    for i, (node, flow_rate) in enumerate(zip(nodes, flow_rates)):
        if flow_rate > 0:
            valve_node = f'valve_{node}'
            # Higher flow valves are more time-sensitive
            time_sensitivity = (flow_rate / max_flow) * 0.8 + 0.2
            G.add_edge(valve_node, 'timing_constraint', weight=time_sensitivity)
    
    # 3. Valuable valves to coordination constraint
    for i, (node, flow_rate) in enumerate(zip(nodes, flow_rates)):
        if flow_rate > 0:
            valve_node = f'valve_{node}'
            # More valuable valves require more coordination
            coordination_need = flow_rate / max_flow
            G.add_edge(valve_node, 'coordination_constraint', weight=coordination_need)
    
    # 4. Network connectivity edges (valve to valve via movement)
    for node, connected_nodes in connections.items():
        valve_node = f'valve_{node}'
        for connected_node in connected_nodes:
            if connected_node in nodes:  # Safety check
                connected_valve = f'valve_{connected_node}'
                # Weight by combined flow potential and distance
                node_idx = nodes.index(node)
                connected_idx = nodes.index(connected_node)
                combined_flow = flow_rates[node_idx] + flow_rates[connected_idx]
                if max_flow > 0:
                    flow_factor = combined_flow / (2 * max_flow)
                    weight = max(0.2, min(flow_factor + 0.3, 1.0))
                else:
                    weight = 0.3
                G.add_edge(valve_node, connected_valve, weight=weight)
    
    # 5. Time resource connections
    for minute in range(1, horizon):
        current_time = f'time_{minute}'
        next_time = f'time_{minute + 1}'
        # Time progression gets more critical as horizon approaches
        urgency = math.exp(minute / horizon - 1)  # Exponential increase
        weight = min(urgency, 1.0)
        G.add_edge(current_time, next_time, weight=weight)
    
    # 6. High-value valves create competition (conflict edges)
    valuable_valve_nodes = []
    for i, (node, flow_rate) in enumerate(zip(nodes, flow_rates)):
        if flow_rate >= max_flow * 0.6:  # Top-tier valves
            valuable_valve_nodes.append((f'valve_{node}', flow_rate))
    
    # Add conflict edges between high-value valves
    for i in range(len(valuable_valve_nodes)):
        for j in range(i + 1, len(valuable_valve_nodes)):
            valve1, flow1 = valuable_valve_nodes[i]
            valve2, flow2 = valuable_valve_nodes[j]
            # Conflict intensity based on combined value and coordination difficulty
            conflict_intensity = ((flow1 + flow2) / (2 * max_flow)) * 0.6
            G.add_edge(valve1, valve2, weight=conflict_intensity)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()