#!/usr/bin/env python3
"""
Graph converter for BIBD (Balanced Incomplete Block Design) problem.
Created using subagent_prompt.md version: v_02

This problem is about constructing a binary incidence matrix where:
- v varieties (rows) each appear in exactly r blocks
- b blocks (columns) each contain exactly k varieties  
- Any two varieties appear together in exactly lambda blocks

Key challenges: The interdependence of row sums, column sums, and pairwise constraints
creates complex constraint interactions that make some parameter combinations impossible.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the BIBD problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph modeling the constraint structure:
    - Variable nodes for matrix cells (v×b boolean variables)
    - Constraint nodes for row sums, column sums, and pairwise dot products
    - Edge weights reflect constraint tightness and cell importance
    - Node weights reflect criticality based on problem parameters
    """
    # Extract parameters
    v = json_data.get('v', 0)  # number of varieties (rows)
    k = json_data.get('k', 0)  # block size (items per column)
    lambda_val = json_data.get('lambda', 0)  # pairwise intersection size
    
    # Calculate derived parameters (from MiniZinc model)
    if k == 0 or (k - 1) == 0:
        b = 0
        r = 0
    else:
        b = (lambda_val * v * (v - 1)) // (k * (k - 1))  # number of blocks (columns)
        r = (lambda_val * (v - 1)) // (k - 1)  # replication number (items per row)
    
    G = nx.Graph()
    
    # Variable nodes: matrix cells m[i,j] (type 0)
    # Weight by position criticality - central positions are more constrained
    for i in range(v):
        for j in range(b):
            # Cells are more critical if they're in central rows/columns
            # and if the problem parameters create tight constraints
            row_centrality = 1.0 - abs(i - v//2) / max(v//2, 1)
            col_centrality = 1.0 - abs(j - b//2) / max(b//2, 1)
            
            # Higher weight for more constrained problem instances
            constraint_density = lambda_val / max(k, 1) if k > 0 else 0
            
            cell_weight = (row_centrality + col_centrality) / 2 * (0.5 + 0.5 * constraint_density)
            G.add_node(f'cell_{i}_{j}', type=0, weight=min(cell_weight, 1.0))
    
    # Row sum constraints (type 1) - each row must sum to r
    for i in range(v):
        # Tightness: how difficult is it to achieve exactly r ones in this row
        # More difficult when r is close to b/2 (maximum entropy)
        if b > 0:
            row_tightness = 1.0 - abs(r - b/2) / max(b/2, 1)
        else:
            row_tightness = 0.5
        G.add_node(f'row_sum_{i}', type=1, weight=min(row_tightness, 1.0))
        
        # Connect to all cells in this row
        for j in range(b):
            # Edge weight reflects contribution to constraint satisfaction
            participation_weight = 1.0 / max(r, 1) if r > 0 else 0.5
            G.add_edge(f'cell_{i}_{j}', f'row_sum_{i}', weight=participation_weight)
    
    # Column sum constraints (type 1) - each column must sum to k
    for j in range(b):
        # Tightness: similar logic for columns
        if v > 0:
            col_tightness = 1.0 - abs(k - v/2) / max(v/2, 1)
        else:
            col_tightness = 0.5
        G.add_node(f'col_sum_{j}', type=1, weight=min(col_tightness, 1.0))
        
        # Connect to all cells in this column
        for i in range(v):
            participation_weight = 1.0 / max(k, 1) if k > 0 else 0.5
            G.add_edge(f'cell_{i}_{j}', f'col_sum_{j}', weight=participation_weight)
    
    # Pairwise dot product constraints (type 1) - most critical constraints
    pairwise_count = 0
    for i_a in range(v):
        for i_b in range(i_a + 1, v):
            # These are the most complex constraints - any two rows must intersect in exactly lambda positions
            # Weight by constraint complexity - harder when lambda is small relative to r
            if r > 0:
                complexity = math.exp(-2.0 * lambda_val / r)  # Exponential decay - smaller lambda is harder
            else:
                complexity = 0.8
            
            constraint_id = f'pairwise_{i_a}_{i_b}'
            G.add_node(constraint_id, type=1, weight=min(complexity, 1.0))
            pairwise_count += 1
            
            # Connect to all cells in both rows
            for j in range(b):
                # Each cell contributes to the pairwise intersection
                # Weight reflects the cell's importance in satisfying this constraint
                intersection_weight = 2.0 * lambda_val / max(r, 1) if r > 0 else 0.5
                G.add_edge(f'cell_{i_a}_{j}', constraint_id, weight=min(intersection_weight, 1.0))
                G.add_edge(f'cell_{i_b}_{j}', constraint_id, weight=min(intersection_weight, 1.0))
    
    # Add symmetry breaking constraint nodes (type 1) - lower weight as they're auxiliary
    # Row lexicographic ordering constraints
    for i in range(v - 1):
        lex_weight = 0.3  # Lower importance - these are for search efficiency
        G.add_node(f'row_lex_{i}', type=1, weight=lex_weight)
        
        # Connect to cells in consecutive rows
        for j in range(b):
            G.add_edge(f'cell_{i}_{j}', f'row_lex_{i}', weight=0.2)
            G.add_edge(f'cell_{i+1}_{j}', f'row_lex_{i}', weight=0.2)
    
    # Column lexicographic ordering constraints
    for j in range(b - 1):
        lex_weight = 0.3
        G.add_node(f'col_lex_{j}', type=1, weight=lex_weight)
        
        # Connect to cells in consecutive columns
        for i in range(v):
            G.add_edge(f'cell_{i}_{j}', f'col_lex_{j}', weight=0.2)
            G.add_edge(f'cell_{i}_{j+1}', f'col_lex_{j}', weight=0.2)
    
    # Add global complexity indicator (optional type 1 node)
    # This captures overall problem difficulty
    total_vars = v * b
    total_constraints = v + b + pairwise_count + (v-1) + (b-1)
    
    if total_vars > 0:
        constraint_ratio = total_constraints / total_vars
        # Problems with more constraints per variable are typically harder
        global_complexity = min(constraint_ratio / 2.0, 1.0)
        G.add_node('global_complexity', type=1, weight=global_complexity)
        
        # Connect to most critical cells (central positions)
        center_i, center_j = v//2, b//2
        radius = max(v, b) // 4
        for i in range(max(0, center_i - radius), min(v, center_i + radius + 1)):
            for j in range(max(0, center_j - radius), min(b, center_j + radius + 1)):
                distance = abs(i - center_i) + abs(j - center_j)
                if distance <= radius:
                    complexity_edge_weight = math.exp(-distance / max(radius, 1))
                    G.add_edge(f'cell_{i}_{j}', 'global_complexity', 
                             weight=min(complexity_edge_weight, 1.0))
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()