#!/usr/bin/env python
"""
Quick test to verify the kernel parameterization fix.
"""

import numpy as np
from physically_correct_quantum_kernel import PhysicallyCorrectQuantumKernel

def test_kernel_fix():
    """Test that the kernel now changes with parameters."""
    print("Testing Kernel Parameterization Fix")
    print("=" * 40)
    
    # Create test data
    np.random.seed(42)
    X = np.random.randn(5, 4)
    
    # Create kernel
    kernel = PhysicallyCorrectQuantumKernel(n_qubits=4)
    
    print("Initial parameters:", kernel.get_parameters())
    
    # Compute initial kernel
    K1 = kernel.compute_kernel_matrix(X, X)
    print("Initial kernel diagonal:", K1.diagonal())
    print("Initial kernel off-diagonal example:", K1[0, 1])
    
    # Update parameters
    new_params = np.array([0.5, 2.0, 1.5, 0.8])
    kernel.update_parameters(new_params)
    print("\nUpdated parameters:", kernel.get_parameters())
    
    # Compute updated kernel
    K2 = kernel.compute_kernel_matrix(X, X)
    print("Updated kernel diagonal:", K2.diagonal())
    print("Updated kernel off-diagonal example:", K2[0, 1])
    
    # Check if kernels are different
    kernel_changed = not np.allclose(K1, K2)
    print(f"\nKernel changed with parameter update: {kernel_changed}")
    
    if kernel_changed:
        print("✅ FIX SUCCESSFUL: Kernel now changes with parameters!")
        max_change = np.max(np.abs(K1 - K2))
        print(f"Maximum absolute change: {max_change:.6f}")
    else:
        print("❌ FIX FAILED: Kernel still doesn't change with parameters")
        
    return kernel_changed

if __name__ == "__main__":
    test_kernel_fix()