#!/usr/bin/env python3
"""
test_opt_weights.py — Unit test for the Constraint-Potential Diffusion algorithm.


This script now serves as a unit test for the new, training-free diffusion mechanism in Phase 1.
"""
from __future__ import annotations
import unittest
import numpy as np
import networkx as nx

def constraint_potential_diffusion(G, potential_vector, alpha=0.85, max_iters=10):
    """The core diffusion algorithm from Phase 1. (Function to be tested)."""
    relevance_scores = potential_vector.copy()
    L_sym = nx.normalized_laplacian_matrix(G).toarray()
    W = np.eye(len(G.nodes)) - L_sym
    
    for _ in range(max_iters):
        relevance_scores = (1 - alpha) * potential_vector + alpha * (W @ relevance_scores)
        
    return relevance_scores

class TestConstraintPotentialDiffusion(unittest.TestCase):
    """Test suite for the Phase 1 diffusion algorithm."""

    def setUp(self):
        """Set up a simple test graph for all tests."""
        # A simple "barbell" graph: 0--1--2--3--4
        self.G = nx.path_graph(5)
        self.n_nodes = len(self.G.nodes)
        
        # Constraint: value <= 1.0
        self.constraint_limit = 1.0
        self.epsilon_pot = 1e-6

    def test_diffusion_logic(self):
        """Verify that relevance correctly diffuses from a critical node."""
        print("\nRunning test_diffusion_logic...")
        # System state where node 2 is critically close to the limit
        system_state = np.array([0.1, 0.5, 0.999, 0.5, 0.1])
        
        # 1. Calculate initial potential (Φ)
        potential_vector = np.zeros(self.n_nodes)
        for i in range(self.n_nodes):
            slack = self.constraint_limit - system_state[i]
            if slack > 0: # Avoid log of negative
                potential_vector[i] = 1.0 / (slack + self.epsilon_pot)
        
        # Assert that node 2 has the highest initial potential
        self.assertEqual(np.argmax(potential_vector), 2, "Node 2 should have the max initial potential.")
        
        # 2. Run diffusion
        final_relevance = constraint_potential_diffusion(self.G, potential_vector)
        
        print(f"Initial Potential (Φ): {np.round(potential_vector/potential_vector.max(), 3)}")
        print(f"Final Relevance (Λ):   {np.round(final_relevance/final_relevance.max(), 3)}")
        
        # 3. Assert final relevance scores
        # The critical node should still be the most relevant
        self.assertEqual(np.argmax(final_relevance), 2, "Critical node 2 should have the max final relevance.")
        
        # Its direct neighbors (1 and 3) should be more relevant than the end nodes (0 and 4)
        self.assertGreater(final_relevance[1], final_relevance[0], "Neighbor node 1 should be more relevant than end node 0.")
        self.assertGreater(final_relevance[3], final_relevance[4], "Neighbor node 3 should be more relevant than end node 4.")
        
        # The neighbors should be roughly symmetrical in relevance
        self.assertAlmostEqual(final_relevance[1], final_relevance[3], places=5, msg="Symmetrical neighbors should have similar relevance.")

    def test_no_critical_nodes(self):
        """Verify that with no critical nodes, relevance scores are all zero."""
        print("\nRunning test_no_critical_nodes...")
        # System state where no node is close to the limit
        system_state = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
        
        # 1. Calculate initial potential (Φ) - should be all zeros
        potential_vector = np.zeros(self.n_nodes) # Simplified for this test
        
        # 2. Run diffusion
        final_relevance = constraint_potential_diffusion(self.G, potential_vector)
        
        # 3. Assert all scores are zero
        self.assertTrue(np.all(final_relevance == 0), "All relevance scores should be zero when there is no potential.")
        print("Test passed: All relevance scores are zero as expected.")

if __name__ == "__main__":
    unittest.main(argv=['first-arg-is-ignored'], exit=False)