#!/usr/bin/env python
"""Debug QISK wrapper to see why it's getting 50% accuracy."""

import numpy as np
from simple_qisk_wrapper import create_enhanced_qisk
from real_world_datasets import get_dataset_by_name
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score

# Create simple test data
print("🔍 Debugging QISK wrapper...")

# Test 1: Simple synthetic data
print("\n1. Testing with simple synthetic data:")
np.random.seed(42)
X_simple = np.random.randn(100, 2)
y_simple = (X_simple[:, 0] + X_simple[:, 1] > 0).astype(int)

qisk = create_enhanced_qisk()
svm_baseline = SVC(kernel='rbf', C=1.0, gamma='scale')

# Split data
X_train, X_test = X_simple[:80], X_simple[80:]
y_train, y_test = y_simple[:80], y_simple[80:]

# Train both
qisk.fit(X_train, y_train)
svm_baseline.fit(X_train, y_train)

# Test both
qisk_pred = qisk.predict(X_test)
svm_pred = svm_baseline.predict(X_test)

qisk_acc = accuracy_score(y_test, qisk_pred)
svm_acc = accuracy_score(y_test, svm_pred)

print(f"QISK accuracy: {qisk_acc:.3f}")
print(f"SVM accuracy: {svm_acc:.3f}")
print(f"QISK predictions: {qisk_pred}")
print(f"True labels: {y_test}")

# Test 2: SEA dataset
print("\n2. Testing with SEA dataset:")
dataset = get_dataset_by_name('sea', n_samples=200, drift_points=[100], noise_level=0.1)
stream_data = list(dataset.stream())

X_window = np.array([x for x, y in stream_data[:80]])
y_window = np.array([y for x, y in stream_data[:80]])

X_train_sea = X_window[:40]
y_train_sea = y_window[:40]
X_test_sea = X_window[40:80]
y_test_sea = y_window[40:80]

print(f"SEA data shape: X_train {X_train_sea.shape}, y_train {y_train_sea.shape}")
print(f"SEA labels: {np.unique(y_train_sea, return_counts=True)}")

# Train on SEA
qisk_sea = create_enhanced_qisk()
svm_sea = SVC(kernel='rbf', C=1.0, gamma='scale')

qisk_sea.fit(X_train_sea, y_train_sea)
svm_sea.fit(X_train_sea, y_train_sea)

# Test on SEA
qisk_pred_sea = qisk_sea.predict(X_test_sea)
svm_pred_sea = svm_sea.predict(X_test_sea)

qisk_acc_sea = accuracy_score(y_test_sea, qisk_pred_sea)
svm_acc_sea = accuracy_score(y_test_sea, svm_pred_sea)

print(f"QISK SEA accuracy: {qisk_acc_sea:.3f}")
print(f"SVM SEA accuracy: {svm_acc_sea:.3f}")

# Test 3: Check internal classifier
print("\n3. Checking QISK internal state:")
print(f"QISK is_fitted: {qisk_sea.is_fitted}")
print(f"QISK base_classifier: {type(qisk_sea.base_classifier)}")
if hasattr(qisk_sea.base_classifier, 'classes_'):
    print(f"Base classifier classes: {qisk_sea.base_classifier.classes_}")

# Test direct prediction
try:
    direct_pred = qisk_sea.base_classifier.predict(X_test_sea)
    direct_acc = accuracy_score(y_test_sea, direct_pred)
    print(f"Direct base classifier accuracy: {direct_acc:.3f}")
except Exception as e:
    print(f"Direct prediction error: {e}")