#!/usr/bin/env python3
"""
Graph converter for parity-learning problem.
Created using subagent_prompt.md version: v_02

This problem is about learning the parity function from noisy samples.
Given Boolean input/output samples, find a subset of variables whose parity
matches most sample outputs, minimizing disagreement errors.

Key challenges: 
- High-dimensional Boolean space makes exhaustive search intractable
- Noise in samples creates uncertainty about correct parity function
- Combinatorial explosion of possible variable subsets to consider
- Sample distribution and error patterns affect difficulty
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the parity-learning problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Model as bipartite graph with variables and samples/constraints
    - Variable nodes (type 0): Boolean variables in parity function
    - Sample constraint nodes (type 1): Each sample creates a constraint
    - Error tolerance constraint (type 2): Global constraint on max errors
    - Connect variables to samples they participate in
    - Weight by importance and error potential
    """
    # Access data from json_data dict
    num_vars = json_data.get('num_vars', 0)
    num_samples = json_data.get('num_samples', 0)
    max_errors = json_data.get('max_errors', 0)
    
    # Note: sample_inputs and sample_outputs may be missing due to JSON conversion limitations
    sample_inputs = json_data.get('sample_inputs', [])
    sample_outputs = json_data.get('sample_outputs', [])
    
    if num_vars == 0 or num_samples == 0:
        # Fallback for missing data
        G = nx.Graph()
        G.add_node('dummy', type=0, weight=0.5)
        return G
    
    G = nx.Graph()
    
    # Variable nodes (type 0) - Boolean variables in parity function
    # Weight by potential impact: central variables are more constrained
    for v in range(num_vars):
        # Variables near middle indices tend to be more interconnected
        centrality_weight = 1.0 - abs(v - num_vars/2) / (num_vars/2) if num_vars > 1 else 0.5
        # Non-linear weighting for better discrimination
        weight = 0.3 + 0.7 * math.exp(-2.0 * abs(v - num_vars/2) / max(num_vars/2, 1))
        G.add_node(f'var_{v}', type=0, weight=weight)
    
    # Sample constraint nodes (type 1) - Each sample creates a parity constraint
    # Weight by potential for causing errors (harder samples)
    error_tolerance = max_errors / max(num_samples, 1) if num_samples > 0 else 0.0
    
    for s in range(num_samples):
        # Samples later in sequence might be harder/more complex
        complexity_factor = (s + 1) / num_samples
        # Weight reflects how critical this sample is for learning
        # Samples with higher complexity get higher weight as they're more constraining
        weight = 0.4 + 0.6 * complexity_factor
        G.add_node(f'sample_{s}', type=1, weight=weight)
    
    # Global error constraint node (type 1) - Models the max_errors constraint
    # This constraint becomes tighter as max_errors decreases relative to num_samples
    error_tightness = 1.0 - (max_errors / max(num_samples, 1)) if num_samples > 0 else 0.5
    G.add_node('error_constraint', type=1, weight=error_tightness)
    
    # Problem difficulty node (type 2) - Models overall learning complexity
    # Difficulty increases with more variables and stricter error bounds
    if num_vars > 0 and num_samples > 0:
        # Learning difficulty: more variables = exponentially harder
        var_complexity = 1.0 - math.exp(-num_vars / 10.0)
        # Sample sparsity: fewer samples relative to variable space = harder
        sample_density = min(num_samples / (2 ** min(num_vars, 10)), 1.0)
        # Error strictness: lower error tolerance = harder
        error_strictness = 1.0 - error_tolerance
        
        overall_difficulty = (var_complexity + error_strictness + (1.0 - sample_density)) / 3.0
        G.add_node('problem_difficulty', type=2, weight=overall_difficulty)
    
    # Bipartite edges: Variables participate in all sample constraints
    # Since each variable can appear in the parity of any sample
    for v in range(num_vars):
        for s in range(num_samples):
            # Edge weight represents strength of participation
            # Variables participate in all samples, but with varying importance
            # Weight by position and sample complexity
            var_importance = G.nodes[f'var_{v}']['weight']
            sample_importance = G.nodes[f'sample_{s}']['weight']
            participation_weight = (var_importance + sample_importance) / 2.0
            G.add_edge(f'var_{v}', f'sample_{s}', weight=participation_weight)
    
    # Connect variables to global error constraint
    # All variables affect the global error count
    for v in range(num_vars):
        var_weight = G.nodes[f'var_{v}']['weight']
        # More important variables have stronger connection to error constraint
        error_impact = var_weight * error_tightness
        G.add_edge(f'var_{v}', 'error_constraint', weight=error_impact)
    
    # Connect problem difficulty to key constraints and variables
    if 'problem_difficulty' in G.nodes:
        # Connect to error constraint - tight error bounds increase difficulty
        G.add_edge('problem_difficulty', 'error_constraint', weight=error_tightness)
        
        # Connect to most critical variables (highest weighted)
        var_weights = [(f'var_{v}', G.nodes[f'var_{v}']['weight']) for v in range(num_vars)]
        var_weights.sort(key=lambda x: x[1], reverse=True)
        
        # Connect to top 25% of variables or at least 3
        num_critical = max(3, num_vars // 4)
        for var_name, var_weight in var_weights[:num_critical]:
            difficulty_impact = var_weight * overall_difficulty
            G.add_edge('problem_difficulty', var_name, weight=difficulty_impact)
        
        # Connect to most complex samples
        sample_weights = [(f'sample_{s}', G.nodes[f'sample_{s}']['weight']) for s in range(num_samples)]
        sample_weights.sort(key=lambda x: x[1], reverse=True)
        
        # Connect to top 25% of samples or at least 5
        num_critical_samples = max(5, num_samples // 4)
        for sample_name, sample_weight in sample_weights[:num_critical_samples]:
            difficulty_impact = sample_weight * overall_difficulty
            G.add_edge('problem_difficulty', sample_name, weight=difficulty_impact)
    
    # Add conflict edges between variables for highly constrained cases
    # Variables conflict when error tolerance is very low (overly constrained)
    if error_tolerance < 0.2 and num_vars > 1:
        # Create conflicts between variables that would make learning harder
        for v1 in range(min(num_vars, 10)):  # Limit to avoid too many edges
            for v2 in range(v1 + 1, min(num_vars, 10)):
                # Conflict strength based on how constraining this pair would be
                var1_weight = G.nodes[f'var_{v1}']['weight']
                var2_weight = G.nodes[f'var_{v2}']['weight']
                # Higher weight variables create stronger conflicts when over-constrained
                conflict_strength = (var1_weight + var2_weight) / 2.0 * (1.0 - error_tolerance)
                if conflict_strength > 0.5:  # Only add significant conflicts
                    G.add_edge(f'var_{v1}', f'var_{v2}', weight=conflict_strength)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()