#!/usr/bin/env python3
"""
Graph converter for Harmony problem.
Created using subagent_prompt.md version: v_02

This problem is about four-part harmony composition where voices (Soprano, Alto, Tenor, Bass)
must follow musical rules while playing chord progressions that harmonize a given melody.
Key challenges: voice range constraints, avoiding parallel fifths/octaves, proper chord voicing,
cadence requirements, and melodic constraints across voices.
"""

import sys
import json
import math
import networkx as nx
import re
from pathlib import Path


def extract_melody_from_dzn(dzn_file):
    """Extract melody data from DZN file since JSON conversion doesn't handle enums well."""
    try:
        with open(dzn_file, 'r') as f:
            content = f.read()
        
        # Extract melody array
        melody_match = re.search(r'melody\s*=\s*\[(.*?)\]', content, re.DOTALL)
        if melody_match:
            melody_str = melody_match.group(1)
            # Extract pitch numbers from Pitch(n) format
            pitches = re.findall(r'Pitch\((\d+)\)', melody_str)
            melody = [int(p) for p in pitches]
        else:
            melody = []
        
        # Extract key (default to C if not found)
        key_match = re.search(r'key\s*=\s*(\w+)', content)
        key = key_match.group(1) if key_match else 'C'
        
        return melody, key
    except:
        return [], 'C'


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the harmony problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Model the four-part harmony as a bipartite graph with:
    - Voice-time position nodes (type 0) representing decision points
    - Musical constraint nodes (type 1) for voice ranges, intervals, cadences
    - Chord position nodes (type 2) representing harmonic resources
    
    Key aspects that make instances hard:
    - Length of melody (more time steps = more variables)
    - Cadence requirements (stricter = more constraints)
    - Voice range overlaps (tighter ranges = more conflicts)
    - Stationary limits (forcing movement)
    """
    
    # Get additional data from DZN file if available
    dzn_file = str(mzn_file).replace('.mzn', '.dzn')
    melody, key = extract_melody_from_dzn(dzn_file)
    
    # Get constraint parameters from JSON
    min_perfect = json_data.get('min_perfect', 0)
    min_plagal = json_data.get('min_plagal', 0) 
    min_imperfect = json_data.get('min_imperfect', 0)
    min_interrupted = json_data.get('min_interrupted', 0)
    max_stationary = json_data.get('max_stationary', 3)
    enforce_cadences = json_data.get('enforce_cadences', False)
    
    # Problem dimensions
    melody_length = len(melody) if melody else 8  # Default to 8 if no melody found
    num_voices = 4  # Soprano, Alto, Tenor, Bass
    num_chords = 7  # I, ii, iii, IV, V, V7, vi
    
    # Voice ranges (MIDI note numbers for analysis)
    voice_ranges = {
        'Soprano': (60, 79),    # C4 to G5
        'Alto': (55, 72),       # G3 to C5  
        'Tenor': (48, 67),      # C3 to G4
        'Bass': (41, 60)        # F2 to C4
    }
    
    G = nx.Graph()
    
    # === TYPE 0 NODES: Voice-time positions (decision variables) ===
    voice_names = ['Soprano', 'Alto', 'Tenor', 'Bass']
    
    for t in range(melody_length):
        for v, voice in enumerate(voice_names):
            # Weight by melodic complexity and voice constraints
            if melody and t < len(melody):
                # Soprano is fixed by melody, others need to harmonize
                if voice == 'Soprano':
                    # Fixed voice has low weight (easy decision)
                    weight = 0.2
                else:
                    # Weight by how constrained the voice is relative to melody note
                    melody_note = melody[t]
                    range_min, range_max = voice_ranges[voice]
                    
                    # More weight if melody note is outside voice range (forces specific voicing)
                    if melody_note < range_min or melody_note > range_max:
                        weight = 0.9
                    else:
                        # Weight by position in voice range (edges are harder)
                        range_size = range_max - range_min
                        distance_from_center = abs(melody_note - (range_min + range_max) / 2)
                        weight = 0.4 + 0.4 * (distance_from_center / (range_size / 2))
            else:
                # Default weight based on voice constraints
                weight = 0.6
            
            G.add_node(f'voice_{voice}_{t}', type=0, weight=min(weight, 1.0))
    
    # === TYPE 1 NODES: Musical constraints ===
    
    # 1. Voice range constraints (one per voice)
    for voice in voice_names:
        range_tightness = 0.7  # Voice ranges are moderately restrictive
        G.add_node(f'range_{voice}', type=1, weight=range_tightness)
    
    # 2. Voice crossing constraints (for each time step)
    for t in range(melody_length):
        # Voices must not cross - this is a critical constraint
        crossing_weight = 0.8 + 0.2 * math.exp(-t / melody_length)  # Harder at beginning
        G.add_node(f'no_crossing_{t}', type=1, weight=crossing_weight)
    
    # 3. Interval constraints (parallel 5ths, 8ves) between adjacent time steps
    for t in range(melody_length - 1):
        # Weight by how likely violations are (more voices = more pairs to check)
        interval_weight = 0.7 + 0.2 * (num_voices * (num_voices - 1) / 2) / 10
        G.add_node(f'no_parallel_{t}', type=1, weight=min(interval_weight, 1.0))
    
    # 4. Chord progression constraints
    num_chord_changes = max(1, melody_length - 1)
    chord_change_weight = 0.6  # Moderate constraint
    G.add_node('chord_progression', type=1, weight=chord_change_weight)
    
    # 5. Cadence constraints (if enforced)
    if enforce_cadences:
        total_cadence_requirements = min_perfect + min_plagal + min_imperfect + min_interrupted
        if total_cadence_requirements > 0:
            # Weight by how many cadences are required relative to melody length
            cadence_density = total_cadence_requirements / max(1, melody_length // 4)
            cadence_weight = 0.5 + 0.4 * min(cadence_density, 1.0)
            G.add_node('cadence_requirements', type=1, weight=cadence_weight)
    
    # 6. Stationary constraints (voices can't stay still too long)
    if max_stationary < melody_length:
        # Tighter stationary limits = harder constraint
        stationary_tightness = 1.0 - (max_stationary / melody_length)
        stationary_weight = 0.4 + 0.4 * stationary_tightness
        for voice in voice_names[1:]:  # Skip soprano (it's fixed)
            G.add_node(f'movement_{voice}', type=1, weight=stationary_weight)
    
    # === TYPE 2 NODES: Harmonic resources (chord positions) ===
    
    chord_names = ['I', 'ii', 'iii', 'IV', 'V', 'V7', 'vi']
    for chord in chord_names:
        # Weight by chord function and usage frequency in harmony
        if chord in ['I', 'V', 'IV']:  # Primary chords
            weight = 0.8  # High importance
        elif chord == 'V7':  # Dominant seventh
            weight = 0.7  # Important for cadences
        else:  # Secondary chords
            weight = 0.5  # Less central but still important
        
        G.add_node(f'chord_{chord}', type=2, weight=weight)
    
    # === EDGES: Relationships between nodes ===
    
    # 1. Voice-time to range constraints
    for t in range(melody_length):
        for voice in voice_names:
            voice_node = f'voice_{voice}_{t}'
            range_node = f'range_{voice}'
            G.add_edge(voice_node, range_node, weight=0.9)  # Strong constraint
    
    # 2. Voice-time to crossing constraints
    for t in range(melody_length):
        crossing_node = f'no_crossing_{t}'
        for voice in voice_names:
            voice_node = f'voice_{voice}_{t}'
            G.add_edge(voice_node, crossing_node, weight=0.8)
    
    # 3. Adjacent voice-time positions to parallel motion constraints
    for t in range(melody_length - 1):
        parallel_node = f'no_parallel_{t}'
        for voice in voice_names:
            # Connect both time steps for each voice
            G.add_edge(f'voice_{voice}_{t}', parallel_node, weight=0.7)
            G.add_edge(f'voice_{voice}_{t+1}', parallel_node, weight=0.7)
    
    # 4. Voice-time to chord progression (all voices contribute to harmony)
    for t in range(melody_length):
        for voice in voice_names:
            voice_node = f'voice_{voice}_{t}'
            G.add_edge(voice_node, 'chord_progression', weight=0.6)
    
    # 5. Cadence requirements to specific time positions
    if enforce_cadences and 'cadence_requirements' in [n for n in G.nodes()]:
        # Connect to end positions where cadences occur (every 4 chords)
        for t in range(3, melody_length, 4):  # Every 4th position
            for voice in voice_names:
                voice_node = f'voice_{voice}_{t}'
                G.add_edge(voice_node, 'cadence_requirements', weight=0.8)
    
    # 6. Movement constraints to voice sequences
    for voice in voice_names[1:]:  # Skip soprano
        movement_node = f'movement_{voice}'
        if movement_node in G.nodes():
            for t in range(min(melody_length, max_stationary + 1)):
                voice_node = f'voice_{voice}_{t}'
                # Weight by position (later positions more constrained by earlier choices)
                position_weight = 0.5 + 0.3 * (t / melody_length)
                G.add_edge(voice_node, movement_node, weight=position_weight)
    
    # 7. Voice-time to chord resources (harmonic relationships)
    for t in range(melody_length):
        for voice in voice_names:
            voice_node = f'voice_{voice}_{t}'
            for chord in chord_names:
                chord_node = f'chord_{chord}'
                # Weight by voice importance and chord usage
                if voice == 'Bass':
                    chord_weight = 0.8  # Bass strongly determines harmony
                elif voice == 'Soprano':
                    chord_weight = 0.7  # Melody constrains chord choice
                else:
                    chord_weight = 0.5  # Inner voices more flexible
                
                G.add_edge(voice_node, chord_node, weight=chord_weight)
    
    # 8. Add conflict edges between voices at same time (they must be different notes)
    for t in range(melody_length):
        voices_at_time = [f'voice_{voice}_{t}' for voice in voice_names]
        for i in range(len(voices_at_time)):
            for j in range(i + 1, len(voices_at_time)):
                # Weight by how likely conflicts are (closer voices more likely to conflict)
                voice1, voice2 = voice_names[i], voice_names[j]
                range1, range2 = voice_ranges[voice1], voice_ranges[voice2]
                
                # Calculate range overlap
                overlap_start = max(range1[0], range2[0])
                overlap_end = min(range1[1], range2[1])
                overlap = max(0, overlap_end - overlap_start)
                
                if overlap > 0:
                    # More overlap = higher conflict probability
                    conflict_weight = 0.3 + 0.4 * (overlap / 12)  # 12 semitones = octave
                    G.add_edge(voices_at_time[i], voices_at_time[j], weight=min(conflict_weight, 1.0))
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
    
    # Print node type distribution for verification
    type_counts = {}
    for node, data in G.nodes(data=True):
        node_type = data.get('type', -1)
        type_counts[node_type] = type_counts.get(node_type, 0) + 1
    
    print(f"Node types: {dict(sorted(type_counts.items()))}")


if __name__ == "__main__":
    main()