#!/usr/bin/env python3
"""
Graph converter for step1_aes problem.
Created using subagent_prompt.md version: v_02

This problem is about chosen key differential cryptanalysis for AES.
It models the propagation of differences through multiple rounds of AES encryption.
Key challenges: Managing complex relationships between state differences, key differences,
and the MixColumns constraints across multiple rounds.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the AES differential cryptanalysis problem.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Model the AES rounds as a bipartite graph where:
    - Variable nodes represent state/key differences at different positions
    - Constraint nodes represent the various cryptographic operations (ARK, SR, MC, KS)
    - Edges connect variables to the constraints they participate in
    - Weights reflect the complexity and criticality of operations
    """
    # Access data from json_data
    n = json_data.get('n', 5)  # Number of rounds
    objective = json_data.get('objective', 0)  # Target objective value
    key_bits = json_data.get('KEY_BITS', 128)  # Key size
    
    # Calculate derived parameters (from MZN model)
    BLOCK_BITS = 128
    KC = key_bits // 32  # Number of columns per round of key schedule
    BC = BLOCK_BITS // 32  # Number of columns per round (always 4 for AES)
    NBK = KC + n * BC // KC  # Number of key components
    
    G = nx.Graph()
    
    # Add variable nodes for state differences
    # deltaY: State before ARK (n-1 rounds, 4 columns, 4 bytes each)
    for r in range(n-1):
        for j in range(BC):
            for i in range(4):
                # Weight based on round position - later rounds are more critical
                round_weight = (r + 1) / n
                G.add_node(f'deltaY_{r}_{j}_{i}', type=0, weight=round_weight)
    
    # deltaX: State after ARK (n rounds, 4 columns, 4 bytes each)
    for r in range(n):
        for j in range(BC):
            for i in range(4):
                # Weight based on round position and byte position
                round_weight = (r + 1) / n
                byte_weight = 0.7 + 0.3 * i / 3  # Slightly higher weight for later bytes
                G.add_node(f'deltaX_{r}_{j}_{i}', type=0, weight=round_weight * byte_weight)
    
    # deltaSR: State after ShiftRows (n rounds, 4 columns, 4 bytes each)
    for r in range(n):
        for j in range(BC):
            for i in range(4):
                round_weight = (r + 1) / n
                G.add_node(f'deltaSR_{r}_{j}_{i}', type=0, weight=round_weight)
    
    # deltaK: Key differences (n rounds, 4 columns, 4 bytes each)
    for r in range(n):
        for j in range(BC):
            for i in range(4):
                # Key positions have uniform high importance
                G.add_node(f'deltaK_{r}_{j}_{i}', type=0, weight=0.9)
    
    # Add constraint nodes for cryptographic operations
    
    # ARK constraints (Add Round Key) - one per XOR operation
    for r in range(1, n):
        for j in range(BC):
            for i in range(4):
                # ARK is a simple XOR, lower complexity
                G.add_node(f'ARK_{r}_{j}_{i}', type=1, weight=0.6)
    
    # SR constraints (ShiftRows) - one per permutation
    for r in range(n):
        for j in range(BC):
            for i in range(4):
                # ShiftRows is a permutation, medium complexity
                G.add_node(f'SR_{r}_{j}_{i}', type=1, weight=0.7)
    
    # MC constraints (MixColumns) - one per column operation
    for r in range(n-1):
        for j in range(BC):
            # MixColumns is the most complex operation due to MDS property
            # Weight based on the objective value impact
            complexity = 1.0 - math.exp(-objective / 20.0)  # Non-linear scaling
            G.add_node(f'MC_{r}_{j}', type=1, weight=complexity)
    
    # KS constraints (Key Schedule) - one per key schedule operation
    for J in range(KC, n * BC):
        r = J // BC
        j = J % BC
        for i in range(4):
            # Key schedule complexity depends on position
            if J % KC == 0:
                # SBox positions are more complex
                G.add_node(f'KS_SB_{r}_{j}_{i}', type=1, weight=0.9)
            else:
                # Regular XOR positions
                G.add_node(f'KS_XOR_{r}_{j}_{i}', type=1, weight=0.7)
    
    # EQ relation constraints for key equality relationships
    eq_count = 0
    for J in range(n * BC):
        for J2 in range(J + 1, n * BC):
            for i in range(4):
                # Equality constraints are crucial for propagation
                G.add_node(f'EQ_{eq_count}', type=1, weight=0.8)
                eq_count += 1
    
    # Add edges connecting variables to constraints
    
    # ARK edges: Connect deltaY, deltaK, deltaX to ARK constraints
    for r in range(1, n):
        for j in range(BC):
            for i in range(4):
                ark_node = f'ARK_{r}_{j}_{i}'
                # High weight for participation in XOR operations
                G.add_edge(f'deltaY_{r-1}_{j}_{i}', ark_node, weight=0.9)
                G.add_edge(f'deltaK_{r}_{j}_{i}', ark_node, weight=0.9)
                G.add_edge(f'deltaX_{r}_{j}_{i}', ark_node, weight=0.9)
    
    # SR edges: Connect deltaX to deltaSR via SR constraints
    for r in range(n):
        for j in range(BC):
            for i in range(4):
                sr_node = f'SR_{r}_{j}_{i}'
                # ShiftRows permutation
                shifted_j = (j + i) % BC
                G.add_edge(f'deltaX_{r}_{shifted_j}_{i}', sr_node, weight=0.8)
                G.add_edge(f'deltaSR_{r}_{j}_{i}', sr_node, weight=0.8)
    
    # MC edges: Connect deltaSR to deltaY via MC constraints
    for r in range(n-1):
        for j in range(BC):
            mc_node = f'MC_{r}_{j}'
            # MixColumns affects entire column
            for i in range(4):
                # Higher weight for participation in MDS constraint
                G.add_edge(f'deltaSR_{r}_{j}_{i}', mc_node, weight=1.0)
                G.add_edge(f'deltaY_{r}_{j}_{i}', mc_node, weight=1.0)
    
    # KS edges: Connect key variables to key schedule constraints
    for J in range(KC, n * BC):
        r = J // BC
        j = J % BC
        for i in range(4):
            if J % KC == 0:
                # SBox positions
                ks_node = f'KS_SB_{r}_{j}_{i}'
                # Connect to previous round key and current position
                prev_r = (J - 1) // BC
                prev_j = (J + BC - 1) % BC
                prev_i = (i + 1) % 4
                G.add_edge(f'deltaK_{prev_r}_{prev_j}_{prev_i}', ks_node, weight=0.9)
                G.add_edge(f'deltaK_{r}_{j}_{i}', ks_node, weight=0.9)
                # Connect to previous column
                prev_col_r = (J - KC) // BC
                prev_col_j = (J - KC) % BC
                G.add_edge(f'deltaK_{prev_col_r}_{prev_col_j}_{i}', ks_node, weight=0.8)
            else:
                # Regular XOR positions
                ks_node = f'KS_XOR_{r}_{j}_{i}'
                # Connect to previous round and previous position
                prev_r = (J - KC) // BC
                prev_j = (J - KC) % BC
                prev2_r = (J - 1) // BC
                prev2_j = (J + BC - 1) % BC
                G.add_edge(f'deltaK_{prev_r}_{prev_j}_{i}', ks_node, weight=0.8)
                G.add_edge(f'deltaK_{prev2_r}_{prev2_j}_{i}', ks_node, weight=0.8)
                G.add_edge(f'deltaK_{r}_{j}_{i}', ks_node, weight=0.8)
    
    # Add some conflict edges for highly contested operations
    # Connect variables that participate in the same MixColumns operation
    for r in range(n-1):
        for j in range(BC):
            sr_vars = [f'deltaSR_{r}_{j}_{i}' for i in range(4)]
            y_vars = [f'deltaY_{r}_{j}_{i}' for i in range(4)]
            
            # Add conflict edges between variables in the same column
            for i1 in range(4):
                for i2 in range(i1 + 1, 4):
                    # Non-linear weight based on MDS property impact
                    conflict_weight = 0.5 * (1.0 + math.exp(-objective / 15.0))
                    G.add_edge(sr_vars[i1], sr_vars[i2], weight=min(conflict_weight, 1.0))
                    G.add_edge(y_vars[i1], y_vars[i2], weight=min(conflict_weight, 1.0))
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()