"""
Basic test script to verify CoBET package functionality.
Run this after installing the package with: pip install -e .
"""
import numpy as np
import sys

def test_imports():
    """Test that all modules can be imported."""
    print("Testing imports...")
    try:
        from cobet import (
            run_test, CoBET, dCoBET, wa_dCoBET,
            clayton_copula_sample_nd, apply_transform
        )
        print("✓ All imports successful")
        return True
    except Exception as e:
        print(f"✗ Import failed: {e}")
        return False


def test_cobet_basic():
    """Test basic CoBET functionality."""
    print("\nTesting CoBET...")
    try:
        from cobet import CoBET
        
        # Create instance
        test = CoBET(K=3, d=2, alpha=0.05, seed=42)
        
        # Generate data
        np.random.seed(42)
        X = np.random.randn(50, 2)
        Y = np.random.randn(50, 2)
        
        # Run test
        result = test.test(X, Y)
        
        # Check result structure
        assert 'statistic' in result
        assert 'p_value' in result
        assert 'reject' in result
        assert 'Z' in result
        
        print(f"✓ CoBET test passed (p-value: {result['p_value']:.4f})")
        return True
    except Exception as e:
        print(f"✗ CoBET test failed: {e}")
        import traceback
        traceback.print_exc()
        return False


def test_dcobet_basic():
    """Test basic dCoBET functionality."""
    print("\nTesting dCoBET...")
    try:
        from cobet import dCoBET
        
        # Create instance
        test = dCoBET(K=3, d=2, alpha=0.05, seed=42, reuse_J=True)
        
        # Generate data
        np.random.seed(42)
        X = np.random.randn(50, 2)
        Y = np.random.randn(50, 2)
        
        # Run test
        result = test.test(X, Y)
        
        # Check result structure
        assert 'statistic' in result
        assert 'p_value' in result
        
        print(f"✓ dCoBET test passed (p-value: {result['p_value']:.4f})")
        return True
    except Exception as e:
        print(f"✗ dCoBET test failed: {e}")
        import traceback
        traceback.print_exc()
        return False


def test_wa_dcobet_basic():
    """Test basic wa_dCoBET functionality."""
    print("\nTesting wa_dCoBET...")
    try:
        from cobet import wa_dCoBET
        
        # Create instance
        test = wa_dCoBET(K=3, d=2, alpha=0.05, seed=42, n_folds=5)
        
        # Generate data
        np.random.seed(42)
        X = np.random.randn(50, 2)
        Y = 0.3 * X + 0.7 * np.random.randn(50, 2)
        
        # Run test
        result = test.test(X, Y)
        
        # Check result structure
        assert 'statistic' in result
        assert 'p_value' in result
        assert 'w_identity' in result
        assert 'w_J' in result
        
        print(f"✓ wa_dCoBET test passed (p-value: {result['p_value']:.4f}, "
              f"w_id: {result['w_identity']:.2f})")
        return True
    except Exception as e:
        print(f"✗ wa_dCoBET test failed: {e}")
        import traceback
        traceback.print_exc()
        return False


def test_data_generation():
    """Test data generation functionality."""
    print("\nTesting data generation...")
    try:
        from cobet import CoBET
        
        test = CoBET(K=3, d=3, theta=2, seed=42)
        X, Y = test.generate_data(n=100, transform_key='linear', b=0.1)
        
        assert X.shape == (100, 3)
        assert Y.shape == (100, 3)
        
        print(f"✓ Data generation test passed (shapes: X={X.shape}, Y={Y.shape})")
        return True
    except Exception as e:
        print(f"✗ Data generation test failed: {e}")
        import traceback
        traceback.print_exc()
        return False


if __name__ == "__main__":
    print("="*60)
    print("CoBET Package Basic Tests")
    print("="*60)
    
    results = []
    
    # Run tests
    results.append(("Imports", test_imports()))
    results.append(("CoBET", test_cobet_basic()))
    results.append(("dCoBET", test_dcobet_basic()))
    results.append(("wa_dCoBET", test_wa_dcobet_basic()))
    results.append(("Data Generation", test_data_generation()))
    
    # Summary
    print("\n" + "="*60)
    print("Test Summary")
    print("="*60)
    
    passed = sum(1 for _, result in results if result)
    total = len(results)
    
    for name, result in results:
        status = "✓ PASS" if result else "✗ FAIL"
        print(f"{name:20s}: {status}")
    
    print(f"\nTotal: {passed}/{total} tests passed")
    print("="*60)
    
    # Exit with appropriate code
    sys.exit(0 if passed == total else 1)
