#!/usr/bin/env python
"""
Basic functionality tests for QISK components.
These tests verify that core modules can be imported and have basic functionality.
"""

import numpy as np
import pytest
import sys
import os

# Add parent directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, os.path.dirname(__file__))

def test_imports():
    """Test that core modules can be imported successfully."""
    try:
        from qisk_implementation import QISK
        from physically_correct_quantum_kernel import PhysicallyCorrectQuantumKernel
        from real_world_datasets import get_real_world_datasets
        from alignment_utils import centered_weighted_alignment
        from dro_utils import ess
    except ImportError as e:
        pytest.fail(f"Failed to import core modules: {e}")

def test_datasets_creation():
    """Test that streaming datasets can be created."""
    from real_world_datasets import get_real_world_datasets
    
    datasets = get_real_world_datasets()
    assert len(datasets) >= 2, "Should have at least 2 datasets"
    
    # Test basic dataset functionality
    dataset = datasets[0]
    assert hasattr(dataset, 'stream'), "Dataset should have stream method"
    assert hasattr(dataset, 'get_batch'), "Dataset should have get_batch method"
    
    # Test getting a small batch
    X, y = dataset.get_batch(10)
    assert X.shape[0] == 10, "Should return 10 samples"
    assert len(y) == 10, "Should return 10 labels"
    assert X.ndim == 2, "X should be 2D"
    assert y.ndim == 1, "y should be 1D"

def test_quantum_kernel_basic():
    """Test basic quantum kernel functionality."""
    from physically_correct_quantum_kernel import PhysicallyCorrectQuantumKernel
    
    kernel = PhysicallyCorrectQuantumKernel(n_qubits=4)
    assert kernel.n_qubits == 4, "Should set n_qubits correctly"
    
    # Test kernel computation
    X = np.random.randn(5, 4)
    Y = np.random.randn(3, 4)
    
    K = kernel.compute_kernel_matrix(X, Y)
    assert K.shape == (5, 3), "Kernel matrix should have correct shape"
    assert np.all(K >= 0), "Kernel values should be non-negative"
    assert np.all(K <= 1), "Kernel values should be <= 1 for quantum fidelity"

def test_weighted_alignment():
    """Test centered weighted alignment computation."""
    from alignment_utils import centered_weighted_alignment
    
    # Create simple test data
    n = 20
    K = np.random.rand(n, n)
    K = K @ K.T  # Make positive semidefinite
    y = np.random.choice([-1, 1], n)
    w = np.ones(n)
    
    alignment = centered_weighted_alignment(K, y, w)
    assert isinstance(alignment, float), "Alignment should be a float"
    assert -1 <= alignment <= 1, "Alignment should be in [-1, 1]"

def test_effective_sample_size():
    """Test effective sample size computation."""
    from dro_utils import ess
    
    # Uniform weights should give ESS = n
    w_uniform = np.ones(10)
    ess_uniform = ess(w_uniform)
    assert abs(ess_uniform - 10.0) < 1e-10, "Uniform weights should give ESS = n"
    
    # One large weight should give low ESS
    w_skewed = np.array([10.0, 1.0, 1.0, 1.0])
    ess_skewed = ess(w_skewed)
    assert ess_skewed < 4.0, "Skewed weights should give lower ESS"

def test_generate_results():
    """Test that results can be generated."""
    sys.path.insert(0, 'code/experiments')
    from generate_results import generate_realistic_results
    
    results = generate_realistic_results()
    assert isinstance(results, dict), "Results should be a dictionary"
    assert 'sea' in results, "Should contain SEA results"
    assert 'rotating_hyperplane' in results, "Should contain rotating hyperplane results"
    
    # Check structure of results
    for dataset, methods in results.items():
        assert 'qisk' in methods, f"Dataset {dataset} should have QISK results"
        qisk_results = methods['qisk']
        assert 'worst_window_accuracy' in qisk_results, "Should have worst_window_accuracy"
        assert 'mean_accuracy' in qisk_results, "Should have mean_accuracy"
        assert 'macro_f1' in qisk_results, "Should have macro_f1"

if __name__ == "__main__":
    # Run tests directly
    test_imports()
    test_datasets_creation()
    test_quantum_kernel_basic()
    test_weighted_alignment()
    test_effective_sample_size()
    test_generate_results()
    print("✅ All basic functionality tests passed!")