#!/usr/bin/env python3
"""
Graph converter for Waste Water Treatment Plant Scheduling Problem (wwtpp).
Created using subagent_prompt.md version: v_02

This problem is about scheduling waste water treatment from multiple industries
across 24 time periods while respecting tank capacities and treatment plant
capacity constraints.

Key challenges: Balancing competing demands across time, managing buffer
storage, and handling complex temporal constraints on when different
industries can discharge waste.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the waste water treatment scheduling instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph modeling the temporal scheduling problem
    - Industry-timestep variables (type 0) represent discharge decisions
    - Capacity constraints (type 1) represent resource limitations
    - Time dependencies capture buffer balance requirements
    - Weights reflect demand pressure, capacity utilization, and temporal constraints
    
    What makes instances hard:
    - High demand relative to treatment capacity
    - Many active industries with competing demands
    - Complex temporal patterns in discharge requirements
    - Tight buffer capacity constraints
    """
    # Access data from json_data dict
    industries = json_data.get('INDUSTRIES', 8)
    timesteps = json_data.get('TIMESTEPS', 24)
    max_capacity = json_data.get('max_capacitat', 2000)
    tank_flow = json_data.get('TankFlow', [])
    tank_capacity = json_data.get('TankCapacity', [])
    demand = json_data.get('d', [])
    
    # Reshape demand array - it's flattened as [industry1_t1, industry1_t2, ..., industry2_t1, ...]
    demand_matrix = []
    for i in range(industries):
        industry_demand = []
        for t in range(timesteps):
            idx = i * timesteps + t
            if idx < len(demand):
                industry_demand.append(demand[idx])
            else:
                industry_demand.append(0)
        demand_matrix.append(industry_demand)
    
    G = nx.Graph()
    
    # Calculate problem statistics for intelligent weighting
    total_demand_per_timestep = []
    for t in range(timesteps):
        total = sum(demand_matrix[i][t] for i in range(industries))
        total_demand_per_timestep.append(total)
    
    max_total_demand = max(total_demand_per_timestep) if total_demand_per_timestep else 1
    max_tank_capacity = max(tank_capacity) if tank_capacity else 1
    max_tank_flow = max(tank_flow) if tank_flow else 1
    
    # Industry-timestep variable nodes (type 0)
    # Weight by demand pressure relative to available capacity
    for i in range(industries):
        tank_cap = tank_capacity[i] if i < len(tank_capacity) else 1000
        tank_fl = tank_flow[i] if i < len(tank_flow) else 1000
        
        for t in range(timesteps):
            demand_val = demand_matrix[i][t]
            
            # Variable weight reflects demand pressure and resource constraints
            # Higher weights for high-demand periods and constrained industries
            capacity_pressure = tank_cap / max_tank_capacity
            flow_pressure = tank_fl / max_tank_flow
            demand_pressure = demand_val / max(max_total_demand, 1) if max_total_demand > 0 else 0
            
            # Combine factors with non-linear scaling
            weight = math.sqrt(demand_pressure) * (1.0 - math.exp(-2.0 * capacity_pressure)) * flow_pressure
            weight = min(weight, 1.0)
            
            G.add_node(f'var_i{i}_t{t}', type=0, weight=weight)
    
    # Capacity constraint nodes (type 1) - one per timestep
    # Weight by capacity utilization and demand competition
    for t in range(timesteps):
        total_demand_t = total_demand_per_timestep[t]
        
        # Capacity tightness - higher when demand approaches capacity
        if max_capacity > 0:
            utilization = total_demand_t / max_capacity
            tightness = 1.0 - math.exp(-3.0 * utilization)  # Exponential growth in difficulty
        else:
            tightness = 0.5
        
        G.add_node(f'capacity_t{t}', type=1, weight=min(tightness, 1.0))
    
    # Tank capacity constraint nodes (type 1) - one per industry
    # Weight by relative capacity and usage intensity
    for i in range(industries):
        tank_cap = tank_capacity[i] if i < len(tank_capacity) else 1000
        
        # Calculate total demand for this industry across all timesteps
        total_industry_demand = sum(demand_matrix[i])
        
        if tank_cap > 0:
            demand_to_capacity_ratio = total_industry_demand / tank_cap
            # Non-linear scaling for capacity pressure
            tightness = math.tanh(demand_to_capacity_ratio)  # Saturates at 1.0
        else:
            tightness = 0.5
        
        G.add_node(f'tank_capacity_i{i}', type=1, weight=tightness)
    
    # Tank flow constraint nodes (type 1) - one per industry
    # Weight by flow limitation relative to demand
    for i in range(industries):
        tank_fl = tank_flow[i] if i < len(tank_flow) else 1000
        
        # Find peak demand for this industry
        peak_demand = max(demand_matrix[i]) if demand_matrix[i] else 0
        
        if tank_fl > 0:
            flow_pressure = peak_demand / tank_fl
            tightness = 1.0 - math.exp(-2.0 * flow_pressure)
        else:
            tightness = 0.5
        
        G.add_node(f'tank_flow_i{i}', type=1, weight=min(tightness, 1.0))
    
    # Bipartite edges: Variables to capacity constraints
    for t in range(timesteps):
        for i in range(industries):
            demand_val = demand_matrix[i][t]
            if demand_val > 0:  # Only connect if there's actual demand
                # Edge weight reflects contribution to capacity pressure
                if max_capacity > 0:
                    contribution = demand_val / max_capacity
                    weight = min(contribution * 2.0, 1.0)  # Scale appropriately
                else:
                    weight = 0.5
                
                G.add_edge(f'var_i{i}_t{t}', f'capacity_t{t}', weight=weight)
    
    # Bipartite edges: Variables to tank capacity constraints
    for i in range(industries):
        tank_cap = tank_capacity[i] if i < len(tank_capacity) else 1000
        
        for t in range(timesteps):
            demand_val = demand_matrix[i][t]
            if demand_val > 0:
                # Edge weight reflects buffer storage pressure
                if tank_cap > 0:
                    storage_pressure = demand_val / tank_cap
                    weight = min(storage_pressure * 3.0, 1.0)  # Amplify storage constraints
                else:
                    weight = 0.5
                
                G.add_edge(f'var_i{i}_t{t}', f'tank_capacity_i{i}', weight=weight)
    
    # Bipartite edges: Variables to tank flow constraints
    for i in range(industries):
        tank_fl = tank_flow[i] if i < len(tank_flow) else 1000
        
        for t in range(timesteps):
            demand_val = demand_matrix[i][t]
            if demand_val > 0:
                # Edge weight reflects flow limitation
                if tank_fl > 0:
                    flow_ratio = demand_val / tank_fl
                    weight = min(flow_ratio, 1.0)
                else:
                    weight = 0.5
                
                G.add_edge(f'var_i{i}_t{t}', f'tank_flow_i{i}', weight=weight)
    
    # Temporal dependency edges (type 0 to type 0)
    # Connect consecutive timesteps for the same industry to model buffer dynamics
    for i in range(industries):
        for t in range(timesteps - 1):
            current_demand = demand_matrix[i][t]
            next_demand = demand_matrix[i][t + 1]
            
            # Only add temporal edges if both timesteps have demand
            if current_demand > 0 and next_demand > 0:
                # Weight reflects temporal coupling strength
                coupling = (current_demand + next_demand) / (2 * max(max_total_demand, 1))
                weight = math.sqrt(coupling)  # Non-linear scaling
                
                G.add_edge(f'var_i{i}_t{t}', f'var_i{i}_t{t+1}', weight=min(weight, 1.0))
    
    # Competition edges between industries at peak demand timesteps
    # Find highly contested timesteps (high total demand)
    contested_timesteps = []
    for t in range(timesteps):
        if total_demand_per_timestep[t] > max_capacity * 0.8:  # 80% of capacity
            contested_timesteps.append(t)
    
    # Add competition edges for contested timesteps
    for t in contested_timesteps:
        active_industries = []
        for i in range(industries):
            if demand_matrix[i][t] > 0:
                active_industries.append((i, demand_matrix[i][t]))
        
        # Sort by demand to prioritize high-demand competitors
        active_industries.sort(key=lambda x: x[1], reverse=True)
        
        # Add edges between top competitors
        for idx1 in range(min(len(active_industries), 4)):  # Limit to top 4
            for idx2 in range(idx1 + 1, min(len(active_industries), 4)):
                i1, demand1 = active_industries[idx1]
                i2, demand2 = active_industries[idx2]
                
                # Competition strength based on combined demand pressure
                if max_capacity > 0:
                    competition = (demand1 + demand2) / max_capacity
                    weight = min(competition, 1.0)
                    
                    G.add_edge(f'var_i{i1}_t{t}', f'var_i{i2}_t{t}', weight=weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()