#!/usr/bin/env python3
"""
Graph converter for step1_aes problem.
Created using subagent_prompt.md version: v_02

This problem is about chosen key differential cryptanalysis for AES.
The goal is to find differential characteristics that minimize the number
of active S-boxes and key bytes. Key challenges: complex cryptographic
constraints, key schedule dependencies, and MixColumns interactions.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the AES differential cryptanalysis instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Model the AES rounds, state variables, and cryptographic constraints
    - State bytes are variables (deltaX, deltaY, deltaSR, deltaK)
    - Constraints include ARK (AddRoundKey), SR (ShiftRows), MC (MixColumns), KS (KeySchedule)
    - Weights based on position importance and round number
    - Higher rounds and key schedule positions get higher weights due to complexity
    """
    n = json_data.get('n', 3)  # Number of rounds
    objective = json_data.get('objective', 1)  # Objective value
    key_bits = json_data.get('KEY_BITS', 128)  # Key size
    
    # AES constants
    BLOCK_BITS = 128
    BC = BLOCK_BITS // 32  # Number of columns per round (4 for AES)
    KC = key_bits // 32    # Number of columns per key round
    
    G = nx.Graph()
    
    # Create state variable nodes for different AES transformations
    # Each byte position in each round is a variable
    
    # deltaX nodes (state after AddRoundKey) - Type 0
    for r in range(n):
        for j in range(BC):
            for i in range(4):  # 4 bytes per column
                # Weight based on round position and column centrality
                round_weight = (r + 1) / n  # Later rounds more important
                col_centrality = 1.0 - abs(j - BC//2) / (BC//2 + 1)  # Central columns more critical
                byte_weight = (i + 1) / 4  # Byte position within column
                
                weight = (round_weight + col_centrality + byte_weight) / 3
                G.add_node(f'deltaX_{r}_{j}_{i}', type=0, weight=weight)
    
    # deltaY nodes (state before AddRoundKey) - Type 0  
    for r in range(n-1):  # n-1 rounds for deltaY
        for j in range(BC):
            for i in range(4):
                round_weight = (r + 1) / (n-1) if n > 1 else 1.0
                col_centrality = 1.0 - abs(j - BC//2) / (BC//2 + 1)
                byte_weight = (i + 1) / 4
                
                weight = (round_weight + col_centrality + byte_weight) / 3
                G.add_node(f'deltaY_{r}_{j}_{i}', type=0, weight=weight)
    
    # deltaSR nodes (state after ShiftRows) - Type 0
    for r in range(n):
        for j in range(BC):
            for i in range(4):
                round_weight = (r + 1) / n
                # ShiftRows affects certain positions more
                shift_impact = 1.0 - (abs((j + i) % BC - j) / BC)
                byte_weight = (i + 1) / 4
                
                weight = (round_weight + shift_impact + byte_weight) / 3
                G.add_node(f'deltaSR_{r}_{j}_{i}', type=0, weight=weight)
    
    # deltaK nodes (key differences) - Type 0
    for r in range(n):
        for j in range(BC):
            for i in range(4):
                round_weight = (r + 1) / n
                # Key schedule columns have different importance
                key_col_importance = 1.0 if (r * BC + j) % KC == KC - 1 else 0.7
                byte_weight = (i + 1) / 4
                
                weight = (round_weight + key_col_importance + byte_weight) / 3
                G.add_node(f'deltaK_{r}_{j}_{i}', type=0, weight=weight)
    
    # Create constraint nodes (Type 1) for AES operations
    
    # AddRoundKey (ARK) constraints - one per byte position per round
    for r in range(1, n):  # ARK constraints for rounds 1 to n-1
        for j in range(BC):
            for i in range(4):
                # ARK constraints are critical for later rounds
                constraint_weight = math.sqrt((r + 1) / n)  # Non-linear scaling
                G.add_node(f'ARK_{r}_{j}_{i}', type=1, weight=constraint_weight)
    
    # ShiftRows (SR) constraints - one per byte position per round
    for r in range(n):
        for j in range(BC):
            for i in range(4):
                # SR constraints have uniform importance but vary by position
                shift_complexity = 1.0 - (i * 0.1)  # Higher byte indices slightly less complex
                constraint_weight = math.exp(-0.5 * i / 4) * 0.8  # Exponential decay
                G.add_node(f'SR_{r}_{j}_{i}', type=1, weight=constraint_weight)
    
    # MixColumns (MC) constraints - one per column per round
    for r in range(n-1):  # MC operates on n-1 rounds
        for j in range(BC):
            # MC constraints are very important - they provide diffusion
            mc_importance = 0.9  # High base importance
            col_weight = 1.0 - abs(j - BC//2) / (BC//2 + 1)  # Central columns more critical
            constraint_weight = mc_importance * col_weight
            G.add_node(f'MC_{r}_{j}', type=1, weight=constraint_weight)
    
    # Key Schedule (KS) constraints - complex interactions
    for r in range(n):
        for j in range(BC):
            # Key schedule constraints are particularly complex
            if (r * BC + j) % KC == 0:  # SBox positions in key schedule
                constraint_weight = 0.95  # Very high importance
            else:
                constraint_weight = 0.7   # Still important but less critical
            G.add_node(f'KS_{r}_{j}', type=1, weight=constraint_weight)
    
    # Add edges to model relationships
    
    # ARK constraint edges: deltaY + deltaK = deltaX
    for r in range(1, n):
        for j in range(BC):
            for i in range(4):
                ark_constraint = f'ARK_{r}_{j}_{i}'
                # Connect to variables participating in this ARK operation
                if r > 0:  # deltaY exists for r-1
                    G.add_edge(f'deltaY_{r-1}_{j}_{i}', ark_constraint, weight=0.9)
                G.add_edge(f'deltaK_{r}_{j}_{i}', ark_constraint, weight=0.9)
                G.add_edge(f'deltaX_{r}_{j}_{i}', ark_constraint, weight=0.9)
    
    # ShiftRows constraint edges
    for r in range(n):
        for j in range(BC):
            for i in range(4):
                sr_constraint = f'SR_{r}_{j}_{i}'
                # ShiftRows permutes positions
                shifted_j = (j + i) % BC
                G.add_edge(f'deltaX_{r}_{shifted_j}_{i}', sr_constraint, weight=0.8)
                G.add_edge(f'deltaSR_{r}_{j}_{i}', sr_constraint, weight=0.8)
    
    # MixColumns constraint edges - connects column variables
    for r in range(n-1):
        for j in range(BC):
            mc_constraint = f'MC_{r}_{j}'
            # MixColumns operates on entire columns
            for i in range(4):
                G.add_edge(f'deltaSR_{r}_{j}_{i}', mc_constraint, weight=0.85)
                G.add_edge(f'deltaY_{r}_{j}_{i}', mc_constraint, weight=0.85)
    
    # Key Schedule constraint edges - model key dependencies
    for r in range(n):
        for j in range(BC):
            ks_constraint = f'KS_{r}_{j}'
            # Connect current key bytes
            for i in range(4):
                G.add_edge(f'deltaK_{r}_{j}_{i}', ks_constraint, weight=0.7)
            
            # Connect to previous round key if exists
            if r > 0:
                prev_j = (j - 1) % BC if j > 0 else BC - 1
                for i in range(4):
                    weight = 0.8 if (r * BC + j) % KC == 0 else 0.6  # SBox positions stronger
                    G.add_edge(f'deltaK_{r-1}_{prev_j}_{i}', ks_constraint, weight=weight)
    
    # Add some key variable-to-variable conflicts for oversubscribed objective
    # Based on the objective constraint that sums active S-boxes and key bytes
    if objective > 0:
        # Connect variables that contribute to the same objective terms
        objective_density = min(objective / (n * BC * 4), 1.0)  # Normalize by max possible
        
        # Key variables in objective (KC-1 positions)
        key_objective_vars = []
        for r in range(n):
            for j in range(BC):
                if (r * BC + j) % KC == KC - 1:
                    for i in range(4):
                        key_objective_vars.append(f'deltaK_{r}_{j}_{i}')
        
        # Add conflicts between key objective variables if density is high
        if objective_density > 0.5:
            for idx1 in range(min(len(key_objective_vars), 8)):
                for idx2 in range(idx1 + 1, min(len(key_objective_vars), 8)):
                    conflict_weight = objective_density * 0.6
                    G.add_edge(key_objective_vars[idx1], key_objective_vars[idx2], 
                             weight=conflict_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()