#!/usr/bin/env python3
"""
Graph converter for mznc2017_aes_opt problem.
Created using subagent_prompt.md version: v_02

This problem is about finding optimal differential characteristics in AES encryption.
Key challenges: Complex S-box constraints via DDT, permutation dependencies, round structure.
The problem involves finding the best differential path through R rounds of AES.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the AES differential cryptanalysis problem.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Model the round-based structure of AES differential analysis
    - State variables (x, xp) as variable nodes
    - S-box constraints and permutation constraints as constraint nodes
    - Round-based connectivity reflecting the AES structure
    - Weights based on constraint complexity and round position
    """
    R = json_data.get('R', 1)
    bl = 64  # Block length from MZN file
    
    G = nx.Graph()
    
    # Variable nodes for state differences
    # x[r,j] variables (state before S-box)
    for r in range(R + 1):
        for j in range(bl):
            # Weight based on position criticality
            # Middle rounds are more critical for differential paths
            round_criticality = 1.0 - abs(r - R/2) / (R/2 + 1) if R > 1 else 0.8
            position_weight = 0.3 + 0.7 * round_criticality
            G.add_node(f'x_{r}_{j}', type=0, weight=position_weight)
    
    # xp[r,j] variables (state after permutation)
    for r in range(R):
        for j in range(bl):
            # Similar weighting as x variables
            round_criticality = 1.0 - abs(r - R/2) / (R/2 + 1) if R > 1 else 0.8
            position_weight = 0.3 + 0.7 * round_criticality
            G.add_node(f'xp_{r}_{j}', type=0, weight=position_weight)
    
    # prb[i] variables (S-box probability values)
    for i in range(16 * R):
        round_num = i // 16
        # Probability variables are crucial for optimization
        # Weight by round importance and constraint on values
        round_criticality = 1.0 - abs(round_num - R/2) / (R/2 + 1) if R > 1 else 0.8
        # First and last round prb variables have special constraints (!=3)
        if i < 16 or i >= (R-1)*16:
            constraint_weight = 0.9  # More constrained
        else:
            constraint_weight = 0.6
        weight = 0.4 + 0.6 * round_criticality * constraint_weight
        G.add_node(f'prb_{i}', type=0, weight=weight)
    
    # Constraint nodes
    
    # Permutation constraints (one per round, per position)
    for r in range(R):
        for j in range(bl):
            # Permutation constraints are structural
            weight = 0.7  # Fixed permutation - medium importance
            G.add_node(f'perm_constraint_{r}_{j}', type=1, weight=weight)
    
    # S-box constraints (DDT table lookups)
    # Each S-box operates on 4 bits, so we have 16 S-boxes per round
    for r in range(R):
        for sbox in range(16):
            # S-box constraints are the core cryptographic constraints
            # Weight by round position and constraint complexity
            round_criticality = 1.0 - abs(r - R/2) / (R/2 + 1) if R > 1 else 0.8
            # DDT table has 97 entries, complex constraint
            complexity_weight = 0.9
            weight = 0.5 + 0.5 * round_criticality * complexity_weight
            G.add_node(f'sbox_constraint_{r}_{sbox}', type=1, weight=weight)
    
    # Objective constraint (sum minimization)
    G.add_node('objective_constraint', type=1, weight=1.0)
    
    # Special first/last round constraints
    G.add_node('first_round_constraint', type=1, weight=0.8)
    G.add_node('last_round_constraint', type=1, weight=0.8)
    
    # Edges modeling variable-constraint relationships
    
    # Permutation edges: xp[r,P[j]] = x[r,j]
    P = [0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51,
         4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55,
         8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59,
         12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63]
    
    for r in range(R):
        for j in range(bl):
            perm_constraint = f'perm_constraint_{r}_{j}'
            # Connect x[r,j] and xp[r,P[j]] to permutation constraint
            G.add_edge(f'x_{r}_{j}', perm_constraint, weight=1.0)
            if P[j] < bl:  # Safety check
                G.add_edge(f'xp_{r}_{P[j]}', perm_constraint, weight=1.0)
    
    # S-box constraint edges
    for r in range(R):
        for sbox in range(16):
            sbox_constraint = f'sbox_constraint_{r}_{sbox}'
            # Each S-box constraint involves 4 xp variables, 4 x variables, and 1 prb
            for bit in range(4):
                j = 4 * sbox + bit
                if j < bl:
                    # Connect state variables to S-box constraint
                    G.add_edge(f'xp_{r}_{j}', sbox_constraint, weight=0.8)
                    if r + 1 <= R:
                        G.add_edge(f'x_{r+1}_{j}', sbox_constraint, weight=0.8)
            
            # Connect probability variable
            prb_idx = 16 * r + sbox
            if prb_idx < 16 * R:
                G.add_edge(f'prb_{prb_idx}', sbox_constraint, weight=1.0)
    
    # Objective constraint edges (connects all prb variables)
    for i in range(16 * R):
        G.add_edge(f'prb_{i}', 'objective_constraint', weight=0.9)
    
    # First and last round constraint edges
    for i in range(16):  # First round
        G.add_edge(f'prb_{i}', 'first_round_constraint', weight=0.9)
    
    for i in range((R-1)*16, R*16):  # Last round
        if i < 16 * R:
            G.add_edge(f'prb_{i}', 'last_round_constraint', weight=0.9)
    
    # Round-to-round dependencies via state variables
    for r in range(R):
        for j in range(bl):
            # State flows from round r to r+1
            if r + 1 <= R:
                # Non-linear weight based on round distance and position
                distance_factor = math.exp(-0.5 * abs(r - R/2)) if R > 1 else 1.0
                weight = 0.4 + 0.4 * distance_factor
                # Connect consecutive round states
                G.add_edge(f'xp_{r}_{j}', f'x_{r+1}_{j}', weight=weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()