#!/usr/bin/env python3
"""
test_opt_phase1.py — Demonstrates the new Constraint-Potential Diffusion for Phase 1 of DANCE-ST.

This script creates a simulated turbine blade surface and shows how a localized region
approaching a physical constraint generates a "potential" that diffuses across the
system graph, highlighting systemically relevant components.
"""
from __future__ import annotations
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

def constraint_potential_diffusion(G, potential_vector, alpha=0.85, max_iters=10):
    """Simulates the core diffusion algorithm from Phase 1."""
    relevance_scores = potential_vector.copy()
    # Symmetrically normalized adjacency matrix for diffusion
    L_sym = nx.normalized_laplacian_matrix(G).toarray()
    W = np.eye(len(G.nodes)) - L_sym
    
    for _ in range(max_iters):
        relevance_scores = (1 - alpha) * potential_vector + alpha * (W @ relevance_scores)
        
    return relevance_scores

def main():
    """Main function demonstrating Phase 1 diffusion."""
    print("=== DANCE-ST Phase 1: Constraint-Potential Diffusion Demo ===")

    # 1. Create a simulated turbine blade surface as a grid graph
    grid_size = 15
    G = nx.grid_2d_graph(grid_size, grid_size)
    node_mapping = {node: i for i, node in enumerate(G.nodes())}
    G = nx.relabel_nodes(G, node_mapping)
    pos = {i: (i % grid_size, i // grid_size) for i in range(grid_size*grid_size)}
    n_nodes = len(G.nodes)
    print(f"Created a {grid_size}x{grid_size} grid graph representing a turbine surface.")

    # 2. Simulate a system state (e.g., temperature) with a "hot spot"
    system_state = np.ones(n_nodes) * 800  # Baseline temperature
    hot_spot_center = n_nodes // 2 + grid_size // 2
    system_state[hot_spot_center] = 1195.0  # Very close to the limit
    system_state[hot_spot_center - 1] = 1150.0
    system_state[hot_spot_center + 1] = 1145.0
    system_state[hot_spot_center - grid_size] = 1140.0
    print(f"Simulated a 'hot spot' where one node is at {system_state[hot_spot_center]}°C.")

    # 3. Define the physical constraint and calculate the local potential vector (Φ)
    temp_limit = 1200.0
    epsilon_pot = 1e-6
    potential_vector = np.zeros(n_nodes)
    
    for i in range(n_nodes):
        slack = temp_limit - system_state[i]
        if slack < (temp_limit * 0.2): # Only calculate potential for nodes near the limit
            potential_vector[i] = 1.0 / (slack + epsilon_pot)
    
    # Normalize potential for visualization
    if potential_vector.max() > 0:
        potential_vector /= potential_vector.max()
        
    print("Calculated initial constraint potential (Φ). Only the hot spot has high potential.")

    # 4. Run the Constraint-Potential Diffusion algorithm
    final_relevance = constraint_potential_diffusion(G, potential_vector)
    final_relevance /= final_relevance.max() # Normalize for visualization
    print("Ran diffusion algorithm to get final relevance scores (Λ).")

    # 5. Visualize the results
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    plt.suptitle("Phase 1: Constraint-Potential Diffusion in Action", fontsize=16)

    # Plot 1: System State (Temperature)
    axes[0].set_title("System State (Temperature)")
    nx.draw_networkx_nodes(G, pos, node_color=system_state, cmap=plt.cm.hot, node_size=50, ax=axes[0])
    axes[0].set_aspect('equal')

    # Plot 2: Initial Constraint Potential (Φ)
    axes[1].set_title("Initial Constraint Potential (Φ)")
    nx.draw_networkx_nodes(G, pos, node_color=potential_vector, cmap=plt.cm.viridis, node_size=50, ax=axes[1])
    axes[1].set_aspect('equal')

    # Plot 3: Final Relevance Scores (Λ) after Diffusion
    axes[2].set_title("Final Relevance (Λ) after Diffusion")
    nx.draw_networkx_nodes(G, pos, node_color=final_relevance, cmap=plt.cm.viridis, node_size=50, ax=axes[2])
    axes[2].set_aspect('equal')
    
    save_path = "phase1_diffusion_demo.png"
    plt.savefig(save_path)
    print(f"\nVisualization saved to: {save_path}")

if __name__ == "__main__":
    main()