#!/usr/bin/env python3
"""
Graph converter for Vaccine Trial Optimization problem.
Created using subagent_prompt.md version: v_02

This problem is about optimizing assignment of demographic groups to vaccine trials.
Key challenges: balancing demographic requirements, meeting minimum group sizes,
ensuring fair distribution across vaccines while maximizing information value.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the vaccine trial optimization problem.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create bipartite graph with groups as variables and constraints as explicit nodes.
    - Group nodes (type 0): Represent demographic groups with characteristics
    - Constraint nodes (type 1): Represent different types of constraints
    - Vaccine nodes (type 2): Represent vaccine resources being allocated
    - Edge weights reflect constraint tightness and group importance
    """
    # Extract data from JSON
    m = json_data.get('m', 0)  # number of groups
    minsize = json_data.get('minsize', 1)
    max_people_diff = json_data.get('max_people_diff', 0)
    max_share_vaccines = json_data.get('max_share_vaccines', 0)
    size = json_data.get('size', [])
    age_group_min = json_data.get('age_group_min', [])
    age_group_max = json_data.get('age_group_max', [])
    health_information = json_data.get('health_information', [])
    exposure_information = json_data.get('exposure_information', [])
    
    # Infer number of vaccines from constraint structure (typically 3-7 vaccines)
    # Since we don't have explicit vaccine count, estimate from problem scale
    n_vaccines = max(3, min(7, m // 3))  # Conservative estimate
    
    G = nx.Graph()
    
    # Calculate statistics for normalization
    total_population = sum(size) if size else 1
    max_group_size = max(size) if size else 1
    avg_group_size = total_population / m if m > 0 else 1
    
    # Add group nodes (type 0) - the decision variables
    for g in range(m):
        group_size = size[g] if g < len(size) else avg_group_size
        
        # Weight based on group size and potential vaccine capacity
        max_vaccines_for_group = max(1, group_size // minsize) if minsize > 0 else 1
        size_importance = group_size / max_group_size
        
        # Groups with more vaccine options are more flexible (lower weight = easier)
        flexibility = math.exp(-2.0 * max_vaccines_for_group / n_vaccines)
        weight = 0.3 + 0.7 * (size_importance * flexibility)
        
        G.add_node(f'group_{g}', type=0, weight=min(weight, 1.0))
    
    # Add vaccine resource nodes (type 2)
    for v in range(n_vaccines):
        # All vaccines are equally important resources
        G.add_node(f'vaccine_{v}', type=2, weight=0.8)
    
    # Add constraint nodes (type 1)
    
    # 1. Minimum size constraints for each group-vaccine assignment
    constraint_id = 0
    for g in range(m):
        group_size = size[g] if g < len(size) else avg_group_size
        max_assignments = max(1, group_size // minsize) if minsize > 0 else 1
        
        # Tightness based on how restrictive the minimum size is
        tightness = 1.0 - min(1.0, max_assignments / n_vaccines)
        G.add_node(f'minsize_constraint_{g}', type=1, weight=0.4 + 0.6 * tightness)
        
        # Connect to group
        G.add_edge(f'group_{g}', f'minsize_constraint_{g}', weight=0.8)
        constraint_id += 1
    
    # 2. Age distribution constraints (global cardinality constraints)
    n_age_groups = len(age_group_min)
    for age in range(n_age_groups):
        min_req = age_group_min[age] if age < len(age_group_min) else 0
        max_req = age_group_max[age] if age < len(age_group_max) else m
        
        if min_req > 0 or max_req < m:  # Only create constraint if it's restrictive
            # Tightness based on how narrow the range is
            range_ratio = (max_req - min_req) / m if m > 0 else 1.0
            tightness = 1.0 - range_ratio
            
            G.add_node(f'age_constraint_{age}', type=1, weight=0.3 + 0.7 * tightness)
            
            # Connect to all groups of this age (we don't have age data in JSON, so connect to all)
            # In practice, this would connect to groups where age[g] == age
            for g in range(m):
                # Weight by likelihood this group is of this age
                age_probability = 1.0 / n_age_groups  # Uniform assumption
                G.add_edge(f'group_{g}', f'age_constraint_{age}', 
                          weight=age_probability * 0.6)
    
    # 3. People difference constraints between vaccines
    if max_people_diff > 0:
        # Global constraint on population balance
        balance_tightness = 1.0 - (max_people_diff / (total_population / n_vaccines))
        balance_tightness = max(0.0, min(1.0, balance_tightness))
        
        G.add_node('people_balance_constraint', type=1, weight=0.5 + 0.5 * balance_tightness)
        
        # Connect to all groups since they all affect population balance
        for g in range(m):
            group_size = size[g] if g < len(size) else avg_group_size
            impact = group_size / total_population
            G.add_edge(f'group_{g}', 'people_balance_constraint', weight=impact)
    
    # 4. Gender balance constraints 
    # Each vaccine should have same number of groups per gender
    G.add_node('gender_balance_constraint', type=1, weight=0.7)
    for g in range(m):
        G.add_edge(f'group_{g}', 'gender_balance_constraint', weight=0.5)
    
    # 5. Shared vaccine constraints between groups
    if max_share_vaccines > 0 and max_share_vaccines < n_vaccines:
        sharing_tightness = 1.0 - (max_share_vaccines / n_vaccines)
        
        # Create sharing constraints for pairs of groups
        for g1 in range(min(m, 10)):  # Limit to avoid too many constraints
            for g2 in range(g1 + 1, min(m, 10)):
                constraint_name = f'sharing_constraint_{g1}_{g2}'
                G.add_node(constraint_name, type=1, weight=0.4 + 0.6 * sharing_tightness)
                
                # Connect to both groups
                G.add_edge(f'group_{g1}', constraint_name, weight=0.7)
                G.add_edge(f'group_{g2}', constraint_name, weight=0.7)
    
    # Add edges between groups and vaccines (potential assignments)
    for g in range(m):
        for v in range(n_vaccines):
            # Edge weight represents assignment desirability
            group_size = size[g] if g < len(size) else avg_group_size
            
            # Larger groups are more valuable for trials
            size_value = group_size / max_group_size
            
            # Add some randomness based on group index for diversity
            diversity_factor = 0.5 + 0.5 * math.sin(g * v + 1)
            
            assignment_weight = 0.3 + 0.4 * size_value + 0.3 * diversity_factor
            G.add_edge(f'group_{g}', f'vaccine_{v}', weight=min(assignment_weight, 1.0))
    
    # Add information value complexity node (global constraint)
    if health_information and exposure_information:
        max_health_info = max(health_information)
        max_exposure_info = max(exposure_information)
        max_combined_info = max_health_info * max_exposure_info
        
        # This constraint is complex because it involves the optimization objective
        G.add_node('information_optimization', type=1, weight=0.9)
        
        # Connect to all groups with weights based on their potential information value
        for g in range(m):
            # Assume uniform distribution of health/exposure for weighting
            avg_info_potential = (max_health_info * max_exposure_info) / 4
            info_weight = min(1.0, avg_info_potential / max_combined_info)
            G.add_edge(f'group_{g}', 'information_optimization', weight=0.5 + 0.5 * info_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()