#!/usr/bin/env python3
"""
Graph converter for Waste Water Treatment Plant Problem.
Created using subagent_prompt.md version: v_02

This problem is about scheduling wastewater treatment from multiple industries
across time periods while respecting plant capacity and tank constraints.
Key challenges: capacity conflicts, temporal dependencies, flow constraints.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the WWTPP instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create bipartite graph with:
    - Industry-time decision variables (type 0) 
    - Multiple constraint types (type 1): capacity, buffer, flow
    - Temporal dependencies between adjacent time periods
    - Capacity conflicts for oversubscribed periods
    """
    industries = json_data.get('INDUSTRIES', 0)
    timesteps = json_data.get('TIMESTEPS', 0)
    max_capacity = json_data.get('max_capacitat', 1)
    tank_flow = json_data.get('TankFlow', [])
    tank_capacity = json_data.get('TankCapacity', [])
    d_flat = json_data.get('d', [])
    
    # Convert flattened d array to 2D
    d = []
    for i in range(industries):
        row = []
        for j in range(timesteps):
            idx = i * timesteps + j
            if idx < len(d_flat):
                row.append(d_flat[idx])
            else:
                row.append(0)
        d.append(row)
    
    G = nx.Graph()
    
    # Variable nodes: industry-time pairs (type 0)
    # Weight by demand relative to tank capacity
    for i in range(industries):
        tank_cap = tank_capacity[i] if i < len(tank_capacity) else 1
        for j in range(timesteps):
            demand = d[i][j] if j < len(d[i]) else 0
            # Higher demand relative to tank capacity = higher weight
            demand_ratio = min(demand / tank_cap, 1.0) if tank_cap > 0 else 0.5
            # Use non-linear weighting for better sensitivity
            weight = 1.0 - math.exp(-3.0 * demand_ratio)
            G.add_node(f'ind_{i}_t_{j}', type=0, weight=weight)
    
    # Constraint nodes (type 1)
    
    # 1. Capacity constraints per timestep
    for j in range(timesteps):
        # Calculate total possible demand at this timestep
        total_demand = sum(d[i][j] for i in range(industries))
        # Tightness based on demand vs capacity
        tightness = min(total_demand / max_capacity, 1.0) if max_capacity > 0 else 0.5
        # Non-linear tightness weighting
        weight = math.exp(-2.0 * (1.0 - tightness)) if tightness < 1.0 else 1.0
        G.add_node(f'capacity_t_{j}', type=1, weight=weight)
    
    # 2. Buffer balance constraints per industry-time
    for i in range(industries):
        for j in range(timesteps):
            # Weight by complexity of buffer management
            flow = tank_flow[i] if i < len(tank_flow) else 1
            cap = tank_capacity[i] if i < len(tank_capacity) else 1
            flow_ratio = flow / cap if cap > 0 else 0.5
            # Higher flow ratio = more complex buffer management
            weight = min(math.sqrt(flow_ratio), 1.0)
            G.add_node(f'buffer_{i}_t_{j}', type=1, weight=weight)
    
    # 3. Flow limit constraints per industry-time
    for i in range(industries):
        flow = tank_flow[i] if i < len(tank_flow) else 1
        cap = tank_capacity[i] if i < len(tank_capacity) else 1
        # Weight by tank utilization potential
        utilization = min(flow / cap, 1.0) if cap > 0 else 0.5
        weight = 1.0 - math.exp(-2.0 * utilization)
        for j in range(timesteps):
            G.add_node(f'flow_{i}_t_{j}', type=1, weight=weight)
    
    # Bipartite edges: variables to constraints
    
    # Variable participation in capacity constraints
    for j in range(timesteps):
        total_demand = sum(d[i][j] for i in range(industries))
        for i in range(industries):
            demand = d[i][j]
            if demand > 0:
                # Edge weight by contribution to capacity usage
                contribution = demand / total_demand if total_demand > 0 else 0.5
                weight = min(contribution * 2.0, 1.0)  # Scale up for visibility
                G.add_edge(f'ind_{i}_t_{j}', f'capacity_t_{j}', weight=weight)
    
    # Variable participation in buffer constraints
    for i in range(industries):
        for j in range(timesteps):
            # All variables participate in their buffer constraint
            G.add_edge(f'ind_{i}_t_{j}', f'buffer_{i}_t_{j}', weight=0.8)
    
    # Variable participation in flow constraints
    for i in range(industries):
        for j in range(timesteps):
            demand = d[i][j]
            flow = tank_flow[i] if i < len(tank_flow) else 1
            if demand > 0:
                # Weight by how much this variable stresses the flow limit
                stress = min(demand / flow, 1.0) if flow > 0 else 0.5
                weight = math.sqrt(stress)
                G.add_edge(f'ind_{i}_t_{j}', f'flow_{i}_t_{j}', weight=weight)
    
    # Temporal dependency edges (variable to variable)
    for i in range(industries):
        for j in range(1, timesteps):
            # Buffer at time j depends on buffer at time j-1
            # Weight by demand change between periods
            demand_curr = d[i][j]
            demand_prev = d[i][j-1]
            demand_change = abs(demand_curr - demand_prev)
            cap = tank_capacity[i] if i < len(tank_capacity) else 1
            if demand_change > 0:
                change_ratio = min(demand_change / cap, 1.0) if cap > 0 else 0.5
                weight = math.sqrt(change_ratio)
                G.add_edge(f'ind_{i}_t_{j-1}', f'ind_{i}_t_{j}', weight=weight)
    
    # Conflict edges for oversubscribed timesteps
    for j in range(timesteps):
        total_demand = sum(d[i][j] for i in range(industries))
        if total_demand > max_capacity * 1.2:  # Significantly oversubscribed
            # Create conflicts between high-demand industries at this timestep
            industry_demands = [(i, d[i][j]) for i in range(industries) if d[i][j] > 0]
            industry_demands.sort(key=lambda x: x[1], reverse=True)
            
            # Add conflicts between top consumers
            for idx1 in range(min(len(industry_demands), 4)):
                for idx2 in range(idx1+1, min(len(industry_demands), 4)):
                    i1, demand1 = industry_demands[idx1]
                    i2, demand2 = industry_demands[idx2]
                    if demand1 + demand2 > max_capacity * 0.8:  # Can't both be fully treated
                        conflict_strength = (demand1 + demand2) / (max_capacity * 2.0)
                        weight = min(conflict_strength, 1.0)
                        G.add_edge(f'ind_{i1}_t_{j}', f'ind_{i2}_t_{j}', weight=weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()