#!/usr/bin/env python3
"""
Graph converter for Lot Sizing CP problem.
Created using subagent_prompt.md version: v_02

This problem is about production scheduling and inventory management.
Key challenges: 
- Balancing setup/changeover costs with inventory holding costs
- Meeting due dates while minimizing total cost
- Managing production sequences across different item types
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the lot sizing problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Model as bipartite graph with production scheduling constraints
    - Orders as Type 0 nodes (weighted by urgency and cost impact)
    - Constraints as Type 1 nodes (weighted by tightness and complexity)
    - Time periods as Type 2 nodes (weighted by scarcity/utilization)
    - Edges capture temporal dependencies, setup costs, and capacity constraints
    """
    
    # Extract problem parameters
    nb_item_types = json_data.get('nb_item_types', 0)
    nb_orders = json_data.get('nb_orders', 0)
    nb_periods = json_data.get('nb_periods', 0)
    inventory_cost = json_data.get('inventory_cost', 1)
    due_period = json_data.get('due_period', [])
    nb_of_orders = json_data.get('nb_of_orders', [])
    change_cost_flat = json_data.get('change_cost', [])
    
    if not all([nb_orders, nb_periods, due_period, nb_of_orders]):
        return nx.Graph()
    
    # Reconstruct item_type mapping from nb_of_orders
    item_type = [0]  # Start with dummy order 0
    for item_idx, count in enumerate(nb_of_orders):
        item_type.extend([item_idx + 1] * count)
    
    # Reshape change_cost matrix (flattened in JSON)
    change_cost = {}
    expected_size = (nb_orders + 1) * (nb_orders + 1)
    if len(change_cost_flat) == expected_size:
        for i in range(nb_orders + 1):
            for j in range(nb_orders + 1):
                idx = i * (nb_orders + 1) + j
                change_cost[(i, j)] = change_cost_flat[idx]
    
    G = nx.Graph()
    
    # Calculate problem-specific metrics
    max_due = max(due_period) if due_period else nb_periods
    max_change_cost = max(change_cost_flat) if change_cost_flat else 1
    avg_change_cost = sum(change_cost_flat) / len(change_cost_flat) if change_cost_flat else 0
    
    # Type 0: Order nodes (decision variables)
    for order in range(1, nb_orders + 1):
        due = due_period[order - 1] if order - 1 < len(due_period) else max_due
        
        # Weight by urgency (tight due dates) and potential inventory cost
        urgency = 1.0 - (due - 1) / max_due  # Earlier due dates = higher urgency
        max_inventory_periods = max_due - 1
        potential_inventory_cost = max_inventory_periods * inventory_cost
        cost_impact = potential_inventory_cost / (max_change_cost + potential_inventory_cost)
        
        # Combine urgency and cost impact with non-linear weighting
        weight = 0.6 * urgency + 0.4 * cost_impact
        weight = min(math.exp(weight * 2.0 - 1.0) / math.exp(1.0), 1.0)  # Exponential scaling
        
        G.add_node(f'order_{order}', type=0, weight=weight)
    
    # Type 2: Time period nodes (resources)
    period_utilization = {}
    for period in range(1, nb_periods + 1):
        # Calculate how many orders could potentially be scheduled in this period
        orders_due_by_period = sum(1 for due in due_period if due >= period)
        remaining_periods = nb_periods - period + 1
        utilization = min(orders_due_by_period / max(remaining_periods, 1), 1.0)
        period_utilization[period] = utilization
        
        # Weight by scarcity (periods with high demand relative to remaining time)
        scarcity_weight = math.sqrt(utilization)  # Non-linear scaling
        G.add_node(f'period_{period}', type=2, weight=scarcity_weight)
    
    # Type 1: Constraint nodes
    
    # 1. Due date constraints (one per order)
    for order in range(1, nb_orders + 1):
        due = due_period[order - 1] if order - 1 < len(due_period) else max_due
        # Tightness based on how restrictive the due date is
        tightness = 1.0 - (due - 1) / max_due
        G.add_node(f'due_constraint_{order}', type=1, weight=tightness)
        
        # Connect order to its due date constraint
        G.add_edge(f'order_{order}', f'due_constraint_{order}', weight=1.0)
        
        # Connect to feasible periods (before due date)
        for period in range(1, min(due + 1, nb_periods + 1)):
            feasibility = 1.0 - (period - 1) / due  # Earlier periods more feasible
            G.add_edge(f'due_constraint_{order}', f'period_{period}', weight=feasibility)
    
    # 2. Global cardinality constraint (each order exactly once)
    cardinality_tightness = nb_orders / nb_periods  # Higher when fewer free periods
    G.add_node('cardinality_constraint', type=1, weight=min(cardinality_tightness, 1.0))
    
    # Connect all orders to cardinality constraint
    for order in range(1, nb_orders + 1):
        G.add_edge(f'order_{order}', 'cardinality_constraint', weight=0.8)
    
    # 3. Setup/changeover cost constraints (one per adjacent period pair)
    for period in range(1, nb_periods):
        # Calculate average setup cost for this transition
        relevant_costs = []
        for i in range(nb_orders + 1):
            for j in range(nb_orders + 1):
                if (i, j) in change_cost and change_cost[(i, j)] > 0:
                    relevant_costs.append(change_cost[(i, j)])
        
        avg_cost = sum(relevant_costs) / len(relevant_costs) if relevant_costs else 0
        cost_weight = avg_cost / max_change_cost if max_change_cost > 0 else 0.5
        
        G.add_node(f'setup_constraint_{period}', type=1, weight=cost_weight)
        
        # Connect to adjacent periods
        G.add_edge(f'setup_constraint_{period}', f'period_{period}', weight=0.9)
        G.add_edge(f'setup_constraint_{period}', f'period_{period + 1}', weight=0.9)
    
    # 4. Item type grouping constraints (orders of same type have ordering constraints)
    current_order = 1
    for item_idx, count in enumerate(nb_of_orders):
        if count > 1:  # Only create constraint if multiple orders of this type
            complexity = math.log(count + 1) / math.log(max(nb_of_orders) + 1)  # Log scaling
            G.add_node(f'item_ordering_{item_idx + 1}', type=1, weight=complexity)
            
            # Connect orders of this item type to ordering constraint
            for i in range(count):
                order = current_order + i
                if order <= nb_orders:
                    weight = 1.0 - i / count  # Earlier orders in type have higher weight
                    G.add_edge(f'order_{order}', f'item_ordering_{item_idx + 1}', weight=weight)
        
        current_order += count
    
    # Add setup cost conflict edges between orders of different types
    # Only for high-cost transitions to avoid dense graphs
    high_cost_threshold = avg_change_cost * 1.5 if avg_change_cost > 0 else max_change_cost * 0.7
    
    for order1 in range(1, nb_orders + 1):
        for order2 in range(order1 + 1, nb_orders + 1):
            # Check if orders are of different item types
            item1 = item_type[order1] if order1 < len(item_type) else 1
            item2 = item_type[order2] if order2 < len(item_type) else 1
            
            if item1 != item2:
                # Get setup cost between these order types
                cost12 = change_cost.get((order1, order2), 0)
                cost21 = change_cost.get((order2, order1), 0)
                max_cost = max(cost12, cost21)
                
                if max_cost > high_cost_threshold:
                    # Add conflict edge weighted by setup cost
                    conflict_weight = min(max_cost / max_change_cost, 1.0)
                    G.add_edge(f'order_{order1}', f'order_{order2}', weight=conflict_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()