#!/usr/bin/env python3
"""
Graph converter for Generalized Balanced Academic Curriculum (gbac_1) problem.
Converter created with subagent_prompt.md v_02

This problem is about scheduling courses across periods while balancing curricula loads,
respecting precedences, and avoiding undesirable period assignments.
Key challenges: load balancing across curricula, complex precedence relationships, 
and multiple conflicting objectives (load balance vs undesirable violations).
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def parse_dzn_arrays(dzn_file):
    """Parse the 2D arrays from DZN file that aren't in JSON."""
    courses_of = {}
    precedes = []
    undesirable = []
    
    with open(dzn_file, 'r') as f:
        content = f.read()
    
    # Parse courses_of array
    if 'courses_of =' in content:
        start = content.find('courses_of =') + len('courses_of =')
        end = content.find(';', start)
        courses_str = content[start:end].strip()
        
        # Extract sets from the array
        import re
        sets = re.findall(r'\{[^}]*\}', courses_str)
        for i, s in enumerate(sets):
            # Extract numbers from set notation
            numbers = re.findall(r'\d+', s)
            courses_of[i+1] = [int(n) for n in numbers]
    
    # Parse precedences array
    if 'precedes =' in content:
        start = content.find('precedes =')
        end = content.find(');', start) + 2
        precedes_str = content[start:end]
        
        # Extract pairs from array2d
        import re
        numbers = re.findall(r'\d+', precedes_str.split('[', 1)[1])
        for i in range(0, len(numbers), 2):
            if i+1 < len(numbers):
                precedes.append((int(numbers[i]), int(numbers[i+1])))
    
    # Parse undesirable array
    if 'undesirable =' in content:
        start = content.find('undesirable =')
        end = content.find(');', start) + 2
        undes_str = content[start:end]
        
        # Extract pairs from array2d
        import re
        numbers = re.findall(r'\d+', undes_str.split('[', 1)[1])
        for i in range(0, len(numbers), 2):
            if i+1 < len(numbers):
                undesirable.append((int(numbers[i]), int(numbers[i+1])))
    
    return courses_of, precedes, undesirable


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the GBAC problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create bipartite graph with courses, curricula constraints, 
    precedence constraints, and undesirable assignment constraints.
    - Courses (type 0): weighted by their load relative to max
    - Curriculum constraints (type 1): weighted by load imbalance potential
    - Precedence constraints (type 1): weighted by constraint tightness
    - Undesirable constraints (type 1): weighted by violation impact
    - Period resources (type 2): weighted by capacity tightness
    """
    # Extract basic parameters
    n_courses = json_data.get('n_courses', 0)
    n_periods = json_data.get('n_periods', 0)
    n_curricula = json_data.get('n_curricula', 0)
    n_precedences = json_data.get('n_precedences', 0)
    n_undesirables = json_data.get('n_undesirables', 0)
    course_load = json_data.get('course_load', [])
    min_courses = json_data.get('min_courses', 0)
    max_courses = json_data.get('max_courses', 0)
    w1 = json_data.get('w1', 1)
    w2 = json_data.get('w2', 1)
    
    G = nx.Graph()
    
    if not course_load:
        return G
    
    # Get DZN file path from mzn_file
    mzn_path = Path(mzn_file)
    dzn_files = list(mzn_path.parent.glob('*.dzn'))
    if not dzn_files:
        # Fallback: create simplified graph with basic data
        max_load = max(course_load) if course_load else 1
        
        # Course nodes with load-based weights
        for i in range(n_courses):
            load = course_load[i] if i < len(course_load) else 1
            weight = load / max_load
            G.add_node(f'course_{i+1}', type=0, weight=weight)
        
        # Period resource nodes
        total_load = sum(course_load)
        avg_period_load = total_load / n_periods if n_periods > 0 else 1
        for p in range(n_periods):
            capacity_ratio = (max_courses - min_courses) / avg_period_load if avg_period_load > 0 else 0.5
            tightness = 1.0 / (1.0 + capacity_ratio)
            G.add_node(f'period_{p+1}', type=2, weight=tightness)
        
        # Connect courses to periods (all possible assignments)
        for i in range(n_courses):
            for p in range(n_periods):
                load_ratio = course_load[i] / max_courses if max_courses > 0 else 0.5
                G.add_edge(f'course_{i+1}', f'period_{p+1}', weight=min(load_ratio, 1.0))
        
        return G
    
    # Parse complex arrays from DZN file
    try:
        courses_of, precedes, undesirable = parse_dzn_arrays(dzn_files[0])
    except:
        # Fallback if parsing fails
        courses_of = {}
        precedes = []
        undesirable = []
    
    max_load = max(course_load) if course_load else 1
    
    # Course nodes (type 0) - weighted by load
    for i in range(n_courses):
        load = course_load[i] if i < len(course_load) else 1
        # Use non-linear weighting: higher loads get exponentially higher weights
        weight = math.tanh(load / max_load * 2)  # tanh for smooth scaling
        G.add_node(f'course_{i+1}', type=0, weight=weight)
    
    # Period resource nodes (type 2) - weighted by capacity pressure
    total_load = sum(course_load)
    avg_period_load = total_load / n_periods if n_periods > 0 else 1
    for p in range(n_periods):
        # Tightness based on capacity vs expected load
        expected_courses = n_courses / n_periods
        capacity_ratio = max_courses / expected_courses if expected_courses > 0 else 1.0
        tightness = 1.0 - min(capacity_ratio / 2.0, 1.0)  # Higher tightness for lower capacity
        G.add_node(f'period_{p+1}', type=2, weight=tightness)
    
    # Curriculum constraint nodes (type 1)
    for curr_id in range(1, n_curricula + 1):
        if curr_id in courses_of:
            curriculum_courses = courses_of[curr_id]
            curriculum_load = sum(course_load[c-1] for c in curriculum_courses if c <= len(course_load))
            ideal_load_per_period = curriculum_load / n_periods
            
            # Weight by load imbalance potential
            load_variance = curriculum_load / n_periods if n_periods > 0 else 0
            difficulty = min(load_variance / max_load, 1.0) if max_load > 0 else 0.5
            G.add_node(f'curriculum_{curr_id}', type=1, weight=difficulty)
            
            # Connect courses to their curriculum
            for course in curriculum_courses:
                if course <= n_courses:
                    # Edge weight based on course's contribution to curriculum load
                    course_contribution = course_load[course-1] / curriculum_load if curriculum_load > 0 else 0.5
                    G.add_edge(f'course_{course}', f'curriculum_{curr_id}', 
                             weight=min(course_contribution * 2, 1.0))
        else:
            # Default curriculum constraint if no specific courses
            G.add_node(f'curriculum_{curr_id}', type=1, weight=0.5)
    
    # Precedence constraint nodes (type 1)
    for i, (before, after) in enumerate(precedes):
        if before <= n_courses and after <= n_courses:
            # Weight by load imbalance of the precedence pair
            load_before = course_load[before-1] if before <= len(course_load) else 1
            load_after = course_load[after-1] if after <= len(course_load) else 1
            total_prec_load = load_before + load_after
            
            # Higher weight for precedences involving high-load courses
            weight = min(total_prec_load / (2 * max_load), 1.0) if max_load > 0 else 0.5
            G.add_node(f'precedence_{i+1}', type=1, weight=weight)
            
            # Connect both courses to the precedence constraint
            G.add_edge(f'course_{before}', f'precedence_{i+1}', weight=0.8)
            G.add_edge(f'course_{after}', f'precedence_{i+1}', weight=0.8)
    
    # Undesirable assignment constraint nodes (type 1)
    period_undesirable_counts = {}
    for course, period in undesirable:
        period_undesirable_counts[period] = period_undesirable_counts.get(period, 0) + 1
    
    for i, (course, period) in enumerate(undesirable):
        if course <= n_courses and period <= n_periods:
            # Weight by objective weight and course load
            course_load_val = course_load[course-1] if course <= len(course_load) else 1
            period_pressure = period_undesirable_counts.get(period, 1)
            
            # Higher weight for undesirable assignments of heavy courses or crowded periods
            weight = min((course_load_val / max_load + math.log(period_pressure + 1) / 5) * w2/10, 1.0) if max_load > 0 else 0.3
            G.add_node(f'undesirable_{i+1}', type=1, weight=weight)
            
            # Connect course to undesirable constraint
            G.add_edge(f'course_{course}', f'undesirable_{i+1}', weight=0.6)
            # Connect period to undesirable constraint  
            G.add_edge(f'period_{period}', f'undesirable_{i+1}', weight=0.6)
    
    # Course-period assignment edges (representing possible assignments)
    for i in range(n_courses):
        course_load_val = course_load[i] if i < len(course_load) else 1
        for p in range(n_periods):
            # Weight by load relative to period capacity
            assignment_pressure = course_load_val / max_courses if max_courses > 0 else 0.5
            G.add_edge(f'course_{i+1}', f'period_{p+1}', 
                      weight=min(assignment_pressure, 1.0))
    
    # Add conflict edges between courses with tight precedence relationships
    for i, (before, after) in enumerate(precedes):
        if before <= n_courses and after <= n_courses:
            # Check if they're in same curriculum (creates tighter coupling)
            same_curriculum = False
            for curr_id, curriculum_courses in courses_of.items():
                if before in curriculum_courses and after in curriculum_courses:
                    same_curriculum = True
                    break
            
            if same_curriculum:
                # Add direct conflict edge for courses in same curriculum with precedence
                load_conflict = (course_load[before-1] + course_load[after-1]) / (2 * max_load) if max_load > 0 else 0.5
                G.add_edge(f'course_{before}', f'course_{after}', 
                          weight=min(load_conflict * 1.5, 1.0))
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()