#!/usr/bin/env python
"""
Ablation studies for QADRIFT components.
Tests: +/- DRO-Lite, +/- Nyström, fixed vs learned params, classical RBF baseline.
"""

import numpy as np
import json
import time
from qisk_implementation import (
    QISK, weighted_kernel_target_alignment, 
    center_kernel_weighted
)
from real_world_datasets import get_real_world_datasets
from physically_correct_quantum_kernel import PhysicallyCorrectQuantumKernel, StreamingNystromApproximation
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import balanced_accuracy_score, f1_score

class AblationRunner:
    """Run ablation studies to isolate component contributions."""
    
    def __init__(self, n_windows=10, window_size=200, seed=42):
        self.n_windows = n_windows
        self.window_size = window_size
        self.seed = seed
        np.random.seed(seed)
    
    def run_drqka_no_dro(self, X_windows, y_windows):
        """DRQKA without DRO-Lite (uniform weights)."""
        print("Running DRQKA without DRO-Lite...")
        
        # Initialize components
        kernel = PhysicallyCorrectQuantumKernel(n_qubits=4, trainable_params=np.ones(4))
        nystrom = StreamingNystromApproximation(kernel, n_anchors=16, anchor_strategy='kmeans')
        scaler = StandardScaler()
        
        results = []
        
        for i, (X_train, y_train, X_test, y_test) in enumerate(X_windows):
            # Preprocess
            if i == 0:
                X_train_scaled = scaler.fit_transform(X_train)
            else:
                X_train_scaled = scaler.transform(X_train)
            X_test_scaled = scaler.transform(X_test)
            
            # Fit Nyström
            nystrom.fit(X_train_scaled)
            
            # Use uniform weights (no DRO)
            weights = np.ones(len(y_train))
            
            # Simple parameter optimization (reduced iterations for speed)
            best_kta = -1
            best_params = kernel.get_parameters()
            
            for _ in range(5):  # Reduced from 10
                # SPSA-like perturbation
                perturbation = np.random.normal(0, 0.1, 4)
                test_params = np.maximum(0.1, kernel.get_parameters() + perturbation)
                kernel.update_parameters(test_params)
                
                K_train = nystrom.get_kernel_matrix(X_train_scaled)
                K_centered = center_kernel_weighted(K_train, weights)
                kta = weighted_kernel_target_alignment(K_centered, y_train, weights)
                
                if kta > best_kta:
                    best_kta = kta
                    best_params = test_params.copy()
            
            # Use best parameters
            kernel.update_parameters(best_params)
            nystrom.fit(X_train_scaled)  # Refit with final params
            
            # Train classifier
            Phi_train = nystrom.transform(X_train_scaled)
            K_final = Phi_train @ Phi_train.T
            classifier = SVC(kernel='precomputed', random_state=42)
            classifier.fit(K_final, y_train)
            
            # Evaluate
            Phi_test = nystrom.transform(X_test_scaled)
            K_test = Phi_test @ Phi_train.T
            predictions = classifier.predict(K_test)
            
            balanced_acc = balanced_accuracy_score(y_test, predictions)
            macro_f1 = f1_score(y_test, predictions, average='macro')
            
            results.append({
                'window': i,
                'balanced_accuracy': balanced_acc,
                'macro_f1': macro_f1,
                'kta': best_kta
            })
        
        # Aggregate results
        bal_accs = [r['balanced_accuracy'] for r in results]
        f1_scores = [r['macro_f1'] for r in results]
        
        return {
            'mean_accuracy': np.mean(bal_accs),
            'worst_window_accuracy': np.min(bal_accs),
            'macro_f1': np.mean(f1_scores),
            'per_window_results': results
        }
    
    def run_drqka_fixed_params(self, X_windows, y_windows):
        """DRQKA with fixed (non-trainable) parameters."""
        print("Running DRQKA with fixed parameters...")
        
        # Initialize with fixed parameters (no learning)
        kernel = PhysicallyCorrectQuantumKernel(n_qubits=4, trainable_params=np.ones(4))
        nystrom = StreamingNystromApproximation(kernel, n_anchors=16, anchor_strategy='kmeans')
        scaler = StandardScaler()
        
        # Simple DRO-Lite weighting
        from drqka_lite_implementation import DROLiteWeighting
        dro_weighting = DROLiteWeighting()
        
        results = []
        X_previous = None
        
        for i, (X_train, y_train, X_test, y_test) in enumerate(X_windows):
            # Preprocess
            if i == 0:
                X_train_scaled = scaler.fit_transform(X_train)
            else:
                X_train_scaled = scaler.transform(X_train)
            X_test_scaled = scaler.transform(X_test)
            
            # Get DRO weights
            if X_previous is not None and len(X_previous) > 0:
                X_prev_scaled = scaler.transform(X_previous)
                weights = dro_weighting.compute_weights(X_train_scaled, X_prev_scaled)
            else:
                weights = np.ones(len(y_train))
            
            # Fit Nyström with FIXED parameters (no learning)
            nystrom.fit(X_train_scaled)
            
            # Train classifier
            Phi_train = nystrom.transform(X_train_scaled)
            K_final = Phi_train @ Phi_train.T
            classifier = SVC(kernel='precomputed', random_state=42)
            classifier.fit(K_final, y_train, sample_weight=weights)
            
            # Evaluate
            Phi_test = nystrom.transform(X_test_scaled)
            K_test = Phi_test @ Phi_train.T
            predictions = classifier.predict(K_test)
            
            balanced_acc = balanced_accuracy_score(y_test, predictions)
            macro_f1 = f1_score(y_test, predictions, average='macro')
            
            # KTA for analysis
            K_centered = center_kernel_weighted(K_final, weights)
            kta = weighted_kernel_target_alignment(K_centered, y_train, weights)
            
            results.append({
                'window': i,
                'balanced_accuracy': balanced_acc,
                'macro_f1': macro_f1,
                'kta': kta
            })
            
            X_previous = X_train.copy()
        
        # Aggregate results
        bal_accs = [r['balanced_accuracy'] for r in results]
        f1_scores = [r['macro_f1'] for r in results]
        
        return {
            'mean_accuracy': np.mean(bal_accs),
            'worst_window_accuracy': np.min(bal_accs),
            'macro_f1': np.mean(f1_scores),
            'kta_correlation': np.corrcoef([r['kta'] for r in results], bal_accs)[0, 1] if len(results) > 1 else 0,
            'per_window_results': results
        }
    
    def run_classical_rbf_dro(self, X_windows, y_windows):
        """Classical RBF kernel with DRO-Lite and WKTA."""
        print("Running Classical RBF with DRO-Lite...")
        
        from sklearn.gaussian_process.kernels import RBF
        from sklearn.metrics.pairwise import rbf_kernel
        
        scaler = StandardScaler()
        from drqka_lite_implementation import DROLiteWeighting
        dro_weighting = DROLiteWeighting()
        
        results = []
        X_previous = None
        gamma = 1.0  # RBF parameter
        
        for i, (X_train, y_train, X_test, y_test) in enumerate(X_windows):
            # Preprocess
            if i == 0:
                X_train_scaled = scaler.fit_transform(X_train)
            else:
                X_train_scaled = scaler.transform(X_train)
            X_test_scaled = scaler.transform(X_test)
            
            # Get DRO weights
            if X_previous is not None and len(X_previous) > 0:
                X_prev_scaled = scaler.transform(X_previous)
                weights = dro_weighting.compute_weights(X_train_scaled, X_prev_scaled)
            else:
                weights = np.ones(len(y_train))
            
            # RBF kernel matrix
            K_train = rbf_kernel(X_train_scaled, X_train_scaled, gamma=gamma)
            
            # Train classifier
            classifier = SVC(kernel='precomputed', random_state=42)
            classifier.fit(K_train, y_train, sample_weight=weights)
            
            # Evaluate
            K_test = rbf_kernel(X_test_scaled, X_train_scaled, gamma=gamma)
            predictions = classifier.predict(K_test)
            
            balanced_acc = balanced_accuracy_score(y_test, predictions)
            macro_f1 = f1_score(y_test, predictions, average='macro')
            
            # KTA for analysis
            K_centered = center_kernel_weighted(K_train, weights)
            kta = weighted_kernel_target_alignment(K_centered, y_train, weights)
            
            results.append({
                'window': i,
                'balanced_accuracy': balanced_acc,
                'macro_f1': macro_f1,
                'kta': kta
            })
            
            X_previous = X_train.copy()
        
        # Aggregate results
        bal_accs = [r['balanced_accuracy'] for r in results]
        f1_scores = [r['macro_f1'] for r in results]
        
        return {
            'mean_accuracy': np.mean(bal_accs),
            'worst_window_accuracy': np.min(bal_accs),
            'macro_f1': np.mean(f1_scores),
            'kta_correlation': np.corrcoef([r['kta'] for r in results], bal_accs)[0, 1] if len(results) > 1 else 0,
            'per_window_results': results
        }
    
    def prepare_data_windows(self, dataset_name='sea'):
        """Prepare data windows for ablation experiments."""
        if dataset_name == 'sea':
            X, y, drift_points = DataGenerator.generate_sea(n_samples=2000, seed=self.seed)
        else:
            X, y, drift_points = DataGenerator.generate_rotating_hyperplane(n_samples=2000, seed=self.seed)
        
        windows = []
        for i in range(self.n_windows):
            start = i * (self.window_size // 2)  # 50% overlap
            end = start + self.window_size
            
            if end > len(X):
                break
            
            # Split into train/test (80/20)
            train_size = int(0.8 * self.window_size)
            X_train = X[start:start+train_size]
            y_train = y[start:start+train_size]
            X_test = X[start+train_size:end]
            y_test = y[start+train_size:end]
            
            windows.append((X_train, y_train, X_test, y_test))
        
        return windows
    
    def run_all_ablations(self, dataset='sea'):
        """Run all ablation studies."""
        print(f"\n{'='*50}")
        print(f"Running Ablation Studies on {dataset.upper()} Dataset")
        print(f"{'='*50}")
        
        # Prepare data
        X_windows = self.prepare_data_windows(dataset)
        print(f"Prepared {len(X_windows)} windows")
        
        ablation_results = {}
        
        # 1. Full DRQKA (baseline from main results)
        print(f"\n1. Loading Full DRQKA results from comprehensive_results.json...")
        try:
            with open('experimental_results/comprehensive_results.json', 'r') as f:
                main_results = json.load(f)
            full_drqka = main_results[dataset]['drqka_lite']
            ablation_results['full_drqka'] = full_drqka
            print(f"   Worst-window accuracy: {full_drqka['worst_window_accuracy']['mean']:.3f}")
        except:
            print("   Could not load main results, skipping...")
            ablation_results['full_drqka'] = {'worst_window_accuracy': {'mean': 0.0}}
        
        # 2. DRQKA without DRO-Lite
        try:
            no_dro_results = self.run_drqka_no_dro(X_windows)
            ablation_results['drqka_no_dro'] = no_dro_results
            print(f"\n2. DRQKA without DRO-Lite:")
            print(f"   Worst-window accuracy: {no_dro_results['worst_window_accuracy']:.3f}")
        except Exception as e:
            print(f"\n2. DRQKA without DRO-Lite failed: {e}")
        
        # 3. DRQKA with fixed parameters
        try:
            fixed_params_results = self.run_drqka_fixed_params(X_windows)
            ablation_results['drqka_fixed_params'] = fixed_params_results
            print(f"\n3. DRQKA with fixed parameters:")
            print(f"   Worst-window accuracy: {fixed_params_results['worst_window_accuracy']:.3f}")
        except Exception as e:
            print(f"\n3. DRQKA with fixed parameters failed: {e}")
        
        # 4. Classical RBF with DRO-Lite
        try:
            rbf_dro_results = self.run_classical_rbf_dro(X_windows)
            ablation_results['classical_rbf_dro'] = rbf_dro_results
            print(f"\n4. Classical RBF with DRO-Lite:")
            print(f"   Worst-window accuracy: {rbf_dro_results['worst_window_accuracy']:.3f}")
        except Exception as e:
            print(f"\n4. Classical RBF with DRO-Lite failed: {e}")
        
        return ablation_results

def main():
    """Run ablation studies on both datasets."""
    runner = AblationRunner(n_windows=8, window_size=200, seed=42)  # Smaller scale for speed
    
    all_ablations = {}
    
    # Run on SEA dataset
    sea_ablations = runner.run_all_ablations('sea')
    all_ablations['sea'] = sea_ablations
    
    # Run on Rotating Hyperplane dataset  
    rotating_ablations = runner.run_all_ablations('rotating_hyperplane')
    all_ablations['rotating_hyperplane'] = rotating_ablations
    
    # Save results
    with open('experimental_results/ablation_results.json', 'w') as f:
        json.dump(all_ablations, f, indent=2)
    
    # Print summary
    print(f"\n{'='*60}")
    print("ABLATION STUDY SUMMARY")
    print(f"{'='*60}")
    
    for dataset in ['sea', 'rotating_hyperplane']:
        print(f"\n{dataset.upper()} Dataset:")
        ablations = all_ablations[dataset]
        
        for method_name, results in ablations.items():
            if 'worst_window_accuracy' in results:
                if isinstance(results['worst_window_accuracy'], dict):
                    worst_acc = results['worst_window_accuracy']['mean']
                else:
                    worst_acc = results['worst_window_accuracy']
                print(f"  {method_name:20}: {worst_acc:.3f}")
    
    print(f"\nResults saved to: experimental_results/ablation_results.json")

if __name__ == "__main__":
    main()