#!/usr/bin/env python3
"""
Graph converter for Network Flow Cost (NFC) problem.
Created using subagent_prompt.md version: v_02

This problem is about optimal worker scheduling with network flow constraints.
Workers are scheduled in shifts that span multiple periods, and we need to meet 
worker demand at each time period while minimizing cost.

Key challenges: 
- Balancing worker allocation across overlapping shifts
- Meeting fluctuating demand across time periods  
- Network flow constraints create complex interdependencies
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path
import re


def parse_dzn_file(dzn_file):
    """Parse worker_count array from DZN file since JSON conversion doesn't capture it."""
    worker_count = []
    with open(dzn_file, 'r') as f:
        content = f.read()
        
    # Extract worker_count array
    match = re.search(r'worker_count\s*=\s*array1d\([^,]+,\s*\[([^\]]+)\]', content)
    if match:
        values_str = match.group(1)
        worker_count = [int(x.strip()) for x in values_str.split(',')]
    
    return worker_count


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the NFC problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Bipartite model with period demand and shift coverage
    - Period nodes (type 0): Time periods needing workers
    - Shift nodes (type 0): Decision variables for worker assignments  
    - Demand constraint nodes (type 1): Ensuring enough workers per period
    - Flow constraint nodes (type 1): Network flow consistency
    - Shift overlap edges model temporal dependencies
    """
    # Get basic parameters from JSON
    n_periods = json_data.get('n_periods', 12)
    shift_periods = json_data.get('shift_periods', 2)
    
    # Parse worker_count from DZN file since JSON doesn't have it
    dzn_file = sys.argv[2] if len(sys.argv) > 2 else None
    if dzn_file:
        worker_count = parse_dzn_file(dzn_file)
    else:
        # Fallback if no DZN file available
        worker_count = [50] * n_periods  # Default demand
    
    if not worker_count:
        worker_count = [50] * n_periods  # Safety fallback
        
    G = nx.Graph()
    
    max_demand = max(worker_count) if worker_count else 1
    total_demand = sum(worker_count) if worker_count else n_periods
    
    # Period nodes (type 0) - represent time periods with worker demand
    for t in range(n_periods):
        demand = worker_count[t] if t < len(worker_count) else 0
        # Weight by relative demand - higher demand periods are more critical
        demand_weight = demand / max_demand if max_demand > 0 else 0.5
        G.add_node(f'period_{t}', type=0, weight=demand_weight)
    
    # Shift decision nodes (type 0) - represent shift starting times
    for t in range(n_periods):
        # Weight shifts by their coverage overlap potential
        # Shifts that cover more high-demand periods are more valuable
        covered_periods = [(t + i) % n_periods for i in range(shift_periods)]
        total_covered_demand = sum(worker_count[p] for p in covered_periods)
        coverage_weight = total_covered_demand / (shift_periods * max_demand) if max_demand > 0 else 0.5
        G.add_node(f'shift_{t}', type=0, weight=min(coverage_weight, 1.0))
    
    # Demand constraint nodes (type 1) - one per period ensuring worker requirements
    for t in range(n_periods):
        demand = worker_count[t] if t < len(worker_count) else 0
        # Weight by constraint tightness - ratio of shifts that can cover this period
        covering_shifts = 0
        for s in range(n_periods):
            shift_coverage = [(s + i) % n_periods for i in range(shift_periods)]
            if t in shift_coverage:
                covering_shifts += 1
        
        # Tightness: fewer covering options = tighter constraint
        tightness = 1.0 - (covering_shifts / n_periods) if n_periods > 0 else 0.5
        G.add_node(f'demand_constraint_{t}', type=1, weight=tightness)
    
    # Network flow constraint node (type 1) - represents global flow balance
    # Weight by problem complexity (more periods and longer shifts = more complex)
    complexity = min(1.0, (n_periods * shift_periods) / 100.0)  # Normalize to [0,1]
    G.add_node('flow_constraint', type=1, weight=complexity)
    
    # Edges: Period to demand constraint (period participates in its demand constraint)
    for t in range(n_periods):
        G.add_edge(f'period_{t}', f'demand_constraint_{t}', weight=1.0)
    
    # Edges: Shift to demand constraints (shift contributes to covered periods)
    for s in range(n_periods):
        shift_coverage = [(s + i) % n_periods for i in range(shift_periods)]
        for t in shift_coverage:
            # Weight by the fraction of the shift's contribution to this period
            contribution = 1.0 / shift_periods  # Each shift contributes equally to covered periods
            G.add_edge(f'shift_{s}', f'demand_constraint_{t}', weight=contribution)
    
    # Edges: All shifts to flow constraint (network flow involves all shifts)
    for s in range(n_periods):
        G.add_edge(f'shift_{s}', 'flow_constraint', weight=0.8)
    
    # Temporal dependency edges between consecutive shifts (overlap in coverage)
    for s in range(n_periods):
        next_shift = (s + 1) % n_periods
        current_coverage = set((s + i) % n_periods for i in range(shift_periods))
        next_coverage = set((next_shift + i) % n_periods for i in range(shift_periods))
        overlap_periods = current_coverage.intersection(next_coverage)
        
        if overlap_periods:
            # Weight by overlap strength - more overlapped periods = stronger dependency
            overlap_weight = len(overlap_periods) / shift_periods
            G.add_edge(f'shift_{s}', f'shift_{next_shift}', weight=overlap_weight)
    
    # High-demand period conflict edges (periods competing for limited shift coverage)
    high_demand_periods = [t for t in range(n_periods) 
                          if worker_count[t] > total_demand / n_periods if worker_count]
    
    for i, t1 in enumerate(high_demand_periods):
        for t2 in high_demand_periods[i+1:]:
            if abs(t1 - t2) <= shift_periods or abs(t1 - t2) >= n_periods - shift_periods:
                # Nearby high-demand periods compete for shift coverage
                demand_ratio = (worker_count[t1] + worker_count[t2]) / (2 * max_demand) if max_demand > 0 else 0.5
                conflict_weight = min(demand_ratio, 1.0)
                G.add_edge(f'period_{t1}', f'period_{t2}', weight=conflict_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()