"""
QISK: Quantum-Inspired Streaming Kernels
Optimized implementation with proper streaming baselines,
cached Nyström anchors, and cleaned up code issues.
"""

import numpy as np
from typing import List, Dict, Any, Tuple, Optional
import json
import os
from datetime import datetime
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import balanced_accuracy_score, f1_score
from sklearn.model_selection import cross_val_score
import warnings
warnings.filterwarnings('ignore')

# Import our modules
from physically_correct_quantum_kernel import PhysicallyCorrectQuantumKernel, StreamingNystromApproximation
# Optional River-based baselines
try:
    from streaming_baselines import get_streaming_baselines, evaluate_streaming_baseline
    RIVER_AVAILABLE = True
except ImportError:
    RIVER_AVAILABLE = False
    print("Warning: River library not available, using simple baselines")
from real_world_datasets import get_real_world_datasets

# Import enhanced utilities
try:
    from alignment_utils import centered_weighted_alignment, try_kta_update
    from dro_utils import DROLiteWeighting, EMAState, dro_lite_weights
    from anchor_utils import AnchorManager, refresh_anchors
    ENHANCED_UTILS_AVAILABLE = True
except ImportError:
    ENHANCED_UTILS_AVAILABLE = False
    print("Warning: Enhanced utilities not available, using fallback implementations")



def center_kernel_weighted(K: np.ndarray, w: np.ndarray) -> np.ndarray:
    """
    Double-center kernel K with probability weights w (sum(w)=1).
    Ensures zero weighted row/column means.
    """
    w = np.asarray(w, dtype=float)
    w = w / (w.sum() + 1e-12)
    ones = np.ones((K.shape[0], 1))
    mu_col = K @ w                      # shape (n,)
    mu_row = (w @ K).reshape(-1, )      # shape (n,)
    mu_all = float(w @ K @ w)           # scalar
    # Broadcast to matrix shapes
    Kc = K - mu_col[:, None] - mu_row[None, :] + mu_all
    return Kc


def weighted_kernel_target_alignment(K: np.ndarray, y: np.ndarray, weights: np.ndarray) -> float:
    """
    Compute weighted kernel-target alignment with proper weighted centering.
    
    Args:
        K: Kernel matrix (n x n)
        y: Target labels (n,)
        weights: Sample weights (n,)
    
    Returns:
        Weighted KTA score
    """
    n = len(y)
    if n == 0:
        return 0.0
    
    # Normalize weights to probability distribution
    w = np.asarray(weights, dtype=float)
    w = w / (w.sum() + 1e-12)
    
    # Convert labels to {-1, +1} format
    y_pm = 2 * np.asarray(y) - 1
    
    # Center K with proper weighted double-centering
    Kc = center_kernel_weighted(K, w)
    
    # Center target with weights
    y_mean = float(w @ y_pm)
    y_c = y_pm - y_mean
    Y = np.outer(y_c, y_c)
    
    # Weighted alignment computation
    W = np.outer(w, w)
    num = np.sum(W * Kc * Y)
    den = np.sqrt(np.sum(W * Kc * Kc) * np.sum(W * Y * Y)) + 1e-12
    return float(num / den)


class QISK:
    """
    QISK: Quantum-Inspired Streaming Kernels with cached anchors and proper streaming evaluation.
    
    Key optimizations:
    1. Cache Nyström anchors during SPSA optimization
    2. Remove unused parameters from function signatures
    3. Add comprehensive docstrings
    4. Implement proper streaming evaluation protocol
    """
    
    def __init__(self, 
                 n_qubits: int = 4,
                 n_anchors: int = 16,
                 dro_strength: float = 0.1,
                 anchor_strategy: str = 'kmeans',
                 spsa_iterations: int = 50):
        """
        Initialize QISK (Quantum-Inspired Streaming Kernels).
        
        Args:
            n_qubits: Number of qubits for quantum kernel
            n_anchors: Number of Nyström anchors
            dro_strength: Strength of DRO-Lite importance weighting
            anchor_strategy: Strategy for anchor selection
            spsa_iterations: Number of SPSA optimization iterations
        """
        self.n_qubits = n_qubits
        self.n_anchors = n_anchors
        self.dro_strength = dro_strength
        self.spsa_iterations = spsa_iterations
        
        # Initialize components
        self.quantum_kernel = PhysicallyCorrectQuantumKernel(n_qubits=n_qubits)
        self.nystrom = StreamingNystromApproximation(
            quantum_kernel=self.quantum_kernel,
            n_anchors=n_anchors, 
            anchor_strategy=anchor_strategy
        )
        self.scaler = StandardScaler()
        
        # Training history
        self.training_history = []
        self.fitted = False
        
    def preprocess_features(self, X: np.ndarray) -> np.ndarray:
        """Preprocess features with fitted scaler."""
        if not hasattr(self.scaler, 'mean_'):
            return self.scaler.fit_transform(X)
        else:
            return self.scaler.transform(X)
    
    def estimate_density_ratios(self, X_target: np.ndarray, X_source: np.ndarray) -> np.ndarray:
        """
        Estimate density ratios using logistic discrimination (DRO-Lite).
        
        Args:
            X_target: Target distribution samples
            X_source: Source distribution samples (historical data)
            
        Returns:
            Density ratio estimates for target samples
        """
        from sklearn.linear_model import LogisticRegression
        
        if len(X_source) == 0:
            return np.ones(len(X_target))
        
        # Create discrimination dataset
        X_combined = np.vstack([X_target, X_source])
        y_combined = np.hstack([np.ones(len(X_target)), np.zeros(len(X_source))])
        
        # Train discriminator
        discriminator = LogisticRegression(random_state=42, max_iter=1000)
        discriminator.fit(X_combined, y_combined)
        
        # Estimate density ratios: r(x) = P(x|target) / P(x|source) = p(y=1|x) / p(y=0|x)
        target_probs = discriminator.predict_proba(X_target)[:, 1]  # P(y=1|x)
        source_probs = 1 - target_probs  # P(y=0|x)
        
        # Avoid division by zero
        source_probs = np.clip(source_probs, 1e-8, 1-1e-8)
        ratios = target_probs / source_probs
        
        # Clip ratios for stability
        ratios = np.clip(ratios, 0.1, 10.0)
        
        return ratios
    
    def fit_window(self, 
                   X_current: np.ndarray, 
                   y_current: np.ndarray,
                   X_history: Optional[List[np.ndarray]] = None) -> Dict[str, Any]:
        """
        Fit QISK on current window with optimized SPSA.
        
        Args:
            X_current: Current window features
            y_current: Current window labels  
            X_history: Historical windows for importance weighting
            
        Returns:
            Training results and metrics
        """
        # Preprocess current window
        X_processed = self.preprocess_features(X_current)
        
        # Estimate importance weights using DRO-Lite
        if X_history and len(X_history) > 0:
            X_hist_combined = np.vstack([self.preprocess_features(X_hist) for X_hist in X_history])
            weights = self.estimate_density_ratios(X_processed, X_hist_combined)
            
            # Apply EMA smoothing
            if hasattr(self, '_prev_weights'):
                alpha = 0.7
                weights = alpha * weights + (1 - alpha) * self._prev_weights
            self._prev_weights = weights.copy()
        else:
            weights = np.ones(len(X_current))
        
        # Normalize weights to probability distribution (sum=1)
        weights = weights / (np.sum(weights) + 1e-12)
        
        # Fit Nyström approximation (this selects and caches anchors)
        self.nystrom.fit(X_processed)
        
        # OPTIMIZED SPSA with cached anchors
        best_kta = -np.inf
        best_params = self.quantum_kernel.trainable_params.copy()
        
        # Cache kernel computation components
        cached_anchors = self.nystrom.anchors.copy()
        cached_K_ZZ_inv = self.nystrom.K_ZZ_inv.copy()
        
        def cached_objective(theta_params: np.ndarray) -> float:
            """Objective function with cached anchor computations."""
            # Update kernel parameters
            old_params = self.quantum_kernel.trainable_params.copy()
            self.quantum_kernel.update_parameters(theta_params)
            
            try:
                # Reuse cached anchors - only recompute what's necessary
                K_XZ = self.quantum_kernel.compute_kernel_matrix(X_processed, cached_anchors)
                K_ZZ = self.quantum_kernel.compute_kernel_matrix(cached_anchors, cached_anchors)
                K_ZZ += 1e-8 * np.eye(len(cached_anchors))
                
                # Use cached inverse structure but update with new kernel values
                try:
                    K_ZZ_inv_new = np.linalg.pinv(K_ZZ)
                    K_approx = K_XZ @ K_ZZ_inv_new @ K_XZ.T
                except np.linalg.LinAlgError:
                    K_approx = K_XZ @ cached_K_ZZ_inv @ K_XZ.T  # Fallback to cached
                
                # Compute weighted KTA
                kta = weighted_kernel_target_alignment(K_approx, y_current, weights)
                return -kta  # Minimize negative KTA
                
            except Exception as e:
                # Fallback: return poor score on numerical issues
                return 1.0
            finally:
                # Restore old parameters for safety
                self.quantum_kernel.update_parameters(old_params)
        
        # SPSA optimization with cached computations
        a = 0.1  # Step size
        c = 0.01  # Perturbation size
        current_params = best_params.copy()
        
        for iteration in range(self.spsa_iterations):
            # Generate random perturbation
            delta = 2 * np.random.binomial(1, 0.5, len(current_params)) - 1
            
            # Evaluate at perturbed points
            loss_plus = cached_objective(current_params + c * delta)
            loss_minus = cached_objective(current_params - c * delta)
            
            # SPSA gradient approximation
            gradient_approx = (loss_plus - loss_minus) / (2 * c) * delta
            
            # Update parameters
            step_size = a / ((iteration + 1) ** 0.602)
            current_params -= step_size * gradient_approx
            
            # Evaluate current parameters
            current_loss = cached_objective(current_params)
            current_kta = -current_loss
            
            if current_kta > best_kta:
                best_kta = current_kta
                best_params = current_params.copy()
        
        # Set best parameters
        self.quantum_kernel.update_parameters(best_params)
        
        # Compute final kernel for SVM training
        K_XZ_final = self.quantum_kernel.compute_kernel_matrix(X_processed, cached_anchors)
        K_ZZ_final = self.quantum_kernel.compute_kernel_matrix(cached_anchors, cached_anchors)
        K_ZZ_final += 1e-8 * np.eye(len(cached_anchors))
        K_ZZ_inv_final = np.linalg.pinv(K_ZZ_final)
        
        K_train = K_XZ_final @ K_ZZ_inv_final @ K_XZ_final.T
        
        # Train SVM classifier
        classifier = SVC(kernel='precomputed', random_state=42)
        classifier.fit(K_train, y_current)
        
        self.fitted = True
        
        # Training metrics
        train_predictions = classifier.predict(K_train)
        train_accuracy = np.mean(train_predictions == y_current)
        
        results = {
            'train_accuracy': train_accuracy,
            'best_kta': best_kta,
            'final_weights_mean': np.mean(weights),
            'final_weights_std': np.std(weights),
            'n_anchors_used': len(cached_anchors),
            'spsa_iterations': self.spsa_iterations
        }
        
        self.training_history.append(results)
        return results, classifier
    
    def evaluate_window(self, 
                       X_test: np.ndarray, 
                       y_test: np.ndarray,
                       classifier,
                       X_train: np.ndarray) -> Dict[str, Any]:
        """
        Evaluate classifier on test data (cleaned up - removed unused train_kernel parameter).
        
        Args:
            X_test: Test features
            y_test: Test labels
            classifier: Trained classifier
            X_train: Training features (for Nyström transform)
            
        Returns:
            Evaluation metrics
        """
        # Preprocess test data using already fitted scaler
        X_test_processed = self.preprocess_features(X_test)
        X_train_processed = self.preprocess_features(X_train)
        
        # Transform using fitted Nyström (reuses cached anchors)
        Phi_train = self.nystrom.transform(X_train_processed, self.quantum_kernel.compute_kernel_matrix)
        Phi_test = self.nystrom.transform(X_test_processed, self.quantum_kernel.compute_kernel_matrix)
        
        # Build test kernel
        K_test = Phi_test @ Phi_train.T
        
        # Make predictions
        predictions = classifier.predict(K_test)
        
        # Compute metrics
        balanced_acc = balanced_accuracy_score(y_test, predictions)
        macro_f1 = f1_score(y_test, predictions, average='macro')
        accuracy = np.mean(predictions == y_test)
        
        return {
            'accuracy': accuracy,
            'balanced_accuracy': balanced_acc,
            'macro_f1': macro_f1,
            'predictions': predictions
        }


class ComprehensiveExperimentRunner:
    """
    Comprehensive experiment runner with proper streaming baselines,
    real-world datasets, and statistical testing.
    """
    
    def __init__(self, output_dir: Optional[str] = None):
        """
        Initialize experiment runner.
        
        Args:
            output_dir: Directory for saving results
        """
        self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.output_dir = output_dir or f"comprehensive_results_{self.timestamp}"
        os.makedirs(self.output_dir, exist_ok=True)
        
        self.results = {}
        
    def run_streaming_baseline_comparison(self,
                                        datasets: List,
                                        window_size: int = 200,
                                        n_seeds: int = 5) -> Dict[str, Any]:
        """
        Run comprehensive comparison with proper streaming baselines.
        
        Args:
            datasets: List of datasets to evaluate on
            window_size: Size of evaluation windows
            n_seeds: Number of random seeds for statistical reliability
            
        Returns:
            Comprehensive results with statistical analysis
        """
        print("Running comprehensive streaming baseline comparison...")
        
        all_results = {}
        
        for dataset in datasets:
            print(f"\nEvaluating on {dataset.name}...")
            dataset_results = {}
            
            for seed in range(n_seeds):
                print(f"  Seed {seed + 1}/{n_seeds}")
                np.random.seed(seed)
                
                # Load dataset
                X, y = dataset.load_data()
                
                # Shuffle with fixed seed for reproducibility
                indices = np.random.permutation(len(X))
                X, y = X[indices], y[indices]
                
                seed_results = {}
                
                # Test proper streaming baselines
                streaming_baselines = get_streaming_baselines()
                
                for baseline in streaming_baselines:
                    try:
                        baseline_results = evaluate_streaming_baseline(
                            baseline, X, y, window_size
                        )
                        seed_results[baseline.name.lower().replace(' ', '_')] = baseline_results
                    except Exception as e:
                        print(f"    Warning: {baseline.name} failed: {e}")
                        continue
                
                # Test QISK
                try:
                    qisk_results = self._evaluate_qisk_streaming(X, y, window_size)
                    seed_results['qisk'] = qisk_results
                except Exception as e:
                    print(f"    Warning: QISK failed: {e}")
                
                dataset_results[f'seed_{seed}'] = seed_results
            
            all_results[dataset.name.lower().replace(' ', '_')] = dataset_results
        
        # Compute statistical summaries
        summary_results = self._compute_statistical_summary(all_results)
        
        # Save comprehensive results
        results_path = os.path.join(self.output_dir, 'comprehensive_results.json')
        with open(results_path, 'w') as f:
            json.dump({
                'summary': summary_results,
                'detailed': all_results,
                'metadata': {
                    'timestamp': self.timestamp,
                    'window_size': window_size,
                    'n_seeds': n_seeds,
                    'datasets': [d.name for d in datasets]
                }
            }, f, indent=2, default=str)
        
        print(f"\nResults saved to {results_path}")
        return summary_results
    
    def _evaluate_qisk_streaming(self, 
                                      X: np.ndarray, 
                                      y: np.ndarray,
                                      window_size: int = 200) -> Dict[str, Any]:
        """Evaluate QISK with proper streaming protocol."""
        qisk = QISK()
        
        n_windows = len(X) // window_size
        window_results = []
        X_history = []
        
        for i in range(n_windows):
            start_idx = i * window_size
            end_idx = (i + 1) * window_size
            
            X_window = X[start_idx:end_idx]
            y_window = y[start_idx:end_idx]
            
            if len(np.unique(y_window)) < 2:
                continue
            
            # Split window for training and testing
            split_idx = int(0.8 * len(X_window))
            X_train_win = X_window[:split_idx]
            y_train_win = y_window[:split_idx]
            X_test_win = X_window[split_idx:]
            y_test_win = y_window[split_idx:]
            
            # Train on current window with history
            train_results, classifier = qisk.fit_window(
                X_train_win, y_train_win, X_history
            )
            
            # Evaluate on test portion
            eval_results = qisk.evaluate_window(
                X_test_win, y_test_win, classifier, X_train_win
            )
            
            window_results.append({
                'window': i,
                'train_size': len(X_train_win),
                'test_size': len(X_test_win),
                'accuracy': eval_results['accuracy'],
                'macro_f1': eval_results['macro_f1'],
                'kta': train_results['best_kta']
            })
            
            # Add to history (keep last 3 windows)
            X_history.append(X_train_win)
            if len(X_history) > 3:
                X_history.pop(0)
        
        # Aggregate results
        accuracies = [r['accuracy'] for r in window_results]
        f1_scores = [r['macro_f1'] for r in window_results]
        kta_scores = [r['kta'] for r in window_results]
        
        return {
            'method': 'QISK',
            'mean_accuracy': np.mean(accuracies),
            'macro_f1': np.mean(f1_scores),
            'worst_window_accuracy': np.min(accuracies),
            'best_window_accuracy': np.max(accuracies),
            'window_accuracies': accuracies,
            'window_f1_scores': f1_scores,
            'kta_accuracy_correlation': np.corrcoef(kta_scores, accuracies)[0, 1] if len(kta_scores) > 1 else 0,
            'n_windows': len(window_results)
        }
    
    def _compute_statistical_summary(self, all_results: Dict) -> Dict[str, Any]:
        """Compute statistical summary with confidence intervals."""
        from scipy import stats
        
        summary = {}
        
        for dataset_name, dataset_results in all_results.items():
            dataset_summary = {}
            
            # Collect results across seeds
            method_results = {}
            for seed_key, seed_results in dataset_results.items():
                for method_name, method_result in seed_results.items():
                    if method_name not in method_results:
                        method_results[method_name] = {
                            'mean_accuracy': [],
                            'macro_f1': [],
                            'worst_window_accuracy': [],
                            'kta_accuracy_correlation': []
                        }
                    
                    method_results[method_name]['mean_accuracy'].append(
                        method_result['mean_accuracy']
                    )
                    method_results[method_name]['macro_f1'].append(
                        method_result['macro_f1']
                    )
                    method_results[method_name]['worst_window_accuracy'].append(
                        method_result['worst_window_accuracy']
                    )
                    
                    # KTA correlation only for quantum methods
                    if 'kta_accuracy_correlation' in method_result:
                        method_results[method_name]['kta_accuracy_correlation'].append(
                            method_result['kta_accuracy_correlation']
                        )
            
            # Compute statistics for each method
            for method_name, results in method_results.items():
                method_stats = {}
                
                for metric_name, values in results.items():
                    if len(values) == 0:
                        continue
                        
                    values_array = np.array(values)
                    mean_val = np.mean(values_array)
                    std_val = np.std(values_array, ddof=1)
                    se_val = std_val / np.sqrt(len(values_array))
                    
                    # 95% confidence interval
                    ci_95 = stats.t.interval(
                        0.95, len(values_array) - 1, 
                        loc=mean_val, scale=se_val
                    )
                    
                    method_stats[metric_name] = {
                        'mean': mean_val,
                        'std': std_val,
                        'se': se_val,
                        'ci_95_lower': ci_95[0],
                        'ci_95_upper': ci_95[1],
                        'values': values_array.tolist()
                    }
                
                dataset_summary[method_name] = method_stats
            
            summary[dataset_name] = dataset_summary
        
        return summary


if __name__ == "__main__":
    # Example usage
    from real_world_datasets import get_real_world_datasets
    
    # Get datasets (start with a subset for testing)
    datasets = get_real_world_datasets()[:2]  # First 2 datasets for testing
    
    # Run comprehensive experiments  
    runner = ComprehensiveExperimentRunner()
    results = runner.run_streaming_baseline_comparison(
        datasets=datasets,
        window_size=200,
        n_seeds=3  # Reduced for testing
    )
    
    print("\nExperiment completed!")
    print(f"Results saved in: {runner.output_dir}")