"""
Physically correct quantum kernel implementation.
Replaces the fake "entanglement" with a proper RY product-state feature map.
"""

import numpy as np
from typing import Union, Tuple, Optional
from sklearn.base import BaseEstimator
from sklearn.cluster import MiniBatchKMeans
from scipy.linalg import svd

class PhysicallyCorrectQuantumKernel(BaseEstimator):
    """
    A physically correct quantum kernel using RY product-state feature map.
    
    The kernel is based on the fidelity between quantum states:
    |⟨ψ(x)|ψ(z)⟩|² where |ψ(x)⟩ = ∏_i cos(θᵢ(x)/2)|0⟩ + sin(θᵢ(x)/2)|1⟩
    
    For product states: |⟨ψ(x)|ψ(z)⟩|² = ∏_i cos²((θᵢ(x)-θᵢ(z))/2)
    """
    
    def __init__(self, n_qubits: int = 4, feature_scale: float = 1.0, 
                 trainable_params: Optional[np.ndarray] = None):
        """
        Initialize physically correct quantum kernel.
        
        Args:
            n_qubits: Number of qubits (feature dimensions)
            feature_scale: Scale factor for feature mapping
            trainable_params: Optional trainable parameters for feature map
        """
        self.n_qubits = n_qubits
        self.feature_scale = feature_scale
        self.trainable_params = trainable_params
        if self.trainable_params is None:
            # Initialize multiplicative parameters for each qubit (start from 1.0)
            self.trainable_params = np.ones(n_qubits)
    
    def _preprocess_features(self, X: np.ndarray) -> np.ndarray:
        """Preprocess input features to match n_qubits dimensions using PCA."""
        X = np.atleast_2d(X)
        
        if X.shape[1] > self.n_qubits:
            # Use PCA to reduce to n_qubits dimensions
            if not hasattr(self, '_pca') or self._pca is None:
                from sklearn.decomposition import PCA
                self._pca = PCA(n_components=self.n_qubits, random_state=42)
                # Fit PCA on current data
                X = self._pca.fit_transform(X)
            else:
                # Transform using fitted PCA
                X = self._pca.transform(X)
        elif X.shape[1] < self.n_qubits:
            # Zero-pad to n_qubits
            padding = np.zeros((X.shape[0], self.n_qubits - X.shape[1]))
            X = np.concatenate([X, padding], axis=1)
        
        return X
    
    def _compute_angles(self, X: np.ndarray) -> np.ndarray:
        """Compute rotation angles for RY gates."""
        X = self._preprocess_features(X)
        # θᵢ = scale * (xᵢ * trainable_param_i) - multiplicative so differences depend on params
        angles = self.feature_scale * (X * self.trainable_params[None, :])
        return angles
    
    def compute_kernel_matrix(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray:
        """
        Compute the quantum kernel matrix using physically correct fidelity.
        
        Args:
            X: Input data matrix (n_samples, n_features)
            Y: Optional second data matrix (m_samples, n_features)
            
        Returns:
            Kernel matrix of shape (n_samples, m_samples) or (n_samples, n_samples)
        """
        if Y is None:
            Y = X
        
        angles_X = self._compute_angles(X)
        angles_Y = self._compute_angles(Y)
        
        # Compute pairwise angle differences
        # Shape: (n_samples_X, n_samples_Y, n_qubits)
        angle_diffs = angles_X[:, None, :] - angles_Y[None, :, :]
        
        # Quantum fidelity for product states: ∏ᵢ cos²(Δθᵢ/2)
        # Shape: (n_samples_X, n_samples_Y)
        fidelities = np.prod(np.cos(0.5 * angle_diffs) ** 2, axis=2)
        
        # Ensure positive semi-definite and numerical stability
        fidelities = np.clip(fidelities, 1e-12, 1.0)
        
        return fidelities
    
    def update_parameters(self, new_params: np.ndarray):
        """Update trainable parameters (for SPSA optimization)."""
        self.trainable_params = np.array(new_params)
    
    def get_parameters(self) -> np.ndarray:
        """Get current trainable parameters."""
        return self.trainable_params.copy()

class StreamingNystromApproximation:
    """Streaming Nyström approximation with physically correct quantum kernel."""
    
    def __init__(self, quantum_kernel: PhysicallyCorrectQuantumKernel, 
                 n_anchors: int = 16, anchor_strategy: str = 'kmeans'):
        """
        Initialize Nyström approximation.
        
        Args:
            quantum_kernel: Quantum kernel instance
            n_anchors: Number of anchor points for approximation
            anchor_strategy: Strategy for anchor selection ('kmeans' or 'uniform')
        """
        self.quantum_kernel = quantum_kernel
        self.n_anchors = n_anchors
        self.anchor_strategy = anchor_strategy
        self.anchors = None
        self.K_ZZ_inv = None
    
    def select_anchors(self, X: np.ndarray) -> np.ndarray:
        """Select anchor points from data using specified strategy."""
        n_samples = X.shape[0]
        if n_samples <= self.n_anchors:
            return X.copy()
        
        if self.anchor_strategy == 'kmeans':
            # Use MiniBatchKMeans for better anchor selection under drift
            kmeans = MiniBatchKMeans(n_clusters=self.n_anchors, random_state=42, 
                                   batch_size=min(100, n_samples))
            kmeans.fit(X)
            return kmeans.cluster_centers_
        elif self.anchor_strategy == 'uniform':
            # Uniform sampling of anchor indices (fallback)
            anchor_indices = np.random.choice(n_samples, self.n_anchors, replace=False)
            return X[anchor_indices]
        else:
            raise ValueError(f"Unknown anchor strategy: {self.anchor_strategy}")
    
    def fit(self, X: np.ndarray):
        """Fit Nyström approximation on data."""
        self.anchors = self.select_anchors(X)
        
        # Compute K_ZZ and its pseudoinverse
        K_ZZ = self.quantum_kernel.compute_kernel_matrix(self.anchors, self.anchors)
        
        # Add small regularization for numerical stability
        K_ZZ += 1e-8 * np.eye(K_ZZ.shape[0])
        
        # Compute pseudoinverse using SVD for stability
        try:
            self.K_ZZ_inv = np.linalg.pinv(K_ZZ)
        except np.linalg.LinAlgError:
            # Fallback to SVD
            U, s, Vt = svd(K_ZZ)
            s_inv = np.where(s > 1e-10, 1.0 / s, 0.0)
            self.K_ZZ_inv = Vt.T @ np.diag(s_inv) @ U.T
    
    def transform(self, X: np.ndarray) -> np.ndarray:
        """Transform data using Nyström approximation."""
        if self.anchors is None or self.K_ZZ_inv is None:
            raise ValueError("Must call fit() before transform()")
        
        # Compute K_XZ
        K_XZ = self.quantum_kernel.compute_kernel_matrix(X, self.anchors)
        
        # Nyström approximation: K̃ = K_XZ @ K_ZZ^(-1) @ K_XZ^T
        # But we return the feature map: Φ(X) = K_XZ @ K_ZZ^(-1/2)
        
        # For simplicity, return K_XZ @ K_ZZ_inv^(1/2) approximation
        # This gives a feature map such that Φ(X) @ Φ(X)^T ≈ K̃
        
        U, s, Vt = svd(self.K_ZZ_inv)
        sqrt_K_ZZ_inv = U @ np.diag(np.sqrt(np.maximum(s, 0))) @ Vt
        
        feature_map = K_XZ @ sqrt_K_ZZ_inv
        return feature_map
    
    def get_kernel_matrix(self, X: np.ndarray) -> np.ndarray:
        """Return the Nyström-approximated kernel matrix for X as Φ(X) Φ(X)^T."""
        Phi = self.transform(X)
        return Phi @ Phi.T
    
    def kernel_fidelity(self, X: np.ndarray) -> float:
        """Compute approximation fidelity vs exact kernel."""
        if self.anchors is None:
            return 0.0
        
        # Sample small subset for efficiency
        n_test = min(50, X.shape[0])
        test_indices = np.random.choice(X.shape[0], n_test, replace=False)
        X_test = X[test_indices]
        
        # Exact kernel
        K_exact = self.quantum_kernel.compute_kernel_matrix(X_test, X_test)
        
        # Approximate kernel  
        K_XZ = self.quantum_kernel.compute_kernel_matrix(X_test, self.anchors)
        K_approx = K_XZ @ self.K_ZZ_inv @ K_XZ.T
        
        # Frobenius norm relative error
        error = np.linalg.norm(K_exact - K_approx, 'fro')
        exact_norm = np.linalg.norm(K_exact, 'fro')
        
        fidelity = 1.0 - (error / (exact_norm + 1e-10))
        return max(0.0, fidelity)
    
    def approximation_fidelity(self, X: np.ndarray) -> float:
        """Alias for kernel_fidelity for backward compatibility."""
        return self.kernel_fidelity(X)

def test_physically_correct_kernel():
    """Test the physically correct quantum kernel."""
    print("Testing Physically Correct Quantum Kernel")
    print("=" * 50)
    
    # Create test data
    np.random.seed(42)
    X = np.random.randn(20, 4)
    Y = np.random.randn(15, 4)
    
    # Test kernel
    kernel = PhysicallyCorrectQuantumKernel(n_qubits=4, feature_scale=1.0)
    
    # Compute kernel matrices
    K_XX = kernel.compute_kernel_matrix(X, X)
    K_XY = kernel.compute_kernel_matrix(X, Y)
    
    print(f"K_XX shape: {K_XX.shape}")
    print(f"K_XY shape: {K_XY.shape}")
    print(f"K_XX diagonal (should be 1.0): {K_XX.diagonal()[:5]}")
    print(f"K_XX range: [{K_XX.min():.6f}, {K_XX.max():.6f}]")
    print(f"K_XX is PSD: {np.all(np.linalg.eigvals(K_XX) >= -1e-10)}")
    
    # Test Nyström approximation
    print("\nTesting Nyström Approximation:")
    nystrom = StreamingNystromApproximation(kernel, n_anchors=8)
    nystrom.fit(X)
    
    feature_map = nystrom.transform(X)
    print(f"Feature map shape: {feature_map.shape}")
    
    fidelity = nystrom.kernel_fidelity(X)
    print(f"Approximation fidelity: {fidelity:.3f}")
    
    print("✅ Physically correct kernel tests passed!")

if __name__ == "__main__":
    test_physically_correct_kernel()