"""
Pytest smoke tests for quantum kernel and KTA functions.
These tests verify basic functionality and mathematical properties.
"""

import numpy as np
import pytest
from physically_correct_quantum_kernel import PhysicallyCorrectQuantumKernel, StreamingNystromApproximation
from qisk_implementation import weighted_kernel_target_alignment

class TestPhysicallyCorrectQuantumKernel:
    """Test suite for PhysicallyCorrectQuantumKernel."""
    
    def setup_method(self):
        """Set up test fixtures."""
        np.random.seed(42)
        self.kernel = PhysicallyCorrectQuantumKernel(n_qubits=4, feature_scale=1.0)
        self.X_test = np.random.randn(10, 4)
        self.Y_test = np.random.randn(8, 4)
    
    def test_kernel_matrix_shape(self):
        """Test kernel matrix has correct shape."""
        K_XX = self.kernel.compute_kernel_matrix(self.X_test, self.X_test)
        K_XY = self.kernel.compute_kernel_matrix(self.X_test, self.Y_test)
        
        assert K_XX.shape == (10, 10), f"Expected (10, 10), got {K_XX.shape}"
        assert K_XY.shape == (10, 8), f"Expected (10, 8), got {K_XY.shape}"
    
    def test_kernel_diagonal_ones(self):
        """Test kernel diagonal elements are 1.0 (self-similarity)."""
        K_XX = self.kernel.compute_kernel_matrix(self.X_test, self.X_test)
        diagonal = np.diag(K_XX)
        
        # Quantum fidelity should be 1 for identical states
        np.testing.assert_allclose(diagonal, 1.0, rtol=1e-10, 
                                 err_msg="Kernel diagonal should be all ones")
    
    def test_kernel_positive_semidefinite(self):
        """Test kernel matrix is positive semidefinite."""
        K_XX = self.kernel.compute_kernel_matrix(self.X_test, self.X_test)
        eigenvals = np.linalg.eigvals(K_XX)
        
        assert np.all(eigenvals >= -1e-10), f"Negative eigenvalues found: {eigenvals.min()}"
    
    def test_kernel_bounded(self):
        """Test kernel values are bounded in [0, 1]."""
        K_XY = self.kernel.compute_kernel_matrix(self.X_test, self.Y_test)
        
        assert np.all(K_XY >= 0), f"Found negative kernel values: {K_XY.min()}"
        assert np.all(K_XY <= 1), f"Found kernel values > 1: {K_XY.max()}"
    
    def test_kernel_changes_with_parameters(self):
        """Test that kernel actually changes when parameters are updated."""
        # Compute initial kernel matrix
        K1 = self.kernel.compute_kernel_matrix(self.X_test, self.X_test)
        
        # Update parameters (multiplicative factors)
        new_params = np.array([0.5, 2.0, 1.5, 0.8])
        self.kernel.update_parameters(new_params)
        
        # Compute kernel matrix with new parameters
        K2 = self.kernel.compute_kernel_matrix(self.X_test, self.X_test)
        
        # Kernels should be different now
        assert not np.allclose(K1, K2, rtol=1e-10), \
            "Kernel matrix should change when parameters are updated"
        
        # Diagonal should still be ones (self-similarity preserved)
        np.testing.assert_allclose(np.diag(K2), 1.0, rtol=1e-10)
    
    def test_kernel_symmetric(self):
        """Test kernel matrix is symmetric."""
        K_XX = self.kernel.compute_kernel_matrix(self.X_test, self.X_test)
        
        np.testing.assert_allclose(K_XX, K_XX.T, rtol=1e-12, 
                                 err_msg="Kernel matrix should be symmetric")
    
    def test_parameter_getter_setter(self):
        """Test parameter get and set methods work correctly."""
        original_params = self.kernel.get_parameters()
        
        # Set new parameters 
        new_params = np.array([0.1, 0.2, 0.3, 0.4])
        self.kernel.update_parameters(new_params)
        retrieved_params = self.kernel.get_parameters()
        
        np.testing.assert_allclose(retrieved_params, new_params, rtol=1e-12, 
                                 err_msg="Parameter get/set should be consistent")

class TestStreamingNystromApproximation:
    """Test suite for StreamingNystromApproximation."""
    
    def setup_method(self):
        """Set up test fixtures."""
        np.random.seed(42)
        self.kernel = PhysicallyCorrectQuantumKernel(n_qubits=4, feature_scale=1.0)
        self.nystrom = StreamingNystromApproximation(self.kernel, n_anchors=8)
        self.X_test = np.random.randn(20, 4)
    
    def test_fit_and_transform_shapes(self):
        """Test Nyström fit and transform produce correct shapes."""
        self.nystrom.fit(self.X_test)
        
        # Check anchors were selected
        assert self.nystrom.anchors is not None, "Anchors should be selected after fit"
        assert self.nystrom.anchors.shape[0] <= 8, f"Too many anchors: {self.nystrom.anchors.shape[0]}"
        
        # Check transform shape
        feature_map = self.nystrom.transform(self.X_test)
        assert feature_map.shape[0] == 20, f"Wrong number of samples: {feature_map.shape[0]}"
        assert feature_map.shape[1] <= 8, f"Too many features: {feature_map.shape[1]}"
    
    def test_kernel_matrix_reconstruction(self):
        """Test Nyström kernel matrix has correct shape and properties."""
        self.nystrom.fit(self.X_test)
        K_approx = self.nystrom.get_kernel_matrix(self.X_test)
        
        assert K_approx.shape == (20, 20), f"Expected (20, 20), got {K_approx.shape}"
        
        # Should be approximately symmetric
        symmetry_error = np.max(np.abs(K_approx - K_approx.T))
        assert symmetry_error < 1e-10, f"Kernel matrix should be symmetric, error: {symmetry_error}"
        
        # Should be approximately PSD
        eigenvals = np.linalg.eigvals(K_approx)
        assert np.all(eigenvals >= -1e-8), f"Negative eigenvalues: {eigenvals.min()}"
    
    def test_fidelity_bounds(self):
        """Test approximation fidelity is in reasonable range."""
        self.nystrom.fit(self.X_test)
        fidelity = self.nystrom.kernel_fidelity(self.X_test)
        
        assert 0.0 <= fidelity <= 1.0, f"Fidelity out of bounds: {fidelity}"
        # With 8 anchors for 20 samples, fidelity should be reasonable
        assert fidelity > 0.1, f"Fidelity too low: {fidelity}"

class TestWeightedKTA:
    """Test suite for weighted kernel-target alignment."""
    
    def setup_method(self):
        """Set up test fixtures."""
        np.random.seed(42)
        self.K = np.array([[1.0, 0.8, 0.6],
                          [0.8, 1.0, 0.7], 
                          [0.6, 0.7, 1.0]])  # Simple PSD matrix
        self.y_binary = np.array([0, 1, 0])  # Binary labels
        self.y_pm1 = np.array([-1, 1, -1])  # ±1 labels
        self.weights = np.array([1.0, 1.0, 1.0])  # Uniform weights
    
    def test_kta_bounded(self):
        """Test KTA returns finite values."""
        kta_binary = weighted_kernel_target_alignment(self.K, self.y_binary, self.weights)
        kta_pm1 = weighted_kernel_target_alignment(self.K, self.y_pm1, self.weights)
        
        assert np.isfinite(kta_binary), f"KTA should be finite: {kta_binary}"
        assert np.isfinite(kta_pm1), f"KTA should be finite: {kta_pm1}"
    
    def test_kta_label_encoding_effect(self):
        """Test that label encoding (0/1 vs ±1) affects KTA values."""
        kta_binary = weighted_kernel_target_alignment(self.K, self.y_binary, self.weights)
        kta_pm1 = weighted_kernel_target_alignment(self.K, self.y_pm1, self.weights)
        
        # Due to centering, these should be similar but not identical
        assert abs(kta_binary - kta_pm1) < 1.0, "KTA values should be comparable"
    
    def test_kta_perfect_alignment(self):
        """Test KTA with perfectly aligned kernel and labels."""
        # Create kernel that matches label pattern
        y = np.array([0, 1, 0, 1])
        # Kernel with high similarity for same labels, low for different
        K_perfect = np.array([[1.0, 0.1, 0.9, 0.1],
                            [0.1, 1.0, 0.1, 0.9],
                            [0.9, 0.1, 1.0, 0.1],
                            [0.1, 0.9, 0.1, 1.0]])
        weights = np.ones(4)
        
        kta = weighted_kernel_target_alignment(K_perfect, y, weights)
        
        # Should be positive for aligned patterns
        assert kta > 0, f"KTA should be positive for aligned patterns: {kta}"
    
    def test_kta_numerical_stability(self):
        """Test KTA handles edge cases gracefully."""
        # Zero kernel
        K_zero = np.zeros((3, 3))
        kta_zero = weighted_kernel_target_alignment(K_zero, self.y_binary, self.weights)
        assert np.isfinite(kta_zero), f"Should handle zero kernel: {kta_zero}"
        
        # Very small weights
        tiny_weights = np.array([1e-12, 1e-12, 1e-12])
        kta_tiny = weighted_kernel_target_alignment(self.K, self.y_binary, tiny_weights)
        assert np.isfinite(kta_tiny), f"Should handle tiny weights: {kta_tiny}"

class TestEndToEndWorkflow:
    """Test end-to-end workflow including training and evaluation."""
    
    def setup_method(self):
        """Set up test fixtures."""
        np.random.seed(42)
        import sys
        sys.path.append('..')
        from qisk_implementation import QISK
        from real_world_datasets import get_real_world_datasets
        self.QISK = QISK
        self.DataGenerator = DataGenerator
        self.center_kernel_weighted = center_kernel_weighted
    
    def test_drqka_train_eval_cycle(self):
        """Test full training and evaluation cycle works without errors."""
        # Create small dataset
        X_data, y_data, _ = self.DataGenerator.generate_sea(n_samples=200, seed=42)
        
        # Create two windows
        X_window1, y_window1 = X_data[:100], y_data[:100]
        X_window2, y_window2 = X_data[100:], y_data[100:]
        
        # Initialize DRQKA-Lite
        model = self.DRQKALite()
        
        # Test first window (no history)
        train_results1 = model.train_window(X_window1, y_window1, X_previous=None)
        assert "kta" in train_results1
        assert "training_time" in train_results1
        assert np.isfinite(train_results1["kta"])
        
        # Test evaluation on first window
        eval_results1 = model.evaluate_window(X_window1[:20], y_window1[:20], 
                                            train_results1['classifier'], None, X_window1)
        assert "balanced_accuracy" in eval_results1
        assert "macro_f1" in eval_results1
        assert 0 <= eval_results1["balanced_accuracy"] <= 1
        
        # Test second window (with history)
        train_results2 = model.train_window(X_window2, y_window2, X_previous=X_window1)
        assert "kta" in train_results2
        assert np.isfinite(train_results2["kta"])
        
        # Test evaluation on second window
        eval_results2 = model.evaluate_window(X_window2[:20], y_window2[:20], 
                                            train_results2['classifier'], None, X_window2)
        assert "balanced_accuracy" in eval_results2
        assert 0 <= eval_results2["balanced_accuracy"] <= 1

    def test_kernel_centering_invariance(self):
        """Test kernel centering is invariant to constant shifts."""
        # Create test data
        K = np.array([[1.0, 0.8, 0.6],
                     [0.8, 1.0, 0.7], 
                     [0.6, 0.7, 1.0]])
        weights = np.array([1.0, 1.0, 1.0])
        
        # Add constant to all entries
        K_shifted = K + 0.5
        
        # Center both
        K_centered = self.center_kernel_weighted(K, weights)
        K_shifted_centered = self.center_kernel_weighted(K_shifted, weights)
        
        # Should be identical after centering
        np.testing.assert_allclose(K_centered, K_shifted_centered, rtol=1e-12,
                                 err_msg="Centering should be invariant to constant shifts")
    
    def test_scaler_fitting_with_history(self):
        """Test scaler fitting works correctly with historical data."""
        model = self.DRQKALite()
        
        # Create test data
        X_history = np.random.randn(50, 4)
        X_current = np.random.randn(30, 4) + 2  # Different distribution
        y_current = np.random.randint(0, 2, 30)
        
        # This should not crash (test the bug fix)
        try:
            train_results = model.train_window(X_current, y_current, X_previous=X_history)
            assert "kta" in train_results
            assert np.isfinite(train_results["kta"])
        except Exception as e:
            pytest.fail(f"Training with history failed: {e}")
    
    def test_experiment_runner_seeding(self):
        """Test experiment runner produces consistent results with seeding."""
        from drqka_lite_implementation import ExperimentRunner
        
        # Run experiments with same seed twice
        runner1 = ExperimentRunner()
        results1 = runner1.run_comprehensive_experiments(seed=42)
        
        runner2 = ExperimentRunner()  
        results2 = runner2.run_comprehensive_experiments(seed=42)
        
        # Should get identical results with same seed
        for method in results1:
            if method in results2:
                assert abs(results1[method]["mean_accuracy"] - 
                          results2[method]["mean_accuracy"]) < 1e-10, \
                    f"Seeding failed for {method}"

class TestKTAAlignment:
    """Additional tests for weighted KTA alignment."""
    
    def setup_method(self):
        """Set up test fixtures."""
        from drqka_lite_implementation import weighted_kernel_target_alignment, center_kernel_weighted
        self.weighted_kta = weighted_kernel_target_alignment
        self.center_kernel_weighted = center_kernel_weighted
    
    def test_kta_monotonicity(self):
        """Test KTA decreases when kernel-target mismatch increases."""
        # Create aligned pattern
        K_aligned = np.array([[1.0, 0.9, 0.1, 0.1],
                            [0.9, 1.0, 0.1, 0.1],
                            [0.1, 0.1, 1.0, 0.9],
                            [0.1, 0.1, 0.9, 1.0]])
        # Labels that match the block structure
        y_aligned = np.array([0, 0, 1, 1])
        weights = np.ones(4)
        
        kta_aligned = self.weighted_kta(K_aligned, y_aligned, weights)
        
        # Create misaligned pattern (swap some kernel values)
        K_misaligned = K_aligned.copy()
        K_misaligned[0, 2] = 0.8  # Make different-class points more similar
        K_misaligned[2, 0] = 0.8
        
        kta_misaligned = self.weighted_kta(K_misaligned, y_aligned, weights)
        
        # Aligned pattern should have higher KTA
        assert kta_aligned > kta_misaligned, \
            f"Aligned KTA ({kta_aligned}) should be > misaligned KTA ({kta_misaligned})"

if __name__ == "__main__":
    pytest.main([__file__, "-v"])