#!/usr/bin/env python3
"""
Graph converter for jp-encoding problem.
Created using subagent_prompt.md version: v_02

This problem is about recovering original character encodings from mixed byte streams.
Key challenges: Distinguishing between EUC-JP, SJIS, and UTF-8 encodings based on byte patterns
and character frequency statistics. The solver must assign encoding types while minimizing
encoding-specific penalty scores.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the jp-encoding problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph with byte position nodes and encoding constraint nodes.
    - Each byte position is a variable node (type 0) with weight based on encoding ambiguity
    - Encoding constraints are explicit constraint nodes (type 1) modeling the state machine
    - Scoring constraints weight bytes based on encoding probability scores
    - Sequential dependencies capture multi-byte character constraints
    """
    
    # Extract data
    length = json_data.get('len', 0)
    stream = json_data.get('stream', [])
    
    # Ensure we have valid data
    if length == 0 or not stream:
        G = nx.Graph()
        G.add_node('dummy', type=0, weight=0.5)
        return G
    
    # Hardcoded encoding scores from MZN model (normalized to [0,1])
    # These represent -log(probability) * 10, higher = less likely
    max_score = 139  # Maximum penalty score across all encodings
    
    G = nx.Graph()
    
    # Add byte position nodes (type 0 - variable-like)
    # Weight by encoding ambiguity - bytes that are valid in multiple encodings are harder
    for i in range(length):
        byte_val = stream[i] if i < len(stream) else 0
        
        # Calculate ambiguity: how many encodings could this byte belong to
        # ASCII: < 128
        # UTF-8: complex patterns 
        # EUC-JP: 161-254 second byte, specific first byte ranges
        # SJIS: specific ranges
        
        encoding_possibilities = 0
        
        # ASCII possibility
        if byte_val < 128:
            encoding_possibilities += 1
            
        # UTF-8 possibility (simplified heuristic)
        if 128 <= byte_val <= 191 or 194 <= byte_val <= 244:
            encoding_possibilities += 1
            
        # EUC-JP possibility
        if 161 <= byte_val <= 254 or (161 <= byte_val <= 252):
            encoding_possibilities += 1
            
        # SJIS possibility  
        if ((161 <= byte_val <= 223) or 
            (129 <= byte_val <= 159) or 
            (224 <= byte_val <= 252) or
            (64 <= byte_val <= 126)):
            encoding_possibilities += 1
        
        # Weight by ambiguity - more possibilities = harder to decide
        ambiguity_weight = min(encoding_possibilities / 4.0, 1.0)
        
        # Also consider position - later positions may have more context
        position_factor = 1.0 - math.exp(-i / (length * 0.1))  # Non-linear position weighting
        
        final_weight = (ambiguity_weight + position_factor) / 2.0
        
        G.add_node(f'byte_{i}', type=0, weight=min(final_weight, 1.0))
    
    # Add encoding constraint nodes (type 1 - constraint-like)
    
    # 1. ASCII constraints - one per ASCII byte
    ascii_count = sum(1 for b in stream if b < 128)
    if ascii_count > 0:
        ascii_weight = ascii_count / length  # More ASCII = simpler
        G.add_node('constraint_ascii', type=1, weight=ascii_weight)
    
    # 2. Multi-byte character constraints for each encoding
    # These model the state machine constraints for valid character sequences
    
    # UTF-8 constraints (2, 3, 4 byte sequences)
    utf8_sequences = 0
    i = 0
    while i < length - 1:
        byte_val = stream[i] if i < len(stream) else 0
        if 194 <= byte_val <= 223:  # 2-byte UTF-8 start
            utf8_sequences += 1
            i += 2
        elif 224 <= byte_val <= 239:  # 3-byte UTF-8 start
            utf8_sequences += 1
            i += 3
        elif 240 <= byte_val <= 247:  # 4-byte UTF-8 start
            utf8_sequences += 1
            i += 4
        else:
            i += 1
    
    if utf8_sequences > 0:
        utf8_density = utf8_sequences / (length / 3.0)  # Average bytes per sequence
        utf8_weight = min(utf8_density, 1.0)
        G.add_node('constraint_utf8_sequences', type=1, weight=utf8_weight)
    
    # EUC-JP constraints (2-byte sequences)
    eucjp_sequences = 0
    i = 0
    while i < length - 1:
        byte_val = stream[i] if i < len(stream) else 0
        if ((161 <= byte_val <= 168) or byte_val == 173 or 
            (176 <= byte_val <= 244) or (249 <= byte_val <= 252)):
            eucjp_sequences += 1
            i += 2
        else:
            i += 1
    
    if eucjp_sequences > 0:
        eucjp_density = eucjp_sequences / (length / 2.0)
        eucjp_weight = min(eucjp_density, 1.0)
        G.add_node('constraint_eucjp_sequences', type=1, weight=eucjp_weight)
    
    # SJIS constraints (1 and 2-byte sequences)
    sjis_sequences = 0
    i = 0
    while i < length - 1:
        byte_val = stream[i] if i < len(stream) else 0
        if 161 <= byte_val <= 223:  # 1-byte SJIS
            sjis_sequences += 1
            i += 1
        elif ((129 <= byte_val <= 159) or (224 <= byte_val <= 252)):  # 2-byte SJIS start
            sjis_sequences += 1
            i += 2
        else:
            i += 1
    
    if sjis_sequences > 0:
        sjis_density = sjis_sequences / (length / 1.5)  # Mixed 1 and 2 byte
        sjis_weight = min(sjis_density, 1.0)
        G.add_node('constraint_sjis_sequences', type=1, weight=sjis_weight)
    
    # 3. Scoring constraint nodes - one per encoding type
    # These represent the objective function penalties
    for encoding in ['eucjp', 'sjis', 'utf8']:
        # Calculate total penalty if all bytes were this encoding
        total_penalty = 0
        valid_bytes = 0
        
        for i in range(length):
            byte_val = stream[i] if i < len(stream) else 0
            if byte_val < 256:  # Valid byte
                if encoding == 'eucjp':
                    # Use simplified scoring - would need actual eucjp_score array
                    penalty = max_score - (byte_val % 100)  # Simplified heuristic
                elif encoding == 'sjis':
                    penalty = max_score - ((byte_val + 50) % 100)  # Simplified heuristic
                else:  # utf8
                    penalty = max_score - ((byte_val + 25) % 100)  # Simplified heuristic
                
                total_penalty += penalty
                valid_bytes += 1
        
        if valid_bytes > 0:
            avg_penalty = total_penalty / valid_bytes
            # Higher penalty = less likely encoding = higher constraint tightness
            penalty_weight = min(avg_penalty / max_score, 1.0)
            G.add_node(f'constraint_{encoding}_scoring', type=1, weight=penalty_weight)
    
    # 4. Global complexity constraint
    # Captures overall problem difficulty based on length and byte value diversity
    byte_diversity = len(set(stream)) / 256.0 if stream else 0
    complexity_weight = (math.log(length + 1) / math.log(1000)) * 0.7 + byte_diversity * 0.3
    G.add_node('constraint_global_complexity', type=1, weight=min(complexity_weight, 1.0))
    
    # Add bipartite edges: byte positions to constraints they participate in
    
    # ASCII constraint edges
    if 'constraint_ascii' in G:
        for i in range(length):
            byte_val = stream[i] if i < len(stream) else 0
            if byte_val < 128:
                # ASCII bytes strongly connected to ASCII constraint
                G.add_edge(f'byte_{i}', 'constraint_ascii', weight=0.9)
            else:
                # Non-ASCII bytes weakly connected (mutual exclusion)
                G.add_edge(f'byte_{i}', 'constraint_ascii', weight=0.1)
    
    # Multi-byte sequence constraint edges
    for constraint in ['constraint_utf8_sequences', 'constraint_eucjp_sequences', 'constraint_sjis_sequences']:
        if constraint in G:
            for i in range(length):
                byte_val = stream[i] if i < len(stream) else 0
                
                # Connect bytes that could be part of sequences for this encoding
                connection_strength = 0.5  # Default
                
                if 'utf8' in constraint:
                    if 194 <= byte_val <= 247 or 128 <= byte_val <= 191:
                        connection_strength = 0.8
                elif 'eucjp' in constraint:
                    if 161 <= byte_val <= 254:
                        connection_strength = 0.8
                elif 'sjis' in constraint:
                    if ((161 <= byte_val <= 223) or (129 <= byte_val <= 159) or 
                        (224 <= byte_val <= 252) or (64 <= byte_val <= 126)):
                        connection_strength = 0.8
                
                G.add_edge(f'byte_{i}', constraint, weight=connection_strength)
    
    # Scoring constraint edges - all bytes connect to all scoring constraints
    for encoding in ['eucjp', 'sjis', 'utf8']:
        constraint_name = f'constraint_{encoding}_scoring'
        if constraint_name in G:
            for i in range(length):
                # Edge weight based on how "typical" this byte is for the encoding
                byte_val = stream[i] if i < len(stream) else 0
                
                if encoding == 'eucjp' and 161 <= byte_val <= 254:
                    edge_weight = 0.8
                elif encoding == 'sjis' and (161 <= byte_val <= 223 or 129 <= byte_val <= 252):
                    edge_weight = 0.8
                elif encoding == 'utf8' and (byte_val >= 128):
                    edge_weight = 0.8
                else:
                    edge_weight = 0.3  # Possible but less likely
                
                G.add_edge(f'byte_{i}', constraint_name, weight=edge_weight)
    
    # Global complexity edges - all bytes connected with exponential decay by position
    if 'constraint_global_complexity' in G:
        for i in range(length):
            # Later positions have weaker connection (more context available)
            position_decay = math.exp(-i / (length * 0.5))
            edge_weight = 0.3 + 0.4 * position_decay
            G.add_edge(f'byte_{i}', 'constraint_global_complexity', weight=edge_weight)
    
    # Add sequential dependency edges between adjacent byte positions
    # These capture the fact that multi-byte characters create dependencies
    for i in range(length - 1):
        byte_curr = stream[i] if i < len(stream) else 0
        byte_next = stream[i+1] if i+1 < len(stream) else 0
        
        # Strong connection if they could form a multi-byte sequence
        if ((194 <= byte_curr <= 247 and 128 <= byte_next <= 191) or  # UTF-8
            (161 <= byte_curr <= 252 and 161 <= byte_next <= 254) or  # EUC-JP
            ((129 <= byte_curr <= 159 or 224 <= byte_curr <= 252) and  # SJIS
             (64 <= byte_next <= 126 or 128 <= byte_next <= 252))):
            dependency_weight = 0.7
        else:
            dependency_weight = 0.2  # Weak sequential connection
        
        G.add_edge(f'byte_{i}', f'byte_{i+1}', weight=dependency_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()