#!/usr/bin/env python3
"""
Graph converter for Unit-Commitment problem.
Created using subagent_prompt.md version: v_02

This problem is about scheduling power generators over time to meet electricity demand
while minimizing costs (generation, startup, shutdown, load shedding).
Key challenges: startup/shutdown costs, ramping constraints, minimum downtime, demand balance
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the unit commitment problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph modeling generators, time periods, loads, and constraints
    - Generator nodes (type 0): represent generation units with capacity/cost weights
    - Load nodes (type 0): represent demand points with criticality weights  
    - Time period nodes (type 2): represent time slots as resources
    - Constraint nodes (type 1): model various operational constraints
    - Edges model participation in constraints and temporal dependencies
    """
    # Extract problem dimensions
    time_horizon = json_data.get('time_horizon', 0)
    num_generators = json_data.get('num_of_generators', 0) 
    num_loads = json_data.get('num_of_loads', 0)
    
    # Extract data arrays (flattened from 2D)
    gen_max = json_data.get('gen_max', [])
    gen_min = json_data.get('gen_min', [])
    dispatch_cost = json_data.get('dispatch_cost', [])
    demand = json_data.get('demand', [])
    startup_cost = json_data.get('startup_cost', [])
    shutdown_cost = json_data.get('shutdown_cost', [])
    shed_cost = json_data.get('shed_cost', [])
    max_ramp_rate = json_data.get('max_ramp_rate', [])
    init_commitment = json_data.get('init_commitment', [])
    
    G = nx.Graph()
    
    # Helper function to get 2D array values safely
    def get_2d_value(arr, gen, time, default=0):
        idx = gen * time_horizon + time
        return arr[idx] if idx < len(arr) else default
    
    # Calculate normalizing factors
    max_capacity = max(gen_max) if gen_max else 1
    max_demand = max(demand) if demand else 1
    max_startup = max(startup_cost) if startup_cost else 1
    max_shed = max(shed_cost) if shed_cost else 1
    max_dispatch = max(dispatch_cost) if dispatch_cost else 1
    
    # Generator nodes (type 0) - weight by capacity and flexibility
    for g in range(num_generators):
        # Calculate average capacity across time
        avg_capacity = sum(get_2d_value(gen_max, g, t) for t in range(time_horizon)) / max(time_horizon, 1)
        # Consider ramp rate flexibility (higher = more flexible)
        ramp_flexibility = max_ramp_rate[g] if g < len(max_ramp_rate) else 0
        # Combine capacity and flexibility for weight
        capacity_weight = avg_capacity / max_capacity if max_capacity > 0 else 0.1
        flexibility_weight = min(ramp_flexibility / max_capacity, 1.0) if max_capacity > 0 else 0.1
        weight = (capacity_weight + flexibility_weight) / 2.0
        G.add_node(f'gen_{g}', type=0, weight=weight)
    
    # Load nodes (type 0) - weight by criticality (inverse of shedding cost)
    for l in range(num_loads):
        # Higher shed cost = more critical = higher weight
        criticality = shed_cost[l] if l < len(shed_cost) else max_shed
        weight = criticality / max_shed if max_shed > 0 else 0.5
        G.add_node(f'load_{l}', type=0, weight=weight)
    
    # Time period nodes (type 2) - weight by total demand
    for t in range(time_horizon):
        total_demand_t = sum(get_2d_value(demand, l, t) for l in range(num_loads))
        weight = total_demand_t / (max_demand * num_loads) if max_demand > 0 else 0.5
        G.add_node(f'time_{t}', type=2, weight=weight)
    
    # Constraint nodes (type 1) with varied weights based on constraint type
    
    # 1. Capacity constraints (gen_max/gen_min) - one per generator-time pair
    for g in range(num_generators):
        for t in range(time_horizon):
            max_cap = get_2d_value(gen_max, g, t)
            min_cap = get_2d_value(gen_min, g, t)
            # Tightness = how constraining the min requirement is
            tightness = min_cap / max(max_cap, 1) if max_cap > 0 else 0.5
            G.add_node(f'capacity_g{g}_t{t}', type=1, weight=tightness)
    
    # 2. Power balance constraints - one per time period (critical)
    for t in range(time_horizon):
        total_demand_t = sum(get_2d_value(demand, l, t) for l in range(num_loads))
        total_capacity_t = sum(get_2d_value(gen_max, g, t) for g in range(num_generators))
        # Balance tightness - higher demand relative to capacity = higher weight
        balance_tightness = total_demand_t / max(total_capacity_t, 1) if total_capacity_t > 0 else 1.0
        G.add_node(f'balance_t{t}', type=1, weight=min(balance_tightness, 1.0))
    
    # 3. Ramping constraints - for generator transitions between time periods
    for g in range(num_generators):
        for t in range(1, time_horizon):
            ramp_rate = max_ramp_rate[g] if g < len(max_ramp_rate) else 0
            # Weight by inverse of ramp capability (less flexible = higher weight)
            ramp_constraint_weight = 1.0 - min(ramp_rate / max_capacity, 1.0) if max_capacity > 0 else 0.8
            G.add_node(f'ramp_g{g}_t{t}', type=1, weight=ramp_constraint_weight)
    
    # 4. Startup cost constraints - weight by startup cost magnitude
    for g in range(num_generators):
        startup = startup_cost[g] if g < len(startup_cost) else 0
        startup_weight = startup / max_startup if max_startup > 0 else 0.1
        # Use exponential scaling for cost impact
        startup_weight = 1.0 - math.exp(-3.0 * startup_weight)
        G.add_node(f'startup_g{g}', type=1, weight=startup_weight)
    
    # Add edges for constraint participation
    
    # Generator-capacity constraint participation
    for g in range(num_generators):
        for t in range(time_horizon):
            max_cap = get_2d_value(gen_max, g, t)
            edge_weight = max_cap / max_capacity if max_capacity > 0 else 0.5
            G.add_edge(f'gen_{g}', f'capacity_g{g}_t{t}', weight=edge_weight)
    
    # Generator-balance constraint participation
    for g in range(num_generators):
        for t in range(time_horizon):
            max_cap = get_2d_value(gen_max, g, t)
            # Weight by how much this generator can contribute to balance
            contribution_weight = max_cap / max_capacity if max_capacity > 0 else 0.5
            G.add_edge(f'gen_{g}', f'balance_t{t}', weight=contribution_weight)
    
    # Load-balance constraint participation
    for l in range(num_loads):
        for t in range(time_horizon):
            load_demand = get_2d_value(demand, l, t)
            demand_weight = load_demand / max_demand if max_demand > 0 else 0.5
            G.add_edge(f'load_{l}', f'balance_t{t}', weight=demand_weight)
    
    # Generator-ramping constraint participation
    for g in range(num_generators):
        for t in range(1, time_horizon):
            # Connect generator to its ramping constraints
            G.add_edge(f'gen_{g}', f'ramp_g{g}_t{t}', weight=0.8)
            # Connect consecutive time periods through ramping
            G.add_edge(f'time_{t-1}', f'ramp_g{g}_t{t}', weight=0.6)
            G.add_edge(f'time_{t}', f'ramp_g{g}_t{t}', weight=0.6)
    
    # Generator-startup constraint participation
    for g in range(num_generators):
        startup = startup_cost[g] if g < len(startup_cost) else 0
        startup_weight = min(startup / max_startup, 1.0) if max_startup > 0 else 0.3
        G.add_edge(f'gen_{g}', f'startup_g{g}', weight=startup_weight)
    
    # Add temporal dependencies between time periods
    for t in range(1, time_horizon):
        # Weight by demand change - higher change = stronger temporal coupling
        demand_t = sum(get_2d_value(demand, l, t) for l in range(num_loads))
        demand_prev = sum(get_2d_value(demand, l, t-1) for l in range(num_loads))
        demand_change = abs(demand_t - demand_prev) / max(max_demand * num_loads, 1)
        temporal_weight = min(demand_change * 2.0, 1.0)
        G.add_edge(f'time_{t-1}', f'time_{t}', weight=max(temporal_weight, 0.2))
    
    # Add conflict edges for high-cost generators competing for the same time slot
    for t in range(time_horizon):
        # Find generators with high startup costs that could conflict
        high_cost_gens = []
        for g in range(num_generators):
            startup = startup_cost[g] if g < len(startup_cost) else 0
            if startup > max_startup * 0.5:  # High startup cost generators
                high_cost_gens.append((g, startup))
        
        # Add conflict edges between high-cost generators in same time period
        for i in range(len(high_cost_gens)):
            for j in range(i + 1, len(high_cost_gens)):
                g1, cost1 = high_cost_gens[i]
                g2, cost2 = high_cost_gens[j]
                # Conflict strength based on combined startup costs
                conflict_weight = (cost1 + cost2) / (2 * max_startup) if max_startup > 0 else 0.3
                G.add_edge(f'gen_{g1}', f'gen_{g2}', weight=conflict_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()