#!/usr/bin/env python3
"""
Graph converter for roster_model problem.
Created using subagent_prompt.md version: v_02

This problem is about employee shift scheduling across multiple weeks.
Key challenges: balancing shift requirements while avoiding problematic patterns like evening-to-morning transitions and isolated rest days.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the roster scheduling problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Model as bipartite graph with days as variables and constraints as explicit nodes.
    - Day nodes (type 0): Each day in each week that needs a shift assignment
    - Requirement constraint nodes (type 1): Daily requirements for each shift type
    - Pattern constraint nodes (type 1): Rest pattern constraints (enough rest, no isolated rest, eve-morn violations)
    - Weights reflect constraint tightness and scheduling difficulty
    """
    # Access data from json_data dict
    weeks = json_data.get('weeks', 1)
    reqt_flat = json_data.get('reqt', [])
    minobj = json_data.get('minobj', 0)
    
    # Reconstruct 2D reqt array (5 shift types x 7 days)
    reqt = []
    for shift in range(5):
        week_reqs = []
        for day in range(7):
            idx = shift * 7 + day
            week_reqs.append(reqt_flat[idx] if idx < len(reqt_flat) else 0)
        reqt.append(week_reqs)
    
    total_days = weeks * 7
    
    G = nx.Graph()
    
    # Variable nodes: Each day-week position that needs a shift assignment
    for week in range(weeks):
        for day in range(7):
            day_id = week * 7 + day
            # Weight by centrality in schedule (middle days/weeks more constrained)
            week_centrality = 1.0 - abs(week - weeks/2) / (weeks/2 + 1)
            day_centrality = 1.0 - abs(day - 3) / 4.0  # Middle of week
            centrality = (week_centrality + day_centrality) / 2
            G.add_node(f'day_{week}_{day}', type=0, weight=centrality)
    
    # Constraint nodes for daily shift requirements
    constraint_id = 0
    for shift_type in range(5):  # Rest, Morn, Day, Eve, Joker
        shift_names = ['Rest', 'Morn', 'Day', 'Eve', 'Joker']
        for day_of_week in range(7):
            total_requirement = reqt[shift_type][day_of_week] * weeks
            if total_requirement > 0:
                # Weight by requirement tightness (higher requirements = more constrained)
                tightness = min(total_requirement / (weeks * 3.0), 1.0)  # Normalize by reasonable max
                G.add_node(f'req_{shift_names[shift_type]}_{day_of_week}', 
                          type=1, weight=tightness)
                constraint_id += 1
    
    # Pattern constraint nodes for scheduling rules
    
    # Enough rest constraints (every 7-day window needs at least 1 rest)
    for start_day in range(total_days):
        # Weight by how constrained this window is (fewer potential rest days = higher weight)
        window_pressure = min((start_day % 7 + 1) / 7.0, 1.0)  # Early in week = more pressure
        G.add_node(f'enough_rest_{start_day}', type=1, weight=0.7 + 0.3 * window_pressure)
    
    # Too much rest constraints (no more than 3 consecutive rest days)
    for start_day in range(total_days - 2):
        # Weight by position (middle of schedule more critical)
        position_weight = 1.0 - abs(start_day - total_days/2) / (total_days/2)
        G.add_node(f'max_rest_{start_day}', type=1, weight=0.6 + 0.4 * position_weight)
    
    # Evening-morning transition penalty constraints
    for day in range(total_days):
        next_day = (day + 1) % total_days
        # Weight by difficulty of transition (weekend transitions easier)
        day_of_week = day % 7
        weekend_factor = 0.8 if day_of_week >= 5 else 1.0  # Weekend transitions less problematic
        G.add_node(f'eve_morn_{day}', type=1, weight=0.8 * weekend_factor)
    
    # Isolated rest day penalty constraints
    for day in range(1, total_days - 1):
        # Weight by schedule position (middle days more critical)
        isolation_criticality = 1.0 - abs(day - total_days/2) / (total_days/2)
        G.add_node(f'isolated_rest_{day}', type=1, weight=0.7 + 0.3 * isolation_criticality)
    
    # Bipartite edges: Day variables to requirement constraints
    for week in range(weeks):
        for day in range(7):
            day_var = f'day_{week}_{day}'
            
            # Connect to daily requirement constraints for all shift types
            for shift_type in range(5):
                shift_names = ['Rest', 'Morn', 'Day', 'Eve', 'Joker']
                req_node = f'req_{shift_names[shift_type]}_{day}'
                if req_node in G:
                    # Weight by how much this day contributes to meeting the requirement
                    requirement = reqt[shift_type][day]
                    contribution = 1.0 / weeks if requirement > 0 else 0.3
                    G.add_edge(day_var, req_node, weight=contribution)
    
    # Connect days to pattern constraints
    
    # Enough rest pattern edges
    for start_day in range(total_days):
        rest_constraint = f'enough_rest_{start_day}'
        for offset in range(7):  # 7-day window
            day_idx = (start_day + offset) % total_days
            week = day_idx // 7
            day = day_idx % 7
            if week < weeks:  # Valid day
                day_var = f'day_{week}_{day}'
                # Weight by how critical this day is for satisfying the constraint
                window_position = offset / 6.0
                criticality = 1.0 - abs(window_position - 0.5) * 0.4  # Middle days slightly more important
                G.add_edge(day_var, rest_constraint, weight=criticality)
    
    # Too much rest pattern edges
    for start_day in range(total_days - 2):
        max_rest_constraint = f'max_rest_{start_day}'
        for offset in range(4):  # 4-day window
            day_idx = start_day + offset
            week = day_idx // 7
            day = day_idx % 7
            if week < weeks:  # Valid day
                day_var = f'day_{week}_{day}'
                # Equal weight for all days in the constraint
                G.add_edge(day_var, max_rest_constraint, weight=0.8)
    
    # Evening-morning transition edges
    for day in range(total_days):
        next_day = (day + 1) % total_days
        week1, day1 = day // 7, day % 7
        week2, day2 = next_day // 7, next_day % 7
        
        if week1 < weeks and week2 < weeks:
            eve_morn_constraint = f'eve_morn_{day}'
            day_var1 = f'day_{week1}_{day1}'
            day_var2 = f'day_{week2}_{day2}'
            # Both days involved in the transition constraint
            G.add_edge(day_var1, eve_morn_constraint, weight=0.9)
            G.add_edge(day_var2, eve_morn_constraint, weight=0.9)
    
    # Isolated rest day edges
    for day in range(1, total_days - 1):
        week = day // 7
        day_of_week = day % 7
        if week < weeks:
            isolated_constraint = f'isolated_rest_{day}'
            day_var = f'day_{week}_{day_of_week}'
            # Connect the day itself (most important) and neighbors
            G.add_edge(day_var, isolated_constraint, weight=1.0)
            
            # Connect previous day
            prev_day = day - 1
            prev_week, prev_dow = prev_day // 7, prev_day % 7
            if prev_week < weeks:
                prev_day_var = f'day_{prev_week}_{prev_dow}'
                G.add_edge(prev_day_var, isolated_constraint, weight=0.7)
            
            # Connect next day
            next_day = day + 1
            next_week, next_dow = next_day // 7, next_day % 7
            if next_week < weeks:
                next_day_var = f'day_{next_week}_{next_dow}'
                G.add_edge(next_day_var, isolated_constraint, weight=0.7)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()