#!/usr/bin/env python3
"""
Graph converter for sparse_mds problem.
Created using subagent_prompt.md version: v_02

This problem is about building sparse decision trees for binary classification.
The goal is to find a minimal decision tree that classifies items with few misclassifications.
Key challenges: balancing tree size vs accuracy, feature selection, data sparsity.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the sparse MDS problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Model the decision tree construction problem as a bipartite graph
    - Variable nodes (type 0): data items, decision tree nodes, features
    - Constraint nodes (type 1): classification constraints, tree structure constraints
    - Resource nodes (type 2): available features
    
    Key insight: The difficulty comes from data distribution, feature informativeness,
    and the tension between tree size and classification accuracy.
    """
    # Access data directly from json_data dict
    I = json_data.get('I', 0)  # number of data items
    db = json_data.get('db', [])  # flattened database matrix
    features = json_data.get('FEATURE', [])  # available features
    
    # Calculate dimensions - db is I×F matrix flattened
    F = len(features) if features else 0
    n_tree_nodes = 5  # fixed in MiniZinc model
    
    # Create graph
    G = nx.Graph()
    
    if I == 0 or F == 0:
        # Degenerate case - add minimal structure
        G.add_node('dummy', type=0, weight=0.5)
        return G
    
    # Add data item nodes (type 0) with informativeness weights
    max_entropy = 0
    item_entropies = []
    
    for i in range(I):
        # Calculate feature diversity for this item
        start_idx = i * F
        end_idx = start_idx + F
        item_features = db[start_idx:end_idx] if end_idx <= len(db) else [0] * F
        
        # Entropy-based weight: items with balanced features are harder to classify
        feature_sum = sum(item_features)
        if F > 0 and 0 < feature_sum < F:
            p = feature_sum / F
            entropy = -p * math.log2(p) - (1-p) * math.log2(1-p)
        else:
            entropy = 0.0
        
        item_entropies.append(entropy)
        max_entropy = max(max_entropy, entropy)
    
    # Normalize and add item nodes
    for i in range(I):
        weight = item_entropies[i] / max_entropy if max_entropy > 0 else 0.5
        G.add_node(f'item_{i}', type=0, weight=weight)
    
    # Add decision tree node variables (type 0) with position-based weights
    for j in range(n_tree_nodes):
        # Root and early nodes are more critical
        criticality = math.exp(-j / 2.0)  # Exponential decay from root
        G.add_node(f'tree_node_{j}', type=0, weight=criticality)
    
    # Add feature nodes (type 2) with informativeness weights
    feature_info = []
    for f_idx, f in enumerate(features):
        if f_idx < F:
            # Calculate how informative this feature is across all items
            feature_values = []
            for i in range(I):
                start_idx = i * F
                if start_idx + f_idx < len(db):
                    feature_values.append(db[start_idx + f_idx])
            
            # Information content based on balance (50/50 split is most informative)
            if feature_values:
                ones = sum(feature_values)
                zeros = len(feature_values) - ones
                total = len(feature_values)
                if total > 0 and ones > 0 and zeros > 0:
                    p = ones / total
                    info = -p * math.log2(p) - (1-p) * math.log2(1-p)
                else:
                    info = 0.0
            else:
                info = 0.0
            
            feature_info.append(info)
    
    max_info = max(feature_info) if feature_info else 1.0
    for f_idx, f in enumerate(features):
        if f_idx < len(feature_info):
            weight = feature_info[f_idx] / max_info if max_info > 0 else 0.5
        else:
            weight = 0.5
        G.add_node(f'feature_{f}', type=2, weight=weight)
    
    # Add classification constraint nodes (type 1)
    # One constraint per item requiring correct classification
    for i in range(I):
        # Weight by item complexity (items with balanced features are harder)
        weight = item_entropies[i] / max_entropy if max_entropy > 0 else 0.5
        G.add_node(f'classify_{i}', type=1, weight=weight)
    
    # Add tree structure constraints (type 1)
    # Constraints for valid tree structure, feature assignment, etc.
    for j in range(n_tree_nodes):
        # Root constraint is most critical
        criticality = math.exp(-j / 2.0)
        G.add_node(f'tree_structure_{j}', type=1, weight=criticality)
    
    # Add tree size constraint (type 1) - global constraint on sparsity
    sparsity_pressure = min(1.0, I / (50.0 * n_tree_nodes))  # More pressure for larger datasets
    G.add_node('sparsity_constraint', type=1, weight=sparsity_pressure)
    
    # Add edges - bipartite structure
    
    # Items participate in their classification constraints
    for i in range(I):
        G.add_edge(f'item_{i}', f'classify_{i}', weight=1.0)
    
    # Tree nodes participate in structure constraints
    for j in range(n_tree_nodes):
        G.add_edge(f'tree_node_{j}', f'tree_structure_{j}', weight=1.0)
        # Also connect to sparsity constraint
        usage_weight = math.exp(-j / 2.0)  # Earlier nodes matter more for sparsity
        G.add_edge(f'tree_node_{j}', 'sparsity_constraint', weight=usage_weight)
    
    # Features can be used by tree nodes (feature selection edges)
    for f_idx, f in enumerate(features):
        for j in range(n_tree_nodes):
            # Connection strength based on feature informativeness and node position
            if f_idx < len(feature_info):
                feature_quality = feature_info[f_idx] / max_info if max_info > 0 else 0.5
            else:
                feature_quality = 0.5
            node_importance = math.exp(-j / 2.0)
            weight = (feature_quality + node_importance) / 2.0
            G.add_edge(f'feature_{f}', f'tree_node_{j}', weight=weight)
    
    # Items are classified by tree nodes through the decision path
    # Add edges representing potential classification relationships
    for i in range(I):
        for j in range(n_tree_nodes):
            # Weight by how well positioned this tree node is for this item
            # Root nodes can classify any item, deeper nodes are more specialized
            classification_strength = math.exp(-j / 3.0)
            G.add_edge(f'item_{i}', f'tree_node_{j}', weight=classification_strength)
    
    # Add some conflict edges between items with different feature patterns
    # This captures the challenge of finding splits that separate different classes
    if I > 1 and F > 0:
        conflict_threshold = 0.7  # Only add conflicts for significantly different items
        for i1 in range(min(I, 20)):  # Limit to avoid too many edges
            for i2 in range(i1 + 1, min(I, 20)):
                # Calculate feature difference
                start1, end1 = i1 * F, (i1 + 1) * F
                start2, end2 = i2 * F, (i2 + 1) * F
                
                if end1 <= len(db) and end2 <= len(db):
                    features1 = db[start1:end1]
                    features2 = db[start2:end2]
                    
                    # Hamming distance as conflict measure
                    differences = sum(1 for a, b in zip(features1, features2) if a != b)
                    if F > 0:
                        conflict = differences / F
                        if conflict > conflict_threshold:
                            G.add_edge(f'item_{i1}', f'item_{i2}', weight=conflict)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()