#!/usr/bin/env python3
"""
Graph converter for Generalized Balanced Academic Curriculum (GBAC) problem.
Converter created with subagent_prompt.md v_02

This problem involves assigning courses to academic periods while satisfying:
- Curriculum requirements (courses belonging to curricula)
- Precedence constraints (prerequisites)  
- Load balancing across periods for each curriculum
- Avoiding undesirable course-period assignments

Key challenges: Managing curriculum load balance, precedence dependencies,
and minimizing violations of undesirable assignments.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def parse_dzn_data(dzn_file):
    """
    Parse DZN file to extract missing data structures.
    This is needed because JSON conversion is incomplete.
    """
    courses_of = []
    precedes = []
    undesirable = []
    
    with open(dzn_file, 'r') as f:
        content = f.read()
    
    # Parse courses_of (curricula to courses mapping)
    import re
    courses_of_match = re.search(r'courses_of\s*=\s*\[(.*?)\]\s*;', content, re.DOTALL)
    if courses_of_match:
        sets_str = courses_of_match.group(1)
        # Extract individual sets
        sets = re.findall(r'\{([^}]*)\}', sets_str)
        for s in sets:
            if s.strip():
                courses = [int(x.strip()) for x in s.split(',') if x.strip()]
                courses_of.append(courses)
            else:
                courses_of.append([])
    
    # Parse precedes (precedence relationships)
    precedes_match = re.search(r'precedes\s*=\s*array2d\([^,]+,[^,]+,\s*\[(.*?)\]\s*\)', content, re.DOTALL)
    if precedes_match:
        values_str = precedes_match.group(1)
        values = [int(x.strip()) for x in values_str.split(',') if x.strip()]
        # Group into pairs
        for i in range(0, len(values), 2):
            if i + 1 < len(values):
                precedes.append([values[i], values[i+1]])
    
    # Parse undesirable (course-period pairs)
    undesirable_match = re.search(r'undesirable\s*=\s*array2d\([^,]+,[^,]+,\s*\[(.*?)\]\s*\)', content, re.DOTALL)
    if undesirable_match:
        values_str = undesirable_match.group(1)
        values = [int(x.strip()) for x in values_str.split(',') if x.strip()]
        # Group into pairs
        for i in range(0, len(values), 2):
            if i + 1 < len(values):
                undesirable.append([values[i], values[i+1]])
    
    return courses_of, precedes, undesirable


def build_graph(mzn_file, json_data):
    """
    Build graph representation of GBAC problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph with explicit constraint modeling
    - Type 0: Courses (variables to assign to periods)
    - Type 1: Constraints (curricula load balance, precedences, undesirable penalties)
    - Type 2: Periods (limited resources)
    
    Edges model participation in constraints and period assignments.
    Weights reflect course importance, constraint tightness, and period scarcity.
    """
    # Get basic parameters
    n_courses = json_data.get('n_courses', 0)
    n_periods = json_data.get('n_periods', 0)
    n_curricula = json_data.get('n_curricula', 0)
    n_precedences = json_data.get('n_precedences', 0)
    n_undesirables = json_data.get('n_undesirables', 0)
    min_courses = json_data.get('min_courses', 0)
    max_courses = json_data.get('max_courses', 1)
    course_load = json_data.get('course_load', [])
    w1 = json_data.get('w1', 1)  # load balancing weight
    w2 = json_data.get('w2', 1)  # undesirable penalty weight
    
    # Parse additional data from DZN file (since JSON is incomplete)
    # Try to find corresponding DZN file in the same directory as MZN
    from pathlib import Path
    mzn_path = Path(mzn_file)
    dzn_files = list(mzn_path.parent.glob("*.dzn"))
    
    courses_of, precedes, undesirable = [], [], []
    
    # Try each DZN file to find one that matches our data size
    for dzn_file in dzn_files:
        try:
            temp_courses_of, temp_precedes, temp_undesirable = parse_dzn_data(dzn_file)
            # Basic validation - check if the parsed data makes sense
            if (len(temp_courses_of) == n_curricula and 
                len(temp_precedes) == n_precedences and
                len(temp_undesirable) == n_undesirables):
                courses_of = temp_courses_of
                precedes = temp_precedes  
                undesirable = temp_undesirable
                break
        except Exception:
            continue
    
    G = nx.Graph()
    
    # Course nodes (Type 0) - weighted by load and centrality
    max_load = max(course_load) if course_load else 1
    for i in range(n_courses):
        load = course_load[i] if i < len(course_load) else 1
        # Weight by normalized load (higher load = more important)
        load_weight = load / max_load
        # Count how many curricula this course belongs to (centrality)
        curricula_count = sum(1 for curriculum in courses_of if (i+1) in curriculum)
        centrality_weight = curricula_count / max(n_curricula, 1)
        # Combined weight using non-linear scaling
        weight = 0.6 * load_weight + 0.4 * centrality_weight
        G.add_node(f'course_{i+1}', type=0, weight=min(weight, 1.0))
    
    # Period nodes (Type 2) - weighted by scarcity
    period_capacity = max_courses - min_courses
    total_load = sum(course_load)
    avg_period_load = total_load / n_periods if n_periods > 0 else 0
    
    for p in range(n_periods):
        # Periods become more scarce if they need to handle more than average
        scarcity = min(avg_period_load / max(period_capacity, 1), 1.0)
        G.add_node(f'period_{p+1}', type=2, weight=scarcity)
    
    # Curriculum constraint nodes (Type 1) - one per curriculum
    for c in range(n_curricula):
        if c < len(courses_of):
            curriculum_courses = courses_of[c]
            curriculum_load = sum(course_load[i-1] for i in curriculum_courses if i-1 < len(course_load))
            ideal_load_per_period = curriculum_load / n_periods if n_periods > 0 else 0
            # Tightness based on how hard it is to balance this curriculum
            max_deviation = abs(curriculum_load - ideal_load_per_period * n_periods)
            tightness = min(max_deviation / max(curriculum_load, 1), 1.0)
            G.add_node(f'curriculum_{c+1}', type=1, weight=0.3 + 0.7 * tightness)
        else:
            G.add_node(f'curriculum_{c+1}', type=1, weight=0.5)
    
    # Precedence constraint nodes (Type 1) - one per precedence
    for i, (course1, course2) in enumerate(precedes):
        # Weight by combined load of constrained courses
        load1 = course_load[course1-1] if course1-1 < len(course_load) else 1
        load2 = course_load[course2-1] if course2-1 < len(course_load) else 1
        combined_load = (load1 + load2) / (2 * max_load)
        G.add_node(f'precedence_{i+1}', type=1, weight=0.5 + 0.5 * combined_load)
    
    # Undesirable constraint nodes (Type 1) - weighted by penalty
    undesirable_weight = w2 / max(w1 + w2, 1)  # Relative importance of undesirable penalties
    for i, (course, period) in enumerate(undesirable):
        load = course_load[course-1] if course-1 < len(course_load) else 1
        load_factor = load / max_load
        G.add_node(f'undesirable_{i+1}', type=1, 
                   weight=0.2 + 0.8 * undesirable_weight * load_factor)
    
    # Edges: Course-Curriculum participation
    for c in range(n_curricula):
        if c < len(courses_of):
            curriculum_courses = courses_of[c]
            for course_id in curriculum_courses:
                if 1 <= course_id <= n_courses:
                    load = course_load[course_id-1] if course_id-1 < len(course_load) else 1
                    participation_strength = load / max_load
                    G.add_edge(f'course_{course_id}', f'curriculum_{c+1}', 
                             weight=0.3 + 0.7 * participation_strength)
    
    # Edges: Course-Precedence participation
    for i, (course1, course2) in enumerate(precedes):
        if 1 <= course1 <= n_courses:
            G.add_edge(f'course_{course1}', f'precedence_{i+1}', weight=0.8)
        if 1 <= course2 <= n_courses:
            G.add_edge(f'course_{course2}', f'precedence_{i+1}', weight=0.8)
    
    # Edges: Course-Undesirable participation
    for i, (course, period) in enumerate(undesirable):
        if 1 <= course <= n_courses:
            G.add_edge(f'course_{course}', f'undesirable_{i+1}', weight=1.0)
    
    # Edges: Course-Period assignment possibilities
    # All courses can potentially be assigned to all periods
    for i in range(n_courses):
        for p in range(n_periods):
            # Weight by how much this assignment would stress the period
            load = course_load[i] if i < len(course_load) else 1
            stress = load / max(period_capacity, 1)
            G.add_edge(f'course_{i+1}', f'period_{p+1}', 
                     weight=min(0.2 + 0.8 * stress, 1.0))
    
    # Add conflict edges for courses that compete for limited period slots
    # Focus on high-load courses that might create bottlenecks
    high_load_courses = [(i+1, load) for i, load in enumerate(course_load) if load > max_load * 0.7]
    for i, (course1, load1) in enumerate(high_load_courses):
        for j, (course2, load2) in enumerate(high_load_courses):
            if i < j:
                # Check if they belong to same curriculum (more likely to conflict)
                same_curriculum = False
                for curriculum in courses_of:
                    if course1 in curriculum and course2 in curriculum:
                        same_curriculum = True
                        break
                
                if same_curriculum:
                    conflict_strength = (load1 + load2) / (2 * max_load)
                    G.add_edge(f'course_{course1}', f'course_{course2}', 
                             weight=0.3 + 0.7 * conflict_strength)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()