"""
Enhanced QISK with advanced drift detection ensemble and quantum kernel ensemble.
This version implements state-of-the-art techniques for significant performance improvements.
"""

import numpy as np
from typing import List, Dict, Any, Tuple, Optional
import warnings
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import balanced_accuracy_score, f1_score
from sklearn.cluster import MiniBatchKMeans
from sklearn.gaussian_process.kernels import RBF
import scipy.stats as stats
warnings.filterwarnings('ignore')

from physically_correct_quantum_kernel import PhysicallyCorrectQuantumKernel


class AdvancedDriftDetector:
    """
    Multi-detector ensemble for advanced drift detection.
    Combines multiple drift detection methods for superior performance.
    """
    
    def __init__(self):
        self.detectors = {
            'statistical': self._statistical_drift_detector(),
            'distribution': self._distribution_drift_detector(), 
            'error_rate': self._error_rate_drift_detector()
        }
        self.drift_scores = []
        self.ensemble_threshold = 0.6
        
    def _statistical_drift_detector(self):
        """Kolmogorov-Smirnov and Mann-Whitney U test based detector."""
        return {
            'window_data': [],
            'reference_data': None,
            'type': 'statistical'
        }
    
    def _distribution_drift_detector(self):
        """Wasserstein distance based detector.""" 
        return {
            'window_data': [],
            'reference_data': None,
            'type': 'distribution'
        }
    
    def _error_rate_drift_detector(self):
        """Error rate change detector (EDDM-style)."""
        return {
            'errors': [],
            'recent_error_rate': 0.0,
            'baseline_error_rate': 0.0,
            'type': 'error_rate'
        }
    
    def detect_drift(self, X_new: np.ndarray, y_true: np.ndarray = None, 
                     y_pred: np.ndarray = None) -> Dict[str, float]:
        """
        Detect drift using ensemble of methods.
        
        Returns:
            Dict with individual detector scores and ensemble decision.
        """
        scores = {}
        
        # Statistical drift detection
        if self.detectors['statistical']['reference_data'] is not None:
            ref_data = self.detectors['statistical']['reference_data']
            # Multi-dimensional KS test
            ks_pvals = []
            for i in range(X_new.shape[1]):
                _, p_val = stats.ks_2samp(ref_data[:, i], X_new[:, i])
                ks_pvals.append(p_val)
            scores['statistical'] = 1.0 - np.mean(ks_pvals)  # Higher score = more drift
        else:
            scores['statistical'] = 0.0
            
        # Distribution drift detection (simplified Wasserstein)
        if self.detectors['distribution']['reference_data'] is not None:
            ref_data = self.detectors['distribution']['reference_data']
            # Compute mean distance across features
            distances = []
            for i in range(X_new.shape[1]):
                distance = stats.wasserstein_distance(ref_data[:, i], X_new[:, i])
                distances.append(distance)
            scores['distribution'] = np.mean(distances)
        else:
            scores['distribution'] = 0.0
            
        # Error rate drift detection
        if y_true is not None and y_pred is not None:
            current_error_rate = np.mean(y_true != y_pred)
            if self.detectors['error_rate']['baseline_error_rate'] > 0:
                error_change = abs(current_error_rate - self.detectors['error_rate']['baseline_error_rate'])
                scores['error_rate'] = error_change / (self.detectors['error_rate']['baseline_error_rate'] + 0.01)
            else:
                scores['error_rate'] = 0.0
                self.detectors['error_rate']['baseline_error_rate'] = current_error_rate
        else:
            scores['error_rate'] = 0.0
            
        # Ensemble decision
        ensemble_score = np.mean(list(scores.values()))
        scores['ensemble'] = ensemble_score
        scores['drift_detected'] = ensemble_score > self.ensemble_threshold
        
        return scores
    
    def update_reference(self, X: np.ndarray, y: np.ndarray = None):
        """Update reference data for drift detection."""
        # Keep recent data as reference (sliding window approach)
        max_reference_size = 500
        
        for detector_name in ['statistical', 'distribution']:
            if self.detectors[detector_name]['reference_data'] is None:
                self.detectors[detector_name]['reference_data'] = X.copy()
            else:
                # Concatenate and keep most recent samples
                ref_data = self.detectors[detector_name]['reference_data']
                combined = np.vstack([ref_data, X])
                if len(combined) > max_reference_size:
                    combined = combined[-max_reference_size:]
                self.detectors[detector_name]['reference_data'] = combined


class QuantumKernelEnsemble:
    """
    Ensemble of different quantum-inspired kernels for improved performance.
    """
    
    def __init__(self, n_qubits: int = 4, n_kernels: int = 3):
        self.n_qubits = n_qubits
        self.n_kernels = n_kernels
        self.kernels = []
        self.weights = np.ones(n_kernels) / n_kernels
        
        # Create diverse quantum kernels
        for i in range(n_kernels):
            kernel = PhysicallyCorrectQuantumKernel(
                n_qubits=n_qubits,
                feature_scaling=1.0 + 0.5 * i,  # Different scaling
                random_seed=42 + i
            )
            self.kernels.append(kernel)
    
    def compute_kernel_matrix(self, X1: np.ndarray, X2: np.ndarray = None) -> np.ndarray:
        """Compute weighted ensemble kernel matrix."""
        if X2 is None:
            X2 = X1
            
        # Compute individual kernel matrices
        kernel_matrices = []
        for kernel in self.kernels:
            K = kernel.compute_kernel_matrix(X1, X2)
            kernel_matrices.append(K)
        
        # Weighted combination
        ensemble_kernel = np.zeros_like(kernel_matrices[0])
        for i, K in enumerate(kernel_matrices):
            ensemble_kernel += self.weights[i] * K
            
        return ensemble_kernel
    
    def optimize_kernel_weights(self, X: np.ndarray, y: np.ndarray, 
                               sample_weights: np.ndarray = None) -> float:
        """Optimize kernel combination weights using KTA."""
        if sample_weights is None:
            sample_weights = np.ones(len(X)) / len(X)
            
        # Compute individual KTAs
        ktas = []
        for kernel in self.kernels:
            K = kernel.compute_kernel_matrix(X, X)
            kta = self._weighted_kta(K, y, sample_weights)
            ktas.append(kta)
        
        # Simple weight update based on KTA scores
        ktas = np.array(ktas)
        ktas = np.maximum(ktas, 0.01)  # Prevent negative weights
        self.weights = ktas / np.sum(ktas)
        
        # Return ensemble KTA
        K_ensemble = self.compute_kernel_matrix(X, X)
        return self._weighted_kta(K_ensemble, y, sample_weights)
    
    def _weighted_kta(self, K: np.ndarray, y: np.ndarray, weights: np.ndarray) -> float:
        """Compute weighted kernel-target alignment."""
        y_pm = 2 * np.array(y) - 1
        w = weights / (np.sum(weights) + 1e-12)
        
        # Weighted centering
        K_mean = np.sum(w[:, None] * K, axis=0)
        K_centered = K - K_mean[None, :] - K_mean[:, None] + np.sum(w * K_mean)
        
        y_mean = np.sum(w * y_pm)
        y_centered = y_pm - y_mean
        Y = np.outer(y_centered, y_centered)
        
        # Weighted alignment
        W = np.outer(w, w)
        num = np.sum(W * K_centered * Y)
        den = np.sqrt(np.sum(W * K_centered * K_centered) * np.sum(W * Y * Y)) + 1e-12
        
        return num / den


class AdvancedDROLite:
    """
    Advanced importance weighting with multiple density ratio estimation methods.
    """
    
    def __init__(self, methods: List[str] = None):
        self.methods = methods or ['logistic', 'kmm', 'residual']
        self.estimators = {}
        self.ensemble_weights = np.ones(len(self.methods)) / len(self.methods)
        
    def estimate_density_ratios(self, X_prev: np.ndarray, X_curr: np.ndarray) -> np.ndarray:
        """Estimate density ratios using ensemble of methods."""
        ratios_list = []
        
        for method in self.methods:
            if method == 'logistic':
                ratios = self._logistic_density_ratios(X_prev, X_curr)
            elif method == 'kmm':
                ratios = self._kmm_density_ratios(X_prev, X_curr) 
            elif method == 'residual':
                ratios = self._residual_density_ratios(X_prev, X_curr)
            else:
                ratios = np.ones(len(X_curr))
                
            ratios_list.append(ratios)
        
        # Ensemble combination
        ensemble_ratios = np.zeros(len(X_curr))
        for i, ratios in enumerate(ratios_list):
            ensemble_ratios += self.ensemble_weights[i] * ratios
            
        # Robust clipping
        ensemble_ratios = np.clip(ensemble_ratios, 0.1, 10.0)
        return ensemble_ratios
    
    def _logistic_density_ratios(self, X_prev: np.ndarray, X_curr: np.ndarray) -> np.ndarray:
        """Original logistic discriminator approach."""
        from sklearn.linear_model import LogisticRegression
        
        # Combine datasets
        X_combined = np.vstack([X_prev, X_curr])
        y_combined = np.hstack([np.zeros(len(X_prev)), np.ones(len(X_curr))])
        
        # Train discriminator
        discriminator = LogisticRegression(max_iter=1000, random_state=42)
        discriminator.fit(X_combined, y_combined)
        
        # Estimate ratios for current data
        probs = discriminator.predict_proba(X_curr)[:, 1]
        ratios = probs / (1 - probs + 1e-12)
        
        return ratios
    
    def _kmm_density_ratios(self, X_prev: np.ndarray, X_curr: np.ndarray) -> np.ndarray:
        """Kernel Mean Matching density ratio estimation."""
        # Simplified KMM implementation
        from sklearn.metrics.pairwise import rbf_kernel
        
        # Compute kernel matrices
        K_prev_prev = rbf_kernel(X_prev, X_prev)
        K_curr_prev = rbf_kernel(X_curr, X_prev)
        
        # Solve for weights (simplified version)
        n_prev = len(X_prev)
        n_curr = len(X_curr)
        
        try:
            # Regularized least squares solution
            A = K_prev_prev + 1e-6 * np.eye(n_prev)
            b = np.mean(K_curr_prev, axis=0) * n_prev / n_curr
            weights = np.linalg.solve(A, b)
            
            # Map weights to current samples
            ratios = np.mean(K_curr_prev * weights[None, :], axis=1)
            ratios = np.maximum(ratios, 0.1)
            
        except:
            # Fallback to uniform weights
            ratios = np.ones(len(X_curr))
            
        return ratios
    
    def _residual_density_ratios(self, X_prev: np.ndarray, X_curr: np.ndarray) -> np.ndarray:
        """Residual-based density ratio estimation.""" 
        # Simplified residual approach using k-NN density estimates
        from sklearn.neighbors import NearestNeighbors
        
        k = min(10, len(X_prev) // 4)
        if k < 1:
            return np.ones(len(X_curr))
            
        # Fit k-NN on previous data
        nn_prev = NearestNeighbors(n_neighbors=k)
        nn_prev.fit(X_prev)
        
        # Fit k-NN on current data  
        nn_curr = NearestNeighbors(n_neighbors=k)
        nn_curr.fit(X_curr)
        
        # Estimate densities
        dists_prev, _ = nn_prev.kneighbors(X_curr)
        dists_curr, _ = nn_curr.kneighbors(X_curr)
        
        # Density estimates (inverse of k-th nearest neighbor distance)
        density_prev = 1.0 / (dists_prev[:, -1] + 1e-12)
        density_curr = 1.0 / (dists_curr[:, -1] + 1e-12)
        
        # Density ratios
        ratios = density_curr / (density_prev + 1e-12)
        
        return ratios


class EnhancedQISK:
    """
    Enhanced QISK with advanced drift detection, quantum kernel ensemble, 
    and sophisticated importance weighting.
    """
    
    def __init__(self, n_qubits: int = 4, n_anchors: int = 32, 
                 advanced_features: bool = True):
        self.n_qubits = n_qubits
        self.n_anchors = n_anchors
        self.advanced_features = advanced_features
        
        # Core components
        if advanced_features:
            self.quantum_kernels = QuantumKernelEnsemble(n_qubits=n_qubits, n_kernels=3)
            self.drift_detector = AdvancedDriftDetector()
            self.dro_lite = AdvancedDROLite(['logistic', 'kmm', 'residual'])
        else:
            # Fallback to basic QISK
            self.quantum_kernel = PhysicallyCorrectQuantumKernel(n_qubits=n_qubits)
            
        # Nyström approximation
        self.anchors = None
        self.anchor_weights = None
        self.scaler = StandardScaler()
        
        # Training history
        self.training_history = {
            'X': [],
            'y': [],
            'sample_weights': [],
            'drift_scores': [],
            'kta_scores': []
        }
        
        # SVM classifier
        self.svm = None
        
    def fit(self, X: np.ndarray, y: np.ndarray, 
            X_prev: np.ndarray = None) -> 'EnhancedQISK':
        """
        Fit Enhanced QISK with advanced drift adaptation.
        
        Args:
            X: Current training data
            y: Current labels  
            X_prev: Previous data for drift detection
            
        Returns:
            self
        """
        
        # Scale features
        X_scaled = self.scaler.fit_transform(X)
        if X_prev is not None:
            X_prev_scaled = self.scaler.transform(X_prev)
        
        # Detect drift and compute importance weights
        sample_weights = np.ones(len(X))
        drift_info = {'ensemble': 0.0}
        
        if self.advanced_features and X_prev is not None:
            # Advanced drift detection
            drift_info = self.drift_detector.detect_drift(X_scaled, y_true=None)
            
            # Advanced importance weighting if drift detected
            if drift_info['drift_detected']:
                sample_weights = self.dro_lite.estimate_density_ratios(X_prev_scaled, X_scaled)
                
            # Update drift detector's reference
            self.drift_detector.update_reference(X_scaled, y)
        
        # Select or update Nyström anchors
        self._update_anchors(X_scaled, sample_weights)
        
        # Compute kernel features
        K_features = self._compute_kernel_features(X_scaled)
        
        # Optimize kernel parameters using weighted KTA
        kta_score = 0.0
        if self.advanced_features:
            kta_score = self.quantum_kernels.optimize_kernel_weights(
                X_scaled, y, sample_weights
            )
            # Recompute features with optimized kernels
            K_features = self._compute_kernel_features(X_scaled)
        
        # Train SVM on kernel features
        self.svm = SVC(kernel='linear', probability=True, random_state=42)
        
        # Handle class imbalance with sample weights
        try:
            self.svm.fit(K_features, y, sample_weight=sample_weights)
        except:
            # Fallback without sample weights
            self.svm.fit(K_features, y)
        
        # Store training information
        self.training_history['X'].append(X_scaled.copy())
        self.training_history['y'].append(y.copy())
        self.training_history['sample_weights'].append(sample_weights.copy())
        self.training_history['drift_scores'].append(drift_info['ensemble'])
        self.training_history['kta_scores'].append(kta_score)
        
        # Limit history size
        max_history = 5
        for key in self.training_history:
            if len(self.training_history[key]) > max_history:
                self.training_history[key] = self.training_history[key][-max_history:]
        
        return self
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """Make predictions."""
        if self.svm is None:
            raise ValueError("Model not fitted yet")
            
        X_scaled = self.scaler.transform(X)
        K_features = self._compute_kernel_features(X_scaled)
        return self.svm.predict(K_features)
    
    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        """Predict class probabilities."""
        if self.svm is None:
            raise ValueError("Model not fitted yet")
            
        X_scaled = self.scaler.transform(X)
        K_features = self._compute_kernel_features(X_scaled)
        return self.svm.predict_proba(K_features)
    
    def _update_anchors(self, X: np.ndarray, sample_weights: np.ndarray):
        """Update Nyström anchors using weighted k-means."""
        if self.anchors is None or len(self.anchors) < self.n_anchors:
            # Initial anchor selection
            kmeans = MiniBatchKMeans(
                n_clusters=min(self.n_anchors, len(X)),
                random_state=42,
                batch_size=min(100, len(X))
            )
            kmeans.fit(X, sample_weight=sample_weights)
            self.anchors = kmeans.cluster_centers_
        else:
            # Adaptive anchor update (keep most important anchors)
            # For simplicity, we'll refresh periodically
            if np.random.random() < 0.1:  # 10% chance to refresh
                kmeans = MiniBatchKMeans(
                    n_clusters=min(self.n_anchors, len(X)),
                    random_state=42
                )
                kmeans.fit(X, sample_weight=sample_weights)
                self.anchors = kmeans.cluster_centers_
    
    def _compute_kernel_features(self, X: np.ndarray) -> np.ndarray:
        """Compute Nyström kernel features."""
        if self.anchors is None:
            raise ValueError("Anchors not initialized")
            
        if self.advanced_features:
            # Enhanced quantum kernel ensemble
            K_xa = self.quantum_kernels.compute_kernel_matrix(X, self.anchors)
            K_aa = self.quantum_kernels.compute_kernel_matrix(self.anchors, self.anchors)
        else:
            # Basic quantum kernel
            K_xa = self.quantum_kernel.compute_kernel_matrix(X, self.anchors)
            K_aa = self.quantum_kernel.compute_kernel_matrix(self.anchors, self.anchors)
        
        # Nyström approximation
        try:
            # Regularized inverse
            K_aa_inv = np.linalg.inv(K_aa + 1e-6 * np.eye(len(K_aa)))
            features = K_xa @ K_aa_inv @ K_xa.T
            # Take Cholesky decomposition for linear features
            L = np.linalg.cholesky(features + 1e-6 * np.eye(len(features)))
            return L
        except:
            # Fallback to direct kernel features
            return K_xa
        
    def get_training_diagnostics(self) -> Dict[str, Any]:
        """Get diagnostic information about training process."""
        return {
            'n_training_rounds': len(self.training_history['drift_scores']),
            'recent_drift_scores': self.training_history['drift_scores'][-3:],
            'recent_kta_scores': self.training_history['kta_scores'][-3:],
            'anchor_count': len(self.anchors) if self.anchors is not None else 0,
            'advanced_features_enabled': self.advanced_features
        }