"""
GPU-Accelerated Canonical Correlation Analysis (CCA) Implementations

This module provides multiple GPU-accelerated CCA implementations:
1. PyTorch-based CCA with numerical stability (recommended)
2. CuPy-based CCA for maximum NumPy compatibility
3. Deep CCA loss function for end-to-end learning

References:
- Hotelling, H. (1936). Relations between two sets of variates
- Andrew et al. (2013). Deep Canonical Correlation Analysis (DCCA)
- Raghu et al. (2017). SVCCA: Singular Vector CCA for Deep Learning
"""

import numpy as np
import warnings
from typing import Optional, Tuple, Union

try:
    import torch
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    warnings.warn("PyTorch not available. GPU CCA will not work.")

try:
    import cupy as cp
    CUPY_AVAILABLE = True
except ImportError:
    CUPY_AVAILABLE = False
    warnings.warn("CuPy not available. CuPy-based CCA will not work.")


class TorchCCA:
    """
    GPU-accelerated CCA using PyTorch with numerical stability.

    This implementation uses regularization and stable matrix decomposition
    to handle ill-conditioned matrices common in high-dimensional data.

    Features:
    - GPU acceleration via PyTorch
    - Automatic fallback to CPU if GPU unavailable
    - Numerical stability via regularization
    - Memory-efficient batch processing

    Example:
        >>> cca = TorchCCA(n_components=64, reg_param=1e-4, device='cuda')
        >>> cca.fit(X_train, Y_train)
        >>> X_transformed, Y_transformed = cca.transform(X_test, Y_test)
    """

    def __init__(
        self,
        n_components: int = 64,
        reg_param: float = 1e-4,
        device: Optional[str] = None,
        dtype: torch.dtype = torch.float32,
        max_iter: int = 500,
        tol: float = 1e-6
    ):
        """
        Initialize TorchCCA.

        Args:
            n_components: Number of CCA components to compute
            reg_param: Regularization parameter for numerical stability
            device: 'cuda', 'cpu', or None (auto-detect)
            dtype: PyTorch dtype for computations
            max_iter: Maximum iterations for iterative algorithms
            tol: Convergence tolerance
        """
        if not TORCH_AVAILABLE:
            raise ImportError("PyTorch is required for TorchCCA")

        self.n_components = n_components
        self.reg_param = reg_param
        self.dtype = dtype
        self.max_iter = max_iter
        self.tol = tol

        # Auto-detect device
        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device)

        # Learned parameters
        self.wx_ = None  # X projection weights
        self.wy_ = None  # Y projection weights
        self.mean_x_ = None
        self.mean_y_ = None
        self.fitted_ = False

        print(f"TorchCCA initialized on device: {self.device}")

    def _to_tensor(self, X: np.ndarray) -> torch.Tensor:
        """Convert numpy array to torch tensor on the correct device."""
        return torch.from_numpy(X).to(dtype=self.dtype, device=self.device)

    def _to_numpy(self, X: torch.Tensor) -> np.ndarray:
        """Convert torch tensor to numpy array."""
        return X.detach().cpu().numpy()

    def fit(self, X: np.ndarray, Y: np.ndarray) -> 'TorchCCA':
        """
        Fit CCA model using training data.

        This uses the eigendecomposition method with regularization:
        1. Center the data
        2. Compute cross-covariance and auto-covariance matrices
        3. Apply regularization to avoid singular matrices
        4. Solve generalized eigenvalue problem
        5. Extract top n_components

        Args:
            X: Training data for first view (n_samples, n_features_x)
            Y: Training data for second view (n_samples, n_features_y)

        Returns:
            self: Fitted CCA model
        """
        if X.shape[0] != Y.shape[0]:
            raise ValueError(f"X and Y must have same number of samples. Got {X.shape[0]} and {Y.shape[0]}")

        n_samples = X.shape[0]
        n_components = min(
            self.n_components,
            X.shape[1],
            Y.shape[1],
            n_samples
        )

        if n_components < self.n_components:
            warnings.warn(
                f"n_components reduced to {n_components} "
                f"(min of n_samples={n_samples}, n_features_x={X.shape[1]}, "
                f"n_features_y={Y.shape[1]})"
            )
            self.n_components = n_components

        # Convert to tensors and move to GPU
        X_tensor = self._to_tensor(X)
        Y_tensor = self._to_tensor(Y)

        # Center the data
        self.mean_x_ = X_tensor.mean(dim=0, keepdim=True)
        self.mean_y_ = Y_tensor.mean(dim=0, keepdim=True)

        X_centered = X_tensor - self.mean_x_
        Y_centered = Y_tensor - self.mean_y_

        # Compute covariance matrices on GPU
        n = n_samples - 1

        # Cross-covariance: Cxy = X^T @ Y / (n-1)
        Cxy = (X_centered.T @ Y_centered) / n

        # Auto-covariances with regularization
        Cxx = (X_centered.T @ X_centered) / n
        Cyy = (Y_centered.T @ Y_centered) / n

        # Add regularization for numerical stability
        Cxx += self.reg_param * torch.eye(
            Cxx.shape[0], dtype=self.dtype, device=self.device
        )
        Cyy += self.reg_param * torch.eye(
            Cyy.shape[0], dtype=self.dtype, device=self.device
        )

        # Compute inverse square roots using Cholesky decomposition
        # This is more stable than direct inversion
        try:
            # Cholesky: Cxx = L @ L^T
            Lx = torch.linalg.cholesky(Cxx)
            Ly = torch.linalg.cholesky(Cyy)

            # Compute inv(L) efficiently
            Lx_inv = torch.linalg.inv(Lx)
            Ly_inv = torch.linalg.inv(Ly)

            # Cxx^(-1/2) = inv(L) @ inv(L)^T
            Cxx_inv_sqrt = Lx_inv @ Lx_inv.T
            Cyy_inv_sqrt = Ly_inv @ Ly_inv.T

        except RuntimeError:
            # Fallback to SVD if Cholesky fails
            warnings.warn("Cholesky decomposition failed, using SVD fallback")
            Cxx_inv_sqrt = self._matrix_inverse_sqrt_svd(Cxx)
            Cyy_inv_sqrt = self._matrix_inverse_sqrt_svd(Cyy)

        # Compute the matrix for eigendecomposition
        # M = Cxx^(-1/2) @ Cxy @ Cyy^(-1/2)
        M = Cxx_inv_sqrt @ Cxy @ Cyy_inv_sqrt

        # SVD: M = U @ S @ V^T
        # The singular values are the canonical correlations
        try:
            U, S, Vt = torch.linalg.svd(M, full_matrices=False)
        except RuntimeError as e:
            raise RuntimeError(f"SVD failed: {e}. Try increasing reg_param.")

        # Take top n_components
        U = U[:, :self.n_components]
        Vt = Vt[:self.n_components, :]
        S = S[:self.n_components]

        # Compute projection weights
        self.wx_ = Cxx_inv_sqrt @ U
        self.wy_ = Cyy_inv_sqrt @ Vt.T
        self.canonical_correlations_ = S

        self.fitted_ = True

        return self

    def _matrix_inverse_sqrt_svd(self, M: torch.Tensor) -> torch.Tensor:
        """
        Compute matrix inverse square root using SVD.
        M^(-1/2) = U @ S^(-1/2) @ U^T

        Args:
            M: Symmetric positive semi-definite matrix

        Returns:
            M^(-1/2)
        """
        U, S, _ = torch.linalg.svd(M, full_matrices=False)

        # Filter out small singular values for numerical stability
        eps = torch.finfo(self.dtype).eps * max(M.shape)
        S_inv_sqrt = torch.where(
            S > eps,
            1.0 / torch.sqrt(S),
            torch.zeros_like(S)
        )

        return U @ torch.diag(S_inv_sqrt) @ U.T

    def transform(
        self,
        X: np.ndarray,
        Y: Optional[np.ndarray] = None
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        """
        Transform data to CCA space.

        Args:
            X: Data for first view (n_samples, n_features_x)
            Y: Data for second view (n_samples, n_features_y), optional

        Returns:
            If Y is None: X_transformed
            If Y is provided: (X_transformed, Y_transformed)
        """
        if not self.fitted_:
            raise RuntimeError("CCA model must be fitted before transform")

        X_tensor = self._to_tensor(X)
        X_centered = X_tensor - self.mean_x_
        X_transformed = X_centered @ self.wx_

        if Y is None:
            return self._to_numpy(X_transformed)

        Y_tensor = self._to_tensor(Y)
        Y_centered = Y_tensor - self.mean_y_
        Y_transformed = Y_centered @ self.wy_

        return self._to_numpy(X_transformed), self._to_numpy(Y_transformed)

    def fit_transform(
        self,
        X: np.ndarray,
        Y: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Fit CCA model and transform data.

        Args:
            X: Training data for first view
            Y: Training data for second view

        Returns:
            (X_transformed, Y_transformed)
        """
        self.fit(X, Y)
        return self.transform(X, Y)

    def score(self, X: np.ndarray, Y: np.ndarray) -> float:
        """
        Compute average canonical correlation on test data.

        Args:
            X: Test data for first view
            Y: Test data for second view

        Returns:
            Average canonical correlation
        """
        X_c, Y_c = self.transform(X, Y)

        # Convert to tensors for correlation computation
        X_c = torch.from_numpy(X_c).to(self.device)
        Y_c = torch.from_numpy(Y_c).to(self.device)

        # Compute correlation for each component
        correlations = []
        for i in range(self.n_components):
            x_i = X_c[:, i]
            y_i = Y_c[:, i]

            # Pearson correlation
            corr = torch.corrcoef(torch.stack([x_i, y_i]))[0, 1]
            correlations.append(corr.item())

        return np.mean(correlations)


class CupyCCA:
    """
    GPU-accelerated CCA using CuPy (NumPy-compatible GPU arrays).

    This implementation provides maximum compatibility with NumPy code
    while leveraging GPU acceleration for linear algebra operations.

    Example:
        >>> cca = CupyCCA(n_components=64, reg_param=1e-4)
        >>> cca.fit(X_train, Y_train)
        >>> X_transformed, Y_transformed = cca.transform(X_test, Y_test)
    """

    def __init__(
        self,
        n_components: int = 64,
        reg_param: float = 1e-4,
        max_iter: int = 500,
        tol: float = 1e-6
    ):
        if not CUPY_AVAILABLE:
            raise ImportError("CuPy is required for CupyCCA")

        self.n_components = n_components
        self.reg_param = reg_param
        self.max_iter = max_iter
        self.tol = tol

        self.wx_ = None
        self.wy_ = None
        self.mean_x_ = None
        self.mean_y_ = None
        self.fitted_ = False

        print("CupyCCA initialized with GPU acceleration")

    def fit(self, X: np.ndarray, Y: np.ndarray) -> 'CupyCCA':
        """Fit CCA model using CuPy GPU arrays."""
        if X.shape[0] != Y.shape[0]:
            raise ValueError(f"X and Y must have same number of samples")

        n_samples = X.shape[0]
        n_components = min(
            self.n_components,
            X.shape[1],
            Y.shape[1],
            n_samples
        )

        if n_components < self.n_components:
            warnings.warn(f"n_components reduced to {n_components}")
            self.n_components = n_components

        # Move to GPU
        X_gpu = cp.asarray(X)
        Y_gpu = cp.asarray(Y)

        # Center data
        self.mean_x_ = X_gpu.mean(axis=0, keepdims=True)
        self.mean_y_ = Y_gpu.mean(axis=0, keepdims=True)

        X_centered = X_gpu - self.mean_x_
        Y_centered = Y_gpu - self.mean_y_

        # Compute covariances
        n = n_samples - 1
        Cxy = (X_centered.T @ Y_centered) / n
        Cxx = (X_centered.T @ X_centered) / n + self.reg_param * cp.eye(X.shape[1])
        Cyy = (Y_centered.T @ Y_centered) / n + self.reg_param * cp.eye(Y.shape[1])

        # Compute inverse square roots via Cholesky
        try:
            Lx = cp.linalg.cholesky(Cxx)
            Ly = cp.linalg.cholesky(Cyy)

            Lx_inv = cp.linalg.inv(Lx)
            Ly_inv = cp.linalg.inv(Ly)

            Cxx_inv_sqrt = Lx_inv @ Lx_inv.T
            Cyy_inv_sqrt = Ly_inv @ Ly_inv.T
        except cp.linalg.LinAlgError:
            warnings.warn("Cholesky failed, using SVD")
            Cxx_inv_sqrt = self._matrix_inverse_sqrt_svd(Cxx)
            Cyy_inv_sqrt = self._matrix_inverse_sqrt_svd(Cyy)

        # Compute M and perform SVD
        M = Cxx_inv_sqrt @ Cxy @ Cyy_inv_sqrt
        U, S, Vt = cp.linalg.svd(M, full_matrices=False)

        # Take top components
        U = U[:, :self.n_components]
        Vt = Vt[:self.n_components, :]
        S = S[:self.n_components]

        self.wx_ = Cxx_inv_sqrt @ U
        self.wy_ = Cyy_inv_sqrt @ Vt.T
        self.canonical_correlations_ = S

        self.fitted_ = True
        return self

    def _matrix_inverse_sqrt_svd(self, M):
        """Compute M^(-1/2) using SVD with CuPy."""
        U, S, _ = cp.linalg.svd(M, full_matrices=False)
        eps = cp.finfo(M.dtype).eps * max(M.shape)
        S_inv_sqrt = cp.where(S > eps, 1.0 / cp.sqrt(S), 0.0)
        return U @ cp.diag(S_inv_sqrt) @ U.T

    def transform(
        self,
        X: np.ndarray,
        Y: Optional[np.ndarray] = None
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        """Transform data to CCA space."""
        if not self.fitted_:
            raise RuntimeError("CCA model must be fitted before transform")

        X_gpu = cp.asarray(X)
        X_centered = X_gpu - self.mean_x_
        X_transformed = cp.asnumpy(X_centered @ self.wx_)

        if Y is None:
            return X_transformed

        Y_gpu = cp.asarray(Y)
        Y_centered = Y_gpu - self.mean_y_
        Y_transformed = cp.asnumpy(Y_centered @ self.wy_)

        return X_transformed, Y_transformed

    def fit_transform(
        self,
        X: np.ndarray,
        Y: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Fit and transform."""
        self.fit(X, Y)
        return self.transform(X, Y)


class CCALoss(torch.nn.Module):
    """
    CCA as a differentiable loss function for Deep CCA.

    This can be used as a loss function when training neural networks
    to learn CCA projections end-to-end.

    Based on Andrew et al. (2013) "Deep Canonical Correlation Analysis"

    Example:
        >>> cca_loss = CCALoss(outdim_size=64, device='cuda')
        >>> loss = cca_loss(H1, H2)  # H1, H2 are network outputs
        >>> loss.backward()
    """

    def __init__(
        self,
        outdim_size: int,
        use_all_singular_values: bool = False,
        device: str = 'cuda',
        reg_param: float = 1e-3
    ):
        super(CCALoss, self).__init__()
        self.outdim_size = outdim_size
        self.use_all_singular_values = use_all_singular_values
        self.device = device
        self.reg_param = reg_param

    def forward(self, H1: torch.Tensor, H2: torch.Tensor) -> torch.Tensor:
        """
        Compute CCA loss (negative correlation).

        Args:
            H1: First view representations (batch_size, dim1)
            H2: Second view representations (batch_size, dim2)

        Returns:
            Negative correlation loss (minimize to maximize correlation)
        """
        # Transpose for computation
        H1 = H1.t()
        H2 = H2.t()

        o1 = H1.size(0)
        o2 = H2.size(0)
        m = H1.size(1)

        # Center the data
        H1bar = H1 - H1.mean(dim=1, keepdim=True)
        H2bar = H2 - H2.mean(dim=1, keepdim=True)

        # Compute covariance matrices
        SigmaHat12 = (1.0 / (m - 1)) * (H1bar @ H2bar.t())
        SigmaHat11 = (1.0 / (m - 1)) * (H1bar @ H1bar.t()) + self.reg_param * torch.eye(
            o1, device=self.device
        )
        SigmaHat22 = (1.0 / (m - 1)) * (H2bar @ H2bar.t()) + self.reg_param * torch.eye(
            o2, device=self.device
        )

        # Compute eigendecomposition
        D1, V1 = torch.linalg.eigh(SigmaHat11)
        D2, V2 = torch.linalg.eigh(SigmaHat22)

        # Filter small eigenvalues
        eps = 1e-9
        D1_pos = torch.where(D1 > eps, D1, torch.ones_like(D1) * eps)
        D2_pos = torch.where(D2 > eps, D2, torch.ones_like(D2) * eps)

        # Compute inverse square roots
        D1_inv_sqrt = 1.0 / torch.sqrt(D1_pos)
        D2_inv_sqrt = 1.0 / torch.sqrt(D2_pos)

        SigmaHat11_inv_sqrt = V1 @ torch.diag(D1_inv_sqrt) @ V1.t()
        SigmaHat22_inv_sqrt = V2 @ torch.diag(D2_inv_sqrt) @ V2.t()

        # Compute T matrix and its SVD
        Tval = SigmaHat11_inv_sqrt @ SigmaHat12 @ SigmaHat22_inv_sqrt

        # Use SVD to get correlations
        U, V, _ = torch.linalg.svd(Tval)

        if self.use_all_singular_values:
            corr = torch.sum(V)
        else:
            corr = torch.sum(V[:self.outdim_size])

        # Return negative correlation as loss
        return -corr


def benchmark_cca_implementations(
    X: np.ndarray,
    Y: np.ndarray,
    n_components: int = 64,
    n_runs: int = 5
) -> dict:
    """
    Benchmark different CCA implementations.

    Args:
        X: First view data
        Y: Second view data
        n_components: Number of CCA components
        n_runs: Number of runs for timing

    Returns:
        Dictionary with timing results and correlation scores
    """
    import time

    results = {}

    # Test PyTorch CCA
    if TORCH_AVAILABLE:
        print("\n=== Testing PyTorch CCA ===")
        torch_times = []

        for i in range(n_runs):
            cca = TorchCCA(n_components=n_components, device='cuda')
            start = time.time()
            cca.fit(X, Y)
            X_c, Y_c = cca.transform(X, Y)
            elapsed = time.time() - start
            torch_times.append(elapsed)

            if i == 0:
                score = cca.score(X, Y)

        results['pytorch_gpu'] = {
            'mean_time': np.mean(torch_times),
            'std_time': np.std(torch_times),
            'correlation': score
        }
        print(f"PyTorch GPU: {results['pytorch_gpu']['mean_time']:.4f}s ± {results['pytorch_gpu']['std_time']:.4f}s")
        print(f"Correlation: {score:.4f}")

    # Test CuPy CCA
    if CUPY_AVAILABLE:
        print("\n=== Testing CuPy CCA ===")
        cupy_times = []

        for i in range(n_runs):
            cca = CupyCCA(n_components=n_components)
            start = time.time()
            cca.fit(X, Y)
            X_c, Y_c = cca.transform(X, Y)
            elapsed = time.time() - start
            cupy_times.append(elapsed)

        results['cupy_gpu'] = {
            'mean_time': np.mean(cupy_times),
            'std_time': np.std(cupy_times)
        }
        print(f"CuPy GPU: {results['cupy_gpu']['mean_time']:.4f}s ± {results['cupy_gpu']['std_time']:.4f}s")

    # Test sklearn CCA (CPU baseline)
    try:
        from sklearn.cross_decomposition import CCA as SklearnCCA
        print("\n=== Testing sklearn CCA (CPU baseline) ===")
        sklearn_times = []

        for i in range(n_runs):
            cca = SklearnCCA(n_components=n_components)
            start = time.time()
            cca.fit(X, Y)
            X_c, Y_c = cca.transform(X, Y)
            elapsed = time.time() - start
            sklearn_times.append(elapsed)

        results['sklearn_cpu'] = {
            'mean_time': np.mean(sklearn_times),
            'std_time': np.std(sklearn_times)
        }
        print(f"sklearn CPU: {results['sklearn_cpu']['mean_time']:.4f}s ± {results['sklearn_cpu']['std_time']:.4f}s")

    except ImportError:
        print("sklearn not available for baseline comparison")

    return results


if __name__ == "__main__":
    # Example usage and testing
    print("=" * 60)
    print("GPU-Accelerated CCA Implementation Test")
    print("=" * 60)

    # Generate synthetic data
    np.random.seed(42)
    n_samples = 5000
    n_features_x = 512
    n_features_y = 768

    print(f"\nGenerating synthetic data:")
    print(f"  - Samples: {n_samples}")
    print(f"  - Features X: {n_features_x}")
    print(f"  - Features Y: {n_features_y}")

    X = np.random.randn(n_samples, n_features_x).astype(np.float32)
    Y = np.random.randn(n_samples, n_features_y).astype(np.float32)

    # Add some correlation
    shared = np.random.randn(n_samples, 100).astype(np.float32)
    X[:, :100] += shared * 0.5
    Y[:, :100] += shared * 0.5

    # Run benchmark
    print("\n" + "=" * 60)
    print("Benchmarking CCA Implementations")
    print("=" * 60)

    results = benchmark_cca_implementations(X, Y, n_components=64, n_runs=3)

    print("\n" + "=" * 60)
    print("Summary")
    print("=" * 60)

    for method, res in results.items():
        print(f"\n{method}:")
        print(f"  Time: {res['mean_time']:.4f}s ± {res['std_time']:.4f}s")
        if 'correlation' in res:
            print(f"  Correlation: {res['correlation']:.4f}")

    # Calculate speedup
    if 'sklearn_cpu' in results and 'pytorch_gpu' in results:
        speedup = results['sklearn_cpu']['mean_time'] / results['pytorch_gpu']['mean_time']
        print(f"\n🚀 GPU Speedup (PyTorch vs sklearn): {speedup:.2f}x")
