"""
Multi-modal Data Decomposition Baselines

This file provides Python classes for four multi-modal decomposition methods:
1. CCA (Canonical Correlation Analysis) - Implemented in PyTorch
2. JIVE (Joint and Individual Variation Explained) - Implemented in PyTorch
3. AJIVE (Angle-based Joint and Individual Variation Explained) - Wrapper for 'py-jive'
4. DIVAS (Data-driven Identification of Variation and SVD) - Implemented in PyTorch
5. PPD (Probabilistic Principal Decomposition)
6. SLIDE (Structural Learning and Integrative DEcomposition)
7. ShIndICA (Shared and Individual Independent Component Analysis)

All classes expect a list of 2 PyTorch float tensors as input.
"""

import torch
import numpy as np

# --- Dependencies for AJIVE ---
# We use the 'py-jive' library for the AJIVE class.
# You must install it: pip install py-jive numpy scipy
try:
    from jive.AJIVE import AJIVE as PyAJIVE
except ImportError:
    print("Warning: 'py-jive' library not found.")
    print("The AJIVE class will not work.")
    print("Please install it: pip install py-jive")
    PyAJIVE = None

# --- Utility Functions ---

def _center_data(X):
    """Center data matrix X (features in columns)."""
    mean_X = torch.mean(X, dim=0)
    return X - mean_X, mean_X

def _optimal_hard_threshold_rank(X, noise_std=None):
    """
    Estimates the rank of a matrix using the optimal hard threshold for
    singular values, as proposed by Gavish and Donoho (2014).
    
    Assumes X is (n, p) with n <= p.
    """
    X = X.float()
    n, p = sorted(X.shape)
    if n > p:
        X = X.T
        n, p = X.shape

    # Compute SVD
    try:
        _, S, _ = torch.linalg.svd(X, full_matrices=False)
    except Exception as e:
        print(f"SVD failed: {e}")
        return 1 # Fallback to rank 1

    if noise_std is None:
        # Estimate noise variance
        # Use the median singular value, which is robust
        median_s = torch.median(S)
    else:
        # If noise is known, use it
        # This is more complex, so we'll stick to the median heuristic
        median_s = torch.median(S)
        
    # Calculate the threshold
    beta = n / p
    omega = 0.56 * beta**3 - 0.95 * beta**2 + 1.82 * beta + 1.43
    tau = omega * median_s
    
    # Count singular values above the threshold
    rank = torch.sum(S > tau).item()
    
    # Rank cannot be zero (we assume at least one component)
    return max(1, rank)

def _estimate_signal_rank(X):
    """
    A simple wrapper for rank estimation.
    Uses the optimal hard threshold.
    """
    # Center the data before rank estimation
    X_c, _ = _center_data(X)
    return _optimal_hard_threshold_rank(X_c)

# --- 1. Canonical Correlation Analysis (CCA) ---

class CCA:
    """
    Canonical Correlation Analysis (CCA) implemented with PyTorch.
    
    Finds linear combinations of two modalities (U, V) such that
    corr(X @ U, Y @ V) is maximized.
    """
    def __init__(self, energy_threshold=0.8):
        """
        Args:
            energy_threshold (float): The cumulative energy threshold (0-1) 
                                     to determine the number of components.
                                     Default is 0.8 (80% of total energy).
        """
        self.energy_threshold = energy_threshold
        self.n_components = None  # Will be determined from SVD
        self.U_ = None # Weights for X
        self.V_ = None # Weights for Y
        self.correlations_ = None
        self.mean_X_ = None
        self.mean_Y_ = None

    def decompose(self, modalities):
        """
        Takes a list of 2 modalities [X, Y] and performs CCA.
        
        Args:
            modalities (list[torch.Tensor]): [X, Y]
                X (torch.Tensor): Shape (n_samples, n_features_x)
                Y (torch.Tensor): Shape (n_samples, n_features_y)
                
        Returns:
            dict: {
                'scores_X': X projected onto its components (torch.Tensor)
                'scores_Y': Y projected onto its components (torch.Tensor)
                'correlations': The correlation of each component pair (torch.Tensor)
            },
            dict: {
                'rank': The number of components (int)
            }
        """
        X, Y = modalities
        if X.shape[0] != Y.shape[0]:
            raise ValueError("X and Y must have the same number of samples.")
        
        n_samples = X.shape[0]
        
        # 1. Center the data
        X_c, self.mean_X_ = _center_data(X)
        Y_c, self.mean_Y_ = _center_data(Y)
        
        # 2. Get dimensions
        d1 = X_c.shape[1]
        d2 = Y_c.shape[1]
        
        # 3. Compute covariance matrices
        # S_xx = 1/n * X'X
        # S_yy = 1/n * Y'Y
        # S_xy = 1/n * X'Y
        S_xx = (1.0 / (n_samples - 1)) * X_c.T @ X_c
        S_yy = (1.0 / (n_samples - 1)) * Y_c.T @ Y_c
        S_xy = (1.0 / (n_samples - 1)) * X_c.T @ Y_c
        
        # Add regularization for numerical stability
        eps = 1e-5
        S_xx += eps * torch.eye(d1, device=X.device)
        S_yy += eps * torch.eye(d2, device=Y.device)

        # 4. Solve the eigenvalue problem (or use SVD on the "correlation" matrix)
        # We solve M @ V = rho * V, where M = (S_xx^-0.5) @ S_xy @ (S_yy^-0.5)
        
        # Compute S_xx^(-1/2) and S_yy^(-1/2) via Cholesky or SVD
        # SVD is more stable
        U_xx, D_xx, V_xx_t = torch.linalg.svd(S_xx)
        S_xx_inv_sqrt = V_xx_t.T @ torch.diag(1.0 / torch.sqrt(D_xx)) @ U_xx.T
        
        U_yy, D_yy, V_yy_t = torch.linalg.svd(S_yy)
        S_yy_inv_sqrt = V_yy_t.T @ torch.diag(1.0 / torch.sqrt(D_yy)) @ U_yy.T
        
        # 5. Form the matrix M
        M = S_xx_inv_sqrt @ S_xy @ S_yy_inv_sqrt
        
        # 6. SVD of M
        U, D, V_t = torch.linalg.svd(M, full_matrices=False)
        
        # 7. Determine number of components from cumulative energy
        # Energy is proportional to squared singular values
        energy = D ** 2
        cumulative_energy = torch.cumsum(energy, dim=0) / torch.sum(energy)
        
        # Find the number of components needed to reach the energy threshold
        n_components = torch.searchsorted(cumulative_energy, self.energy_threshold).item() + 1
        n_components = min(n_components, len(D))  # Ensure we don't exceed available components
        self.n_components = n_components
        
        # 8. Get the weights
        # The correlations are the singular values D
        self.correlations_ = D[:self.n_components]
        
        # Weights (canonical vectors)
        self.U_ = S_xx_inv_sqrt @ U[:, :self.n_components]
        self.V_ = S_yy_inv_sqrt @ V_t.T[:, :self.n_components]
        
        # 9. Get the scores (projections)
        scores_X = X_c @ self.U_
        scores_Y = Y_c @ self.V_
        
        decomposed_representations = {
            #'scores_X': scores_X,
            #'scores_Y': scores_Y,
            #'correlations': self.correlations_
            'joint_X': X_c @ self.U_,  # Fixed: was incorrectly using self.V_
            'joint_Y': Y_c @ self.V_
        }
        rank_estimates = {
            'joint_rank': self.n_components
        }
        
        return decomposed_representations, rank_estimates

# --- 2. Joint and Individual Variation Explained (JIVE) ---

class JIVE:
    """
    Joint and Individual Variation Explained (JIVE) - PyTorch Implementation.
    
    Decomposes each modality X and Y into:
    X = J_X + I_X + E_X
    Y = J_Y + I_Y + E_Y
    
    Where J is joint structure, I is individual structure, E is noise.
    
    This is a simplified implementation of classic JIVE, which relies on
    SVD of concatenated matrices and rank estimation.
    """
    def __init__(self, joint_rank=None, ind_ranks=None):
        """
        Args:
            joint_rank (int, optional): The joint rank. If None, it will be 
                                      estimated.
            ind_ranks (list[int], optional): The individual ranks [r_x, r_y]. 
                                           If None, they will be estimated.
        """
        self.joint_rank = joint_rank
        self.ind_ranks = ind_ranks
        
    def decompose(self, modalities):
        """
        Performs JIVE decomposition.
        
        Args:
            modalities (list[torch.Tensor]): [X, Y]
                X (torch.Tensor): Shape (n_samples, n_features_x)
                Y (torch.Tensor): Shape (n_samples, n_features_y)

        Returns:
            dict: {
                'joint_X': Joint component for X (torch.Tensor)
                'joint_Y': Joint component for Y (torch.Tensor)
                'individual_X': Individual component for X (torch.Tensor)
                'individual_Y': Individual component for Y (torch.Tensor)
            },
            dict: {
                'joint_rank': Estimated or given joint rank (int)
                'individual_rank_X': Estimated or given individual rank for X (int)
                'individual_rank_Y': Estimated or given individual rank for Y (int)
            }
        """
        X, Y = modalities
        
        # 0. Center data
        X_c, mean_X = _center_data(X)
        Y_c, mean_Y = _center_data(Y)
        
        # 1. Estimate signal ranks for X and Y
        if self.ind_ranks is None:
            rank_X = _estimate_signal_rank(X_c)
            rank_Y = _estimate_signal_rank(Y_c)
        else:
            rank_X, rank_Y = self.ind_ranks
            
        # 2. Get signal SVD for X and Y
        U_x, S_x, V_x_t = torch.linalg.svd(X_c, full_matrices=False)
        X_signal = U_x[:, :rank_X] @ torch.diag(S_x[:rank_X]) @ V_x_t[:rank_X, :]
        
        U_y, S_y, V_y_t = torch.linalg.svd(Y_c, full_matrices=False)
        Y_signal = U_y[:, :rank_Y] @ torch.diag(S_y[:rank_Y]) @ V_y_t[:rank_Y, :]

        # 3. Concatenate (normalized) signal matrices
        # We use the SVD basis U_x, U_y as per the JIVE paper
        X_basis = U_x[:, :rank_X]
        Y_basis = U_y[:, :rank_Y]
        Z = torch.cat([X_basis, Y_basis], dim=1)
        
        # 4. Estimate joint rank
        if self.joint_rank is None:
            joint_rank = _estimate_signal_rank(Z)
        else:
            joint_rank = self.joint_rank
        
        # 5. SVD of concatenated basis
        U_z, _, _ = torch.linalg.svd(Z, full_matrices=False)
        
        # Joint basis is the first `joint_rank` components
        joint_basis = U_z[:, :joint_rank]
        
        # 6. Project X and Y signal onto the joint basis
        proj_J_X = joint_basis @ joint_basis.T @ X_signal
        proj_J_Y = joint_basis @ joint_basis.T @ Y_signal
        
        # 7. Individual components are the residuals
        proj_I_X = X_signal - proj_J_X
        proj_I_Y = Y_signal - proj_J_Y

        # 8. Re-add means
        J_X = proj_J_X# + mean_X
        J_Y = proj_J_Y# + mean_Y
        I_X = proj_I_X # Individual is zero-mean by definition
        I_Y = proj_I_Y
        
        decomposed_representations = {
            'joint_X': J_X,
            'joint_Y': J_Y,
            'individual_X': I_X,
            'individual_Y': I_Y
        }
        rank_estimates = {
            'joint_rank': joint_rank,
            'individual_rank_X': rank_X,
            'individual_rank_Y': rank_Y
        }
        
        return decomposed_representations, rank_estimates


# --- 3. Angle-based Joint and Individual Variation Explained (AJIVE) ---

class AJIVE:
    """
    Wrapper for the 'py-jive' Python library's AJIVE implementation.
    
    This class takes PyTorch tensors, converts them to NumPy,
    runs the AJIVE decomposition, and returns the results.
    
    You MUST install the library:
    pip install py-jive
    """
    def __init__(self, joint_rank=None, ind_ranks=None, random_seed=None):
        """
        Args:
            joint_rank (int, optional): The joint rank. If None, AJIVE will 
                                      estimate it.
            ind_ranks (list[int], optional): The individual ranks [r_x, r_y]. 
                                           If None, AJIVE will estimate them.
            random_seed (int, optional): Random seed for reproducibility.
        """
        if PyAJIVE is None:
            raise ImportError("AJIVE class requires the 'py-jive' library. "
                              "Please install it: pip install py-jive")
            
        self.joint_rank = joint_rank
        self.ind_ranks = ind_ranks
        self.random_seed = random_seed
        
    def decompose(self, modalities):
        """
        Performs AJIVE decomposition.
        
        Args:
            modalities (list[torch.Tensor]): [X, Y]
                X (torch.Tensor): Shape (n_samples, n_features_x)
                Y (torch.Tensor): Shape (n_samples, n_features_y)

        Returns:
            dict: {
                'joint_X': Joint component for X (torch.Tensor)
                'joint_Y': Joint component for Y (torch.Tensor)
                'individual_X': Individual component for X (torch.Tensor)
                'individual_Y': Individual component for Y (torch.Tensor)
            },
            dict: {
                'joint_rank': Estimated joint rank (int)
                'individual_rank_X': Estimated individual rank for X (int)
                'individual_rank_Y': Estimated individual rank for Y (int)
            }
        """
        # 1. Convert PyTorch tensors to NumPy
        # py-jive's AJIVE expects (n_samples, n_features)
        X_np = modalities[0].cpu().numpy()
        Y_np = modalities[1].cpu().numpy()
        
        blocks = {'X': X_np, 'Y': Y_np}
        
        # 2. Estimate initial signal ranks for each block
        # AJIVE requires init_signal_ranks as a required parameter
        if self.ind_ranks is not None:
            init_signal_ranks = {'X': self.ind_ranks[0], 'Y': self.ind_ranks[1]}
        else:
            # Estimate ranks using optimal hard thresholding
            init_signal_ranks = {
                'X': _optimal_hard_threshold_rank(torch.from_numpy(X_np).float()),
                'Y': _optimal_hard_threshold_rank(torch.from_numpy(Y_np).float())
            }
        
        # 3. Run AJIVE
        # Set random seed if provided
        if self.random_seed is not None:
            np.random.seed(self.random_seed)
        
        ajive_obj = PyAJIVE(init_signal_ranks=init_signal_ranks)
        ajive_obj.fit(blocks)
        
        # 4. Extract results
        # AJIVE stores block-specific results in self.blocks[block_name]
        # Each block has .joint and .individual PCA objects with .full_ attribute
        joint_X_np = ajive_obj.blocks['X'].joint.full_
        joint_Y_np = ajive_obj.blocks['Y'].joint.full_
        
        ind_X_np = ajive_obj.blocks['X'].individual.full_
        ind_Y_np = ajive_obj.blocks['Y'].individual.full_
        
        # Convert from pandas DataFrames to numpy arrays if needed
        if hasattr(joint_X_np, 'values'):
            joint_X_np = joint_X_np.values
        if hasattr(joint_Y_np, 'values'):
            joint_Y_np = joint_Y_np.values
        if hasattr(ind_X_np, 'values'):
            ind_X_np = ind_X_np.values
        if hasattr(ind_Y_np, 'values'):
            ind_Y_np = ind_Y_np.values
        
        # 5. Get rank estimates
        est_rank_X = ajive_obj.blocks['X'].individual.rank
        est_rank_Y = ajive_obj.blocks['Y'].individual.rank
        est_joint_rank = ajive_obj.common.rank
        
        # 6. Convert back to PyTorch tensors
        device = modalities[0].device
        decomposed_representations = {
            'joint_X': torch.from_numpy(joint_X_np).float().to(device),
            'joint_Y': torch.from_numpy(joint_Y_np).float().to(device),
            'individual_X': torch.from_numpy(ind_X_np).float().to(device),
            'individual_Y': torch.from_numpy(ind_Y_np).float().to(device)
        }
        
        rank_estimates = {
            'joint_rank': est_joint_rank,
            'individual_rank_X': est_rank_X,
            'individual_rank_Y': est_rank_Y
        }
        
        return decomposed_representations, rank_estimates


# --- 4. Data-driven Identification of Variation and SVD (DIVAS) ---

class DIVAS:
    """
    Implementation of DIVAS as a baseline multi-modal method.
    
    "DIVAS" (Data-driven Identification of VAlid SVD) is originally a
    rank-estimation method. We interpret this baseline as a "concatenated SVD"
    approach, where the joint rank is estimated using the DIVAS-like
    optimal hard thresholding method.
    
    This model finds *only* a joint structure.
    """
    def __init__(self, joint_rank=None):
        """
        Args:
            joint_rank (int, optional): The joint rank. If None, it will be 
                                      estimated using optimal hard thresholding.
        """
        self.joint_rank = joint_rank
        
    def decompose(self, modalities):
        """
        Performs DIVAS decomposition.
        
        Args:
            modalities (list[torch.Tensor]): [X, Y]
                X (torch.Tensor): Shape (n_samples, n_features_x)
                Y (torch.Tensor): Shape (n_samples, n_features_y)

        Returns:
            dict: {
                'joint_X': Joint component for X (torch.Tensor)
                'joint_Y': Joint component for Y (torch.Tensor)
            },
            dict: {
                'joint_rank': Estimated or given joint rank (int)
            }
        """
        X, Y = modalities
        
        # 0. Center data
        X_c, mean_X = _center_data(X)
        Y_c, mean_Y = _center_data(Y)
        
        # 1. Concatenate centered data
        Z = torch.cat([X_c, Y_c], dim=1)
        
        # 2. Estimate joint rank
        if self.joint_rank is None:
            joint_rank = _optimal_hard_threshold_rank(Z)
        else:
            joint_rank = self.joint_rank
            
        # 3. Compute SVD of concatenated matrix
        U, S, V_t = torch.linalg.svd(Z, full_matrices=False)
        
        # 4. Reconstruct the joint signal
        Z_joint = U[:, :joint_rank] @ torch.diag(S[:joint_rank]) @ V_t[:joint_rank, :]
        
        # 5. Split back into X and Y components
        d1 = X_c.shape[1]
        
        # Add means back
        J_X = Z_joint[:, :d1] + mean_X
        J_Y = Z_joint[:, d1:] + mean_Y

        decomposed_representations = {
            'joint_Z': Z_joint,
            'joint_X': J_X,
            'joint_Y': J_Y
        }
        rank_estimates = {
            'joint_rank': joint_rank
        }
        
        return decomposed_representations, rank_estimates

# --- 5. Product of Projections Decomposition (PPD) ---

class PPD:
    """
    Implementation of the Product of Projections Decomposition (PPD) method
    from Sergazinov et al. (2024).
    
    This method analyzes the spectrum of the product of projection matrices
    (P_X1 @ P_X2) to find the joint rank. It uses a rotational bootstrap
    to estimate the perturbation bound (epsilon_1) and random matrix theory
    to estimate the noise bound (lambda_plus).
    """
    def __init__(self, n_bootstrap=100, random_seed=42):
        """
        Args:
            n_bootstrap (int): Number of bootstrap samples for estimating epsilon_1.
            random_seed (int): Seed for reproducibility of the bootstrap.
        """
        self.n_bootstrap = n_bootstrap
        self.random_seed = random_seed

    def _get_signal_estimates(self, Y):
        """Estimate signal rank, basis, and noise from observed data Y."""
        n, p = Y.shape
        try:
            U, S, V_t = torch.linalg.svd(Y, full_matrices=False)
        except Exception as e:
            print(f"SVD failed during PPD signal estimation: {e}")
            # Fallback to rank 1
            U, S, V_t = torch.linalg.svd(Y + torch.randn_like(Y) * 1e-6, full_matrices=False)
            
        rank = _optimal_hard_threshold_rank(Y)
        
        U_hat = U[:, :rank]
        S_hat_diag = torch.diag(S[:rank])
        V_hat = V_t[:rank, :].T
        
        X_hat = U_hat @ S_hat_diag @ V_hat.T
        E_hat = Y - X_hat
        noise_std = torch.std(E_hat)
        
        return U_hat, S_hat_diag, V_hat, X_hat, rank, noise_std

    def _run_rotational_bootstrap(self, U1, S1_diag, V1, r1, p1, std1, 
                                      U2, S2_diag, V2, r2, p2, std2, device, n):
        """
        Performs the rotational bootstrap to estimate epsilon_1.
        (Based on Section 4 (ii) of Sergazinov et al., 2024)
        """
        torch.manual_seed(self.random_seed)
        
        # Get principal angles (cos(theta))
        U_cos, S_cos, V_cos_t = torch.linalg.svd(U1.T @ U2)
        cos_theta = S_cos
        r_min = cos_theta.shape[0]
        sin_theta = torch.sqrt(1.0 - cos_theta**2)

        epsilons = []
        for _ in range(self.n_bootstrap):
            # 1. Generate random orthonormal bases
            Q_N, _ = torch.linalg.qr(torch.randn(n, n, device=device))
            
            # 2. Align bases using principal angles
            U1_b = Q_N[:, :r1]
            
            U2_b_aligned = Q_N[:, :r_min] @ torch.diag(cos_theta)
            
            # Add orthogonal part for U2
            r_ortho = r2 - r_min
            if r_ortho > 0:
                U2_b_orthog, _ = torch.linalg.qr(Q_N[:, r1:r1 + r_ortho])
                U2_b = torch.cat([U2_b_aligned, U2_b_orthog], dim=1)
            else:
                U2_b = U2_b_aligned
            
            # 3. Create bootstrap "truth" signals
            V1_b, _ = torch.linalg.qr(torch.randn(p1, r1, device=device))
            V2_b, _ = torch.linalg.qr(torch.randn(p2, r2, device=device))
            
            X1_b = U1_b @ S1_diag @ V1_b.T
            X2_b = U2_b @ S2_diag @ V2_b.T
            
            # 4. Create bootstrap data replicate
            Y1_b = X1_b + torch.randn(n, p1, device=device) * std1
            Y2_b = X2_b + torch.randn(n, p2, device=device) * std2
            
            # 5. Estimate subspaces from bootstrap data
            # Use the *original* estimated ranks (r1, r2) per paper
            U1_b_hat, _, _ = torch.linalg.svd(Y1_b)
            U1_b_hat = U1_b_hat[:, :r1]
            
            U2_b_hat, _, _ = torch.linalg.svd(Y2_b)
            U2_b_hat = U2_b_hat[:, :r2]
            
            # 6. Calculate Deltas
            P_X1_b = U1_b @ U1_b.T
            P_X2_b = U2_b @ U2_b.T
            P_X1_b_hat = U1_b_hat @ U1_b_hat.T
            P_X2_b_hat = U2_b_hat @ U2_b_hat.T
            
            Delta1_b = P_X1_b - P_X1_b_hat
            Delta2_b = P_X2_b - P_X2_b_hat
            
            # 7. Calculate epsilon_1 for this replicate
            M_b = P_X1_b @ (Delta1_b + Delta2_b + Delta1_b @ Delta2_b) @ P_X2_b
            
            # Use spectral norm (largest singular value)
            eps_b = torch.linalg.svdvals(M_b)[0]
            epsilons.append(eps_b)
            
        return torch.mean(torch.stack(epsilons))

    def _get_rmt_noise_bound(self, r1, r2, n):
        """
        Calculates the noise bound lambda_plus from Random Matrix Theory.
        (Based on Section 4 (iii) and Equation 5)
        """
        q1 = r1 / n
        q2 = r2 / n
        
        # Handle edge cases where q is 0 or 1
        q1 = max(1e-9, min(q1, 1.0 - 1e-9))
        q2 = max(1e-9, min(q2, 1.0 - 1e-9))
        
        term1 = q1 + q2 - 2*q1*q2
        term2 = 2 * torch.sqrt(torch.tensor(q1 * q2 * (1-q1) * (1-q2)))
        
        lambda_plus_sq = term1 + term2
        
        # Bound at 1.0
        return min(torch.sqrt(lambda_plus_sq), 1.0)


    def decompose(self, modalities):
        """
        Performs PPD decomposition.
        
        Args:
            modalities (list[torch.Tensor]): [X, Y]
                X (torch.Tensor): Shape (n_samples, n_features_x)
                Y (torch.Tensor): Shape (n_samples, n_features_y)

        Returns:
            dict: {
                'joint_X': Joint component for X (torch.Tensor)
                'joint_Y': Joint component for Y (torch.Tensor)
                'individual_X': Individual component for X (torch.Tensor)
                'individual_Y': Individual component for Y (torch.Tensor)
            },
            dict: {
                'joint_rank': Estimated joint rank (int)
                'individual_rank_X': Estimated individual rank for X (int)
                'individual_rank_Y': Estimated individual rank for Y (int)
            }
        """
        X_raw, Y_raw = modalities
        n = X_raw.shape[0]
        p1 = X_raw.shape[1]
        p2 = Y_raw.shape[1]
        device = X_raw.device
        
        # 0. Center data
        X, mean_X = _center_data(X_raw)
        Y, mean_Y = _center_data(Y_raw)

        # 1. Estimate marginal ranks and subspaces (Algorithm 1, Step 1)
        U1_hat, S1_diag, V1_hat, X_hat, r1, std1 = self._get_signal_estimates(X)
        U2_hat, S2_diag, V2_hat, Y_hat, r2, std2 = self._get_signal_estimates(Y)
        
        P_X1_hat = U1_hat @ U1_hat.T
        P_X2_hat = U2_hat @ U2_hat.T
        
        # 2. Compute perturbation bound epsilon_1 (Algorithm 1, Step 2)
        epsilon_1_hat = self._run_rotational_bootstrap(
            U1_hat, S1_diag, V1_hat, r1, p1, std1,
            U2_hat, S2_diag, V2_hat, r2, p2, std2,
            device, n
        )
        
        # 3. Compute noise bound lambda_plus (Algorithm 1, Step 3)
        lambda_plus = self._get_rmt_noise_bound(r1, r2, n)
        
        # 4. Estimate joint rank (Algorithm 1, Step 4)
        M_hat = P_X1_hat @ P_X2_hat
        S_M_hat = torch.linalg.svdvals(M_hat)
        
        threshold = max(1.0 - epsilon_1_hat, lambda_plus)
        r_J = torch.sum(S_M_hat > threshold).item()
        
        # 5. Estimate joint subspace (Algorithm 1, Step 5)
        S_sym = 0.5 * (M_hat + M_hat.T)
        U_sym, _, _ = torch.linalg.svd(S_sym)
        U_J_hat = U_sym[:, :r_J]
        P_J_hat = U_J_hat @ U_J_hat.T
        
        # 6. Estimate individual subspaces (Algorithm 1, Step 6)
        r_I1 = r1 - r_J
        r_I2 = r2 - r_J
        
        # Project residual signal onto orthogonal space
        P_I1_hat = P_X1_hat @ (torch.eye(n, device=device) - P_J_hat)
        P_I2_hat = P_X2_hat @ (torch.eye(n, device=device) - P_J_hat)
        
        U_I1, _, _ = torch.linalg.svd(P_I1_hat)
        U_I1_hat = U_I1[:, :r_I1]
        
        U_I2, _, _ = torch.linalg.svd(P_I2_hat)
        U_I2_hat = U_I2[:, :r_I2]
        
        # 7. Reconstruct data components
        J_X = P_J_hat @ X_hat + mean_X
        J_Y = P_J_hat @ Y_hat + mean_Y
        
        I_X = (U_I1_hat @ U_I1_hat.T) @ X_hat
        I_Y = (U_I2_hat @ U_I2_hat.T) @ Y_hat
        
        decomposed_representations = {
            'joint_X': J_X,
            'joint_Y': J_Y,
            'individual_X': I_X,
            'individual_Y': I_Y
        }
        rank_estimates = {
            'joint_rank': r_J,
            'individual_rank_X': r_I1,
            'individual_rank_Y': r_I2,
            'total_rank_X': r1,
            'total_rank_Y': r2,
            'epsilon_1_hat': epsilon_1_hat.item(),
            'lambda_plus_bound': lambda_plus.item(),
            'joint_rank_threshold': threshold.item()
        }
        
        return decomposed_representations, rank_estimates

def _center_data(X):
    """Center data matrix X (features in columns)."""
    mean_X = torch.mean(X, dim=0)
    return X - mean_X, mean_X

# --- 8. Structural Learning and Integrative Decomposition (SLIDE) ---

class SLIDE:
    """
    Implementation of the SLIDE method from Gaynanova & Li (2017),
    "Structural Learning and Integrative Decomposition of Multi-View Data".
    
    This implements Algorithm 2 from the paper, which fits the model
    X = U @ V(S).T + E given a *pre-specified* structure S.
    
    This wrapper class automatically estimates the ranks by using
    the Optimal Hard Threshold (OHT) for total ranks, and then
    uses the JIVE method's approach to estimate the joint rank.
    """
    def __init__(self, max_iter=100, tol=1e-6):
        """
        Args:
            max_iter (int): Maximum iterations for the algorithm.
            tol (float): Convergence tolerance.
        """
        self.max_iter = max_iter
        self.tol = tol

    def decompose(self, modalities):
        """
        Performs SLIDE decomposition (Algorithm 2) after estimating ranks.
        
        Args:
            modalities (list[torch.Tensor]): [X, Y]
                X (torch.Tensor): Shape (n_samples, n_features_x)
                Y (torch.Tensor): Shape (n_samples, n_features_y)

        Returns:
            dict: {
                'joint_X': Joint component for X (torch.Tensor)
                'joint_Y': Joint component for Y (torch.Tensor)
                'individual_X': Individual component for X (torch.Tensor)
                'individual_Y': Individual component for Y (torch.Tensor)
            },
            dict: {
                'joint_rank': Estimated joint rank (int)
                'individual_rank_X': Estimated individual rank for X (int)
                'individual_rank_Y': Estimated individual rank for Y (int)
            }
        """
        X_raw, Y_raw = modalities
        n, p1 = X_raw.shape
        _, p2 = Y_raw.shape
        p = p1 + p2
        device = X_raw.device
        
        # 0. Center data
        X, mean_X = _center_data(X_raw)
        Y, mean_Y = _center_data(Y_raw)
        X_concat = torch.cat([X, Y], dim=1)

        # 1. Estimate Ranks (using JIVE's method)
        # 1a. Estimate total signal ranks
        r1 = _estimate_signal_rank(X)
        r2 = _estimate_signal_rank(Y)
        
        # 1b. Get signal basis
        U1, _, _ = torch.linalg.svd(X, full_matrices=False)
        U2, _, _ = torch.linalg.svd(Y, full_matrices=False)
        
        # 1c. Estimate joint rank from concatenated basis
        basis_concat = torch.cat([U1[:, :r1], U2[:, :r2]], dim=1)
        k_j = _estimate_signal_rank(basis_concat)
        
        k_i1 = r1 - k_j
        k_i2 = r2 - k_j
        total_rank = k_j + k_i1 + k_i2

        # 2. Initialize U
        U_init, _, _ = torch.linalg.svd(X_concat, full_matrices=False)
        U = U_init[:, :total_rank]
        
        # 3. Define structure indices
        idx_j = slice(0, k_j)
        idx_i1 = slice(k_j, k_j + k_i1)
        idx_i2 = slice(k_j + k_i1, total_rank)
        
        V = torch.zeros(p, total_rank, device=device)

        # 4. Iteratively update U and V (Algorithm 2)
        for _ in range(self.max_iter):
            U_old = U
            
            # 4a. Update V
            V_full = X_concat.T @ U
            V1_full = V_full[:p1, :]
            V2_full = V_full[p1:, :]
            
            # Apply block-sparse structure S
            V.fill_(0.0)
            # Joint components
            V[:p1, idx_j] = V1_full[:, idx_j]
            V[p1:, idx_j] = V2_full[:, idx_j]
            # Individual 1
            V[:p1, idx_i1] = V1_full[:, idx_i1]
            # Individual 2
            V[p1:, idx_i2] = V2_full[:, idx_i2]
            
            # 4b. Update U (Orthogonal Procrustes)
            U_svd, _, V_svd_t = torch.linalg.svd(X_concat @ V, full_matrices=False)
            U = U_svd @ V_svd_t
            
            # Check convergence
            diff = torch.norm(U - U_old, 'fro')
            if diff < self.tol:
                break
                
        # 5. Reconstruct components
        V1 = V[:p1, :]
        V2 = V[p1:, :]
        
        J_X = U[:, idx_j] @ V1[:, idx_j].T + mean_X
        J_Y = U[:, idx_j] @ V2[:, idx_j].T + mean_Y
        
        I_X = U[:, idx_i1] @ V1[:, idx_i1].T
        I_Y = U[:, idx_i2] @ V2[:, idx_i2].T
        
        decomposed_representations = {
            'joint_X': J_X,
            'joint_Y': J_Y,
            'individual_X': I_X,
            'individual_Y': I_Y
        }
        rank_estimates = {
            'joint_rank': k_j,
            'individual_rank_X': k_i1,
            'individual_rank_Y': k_i2
        }
        
        return decomposed_representations, rank_estimates

# --- 9. Shared and Individual Independent Component Analysis (ShIndICA) ---

class ShIndICA:
    """
    Implementation of ShIndICA from Pandeva & Forré (2023),
    "Multi-View Independent Component Analysis with Shared and Individual Sources".
    
    This is a fundamentally different approach based on ICA, not SVD.
    It maximizes the non-Gaussianity of sources + a trace term
    to find statistically independent components.
    
    This class *automatically estimates* the ranks by implementing the
    model selection procedure from Section 5 of the paper.
    """
    def __init__(self, joint_rank_options, test_split_ratio=0.25,
                 n_iter=1000, lr=1e-3, lambda_reg=1.0, random_seed=42):
        """
        Args:
            joint_rank_options (list[int]): List of joint ranks to test
                                            (e.g., [1, 2, 4, 6, 8, 10]).
            test_split_ratio (float): Fraction of data to use for test split
                                      (paper uses 0.25 for 3:1 split).
            n_iter (int): Number of gradient ascent iterations for *each* fit.
            lr (float): Learning rate.
            lambda_reg (float): Weight for the shared trace term (Eq. 3).
            random_seed (int): Seed for reproducibility.
        """
        self.joint_rank_options = joint_rank_options
        self.test_split_ratio = test_split_ratio
        self.n_iter = n_iter
        self.lr = lr
        self.lambda_reg = lambda_reg
        self.random_seed = random_seed
        
    def _negentropy_approx(self, z):
        """Approximates negentropy using log(cosh(z)) for super-Gaussian."""
        return torch.log(torch.cosh(z)).mean()
        
    def _orthogonal_projection(self, W):
        """Projects a matrix W onto the manifold of orthogonal matrices."""
        try:
            U, _, V_t = torch.linalg.svd(W, full_matrices=False)
            return U @ V_t
        except Exception:
            # Fallback for numerical instability
            return W
            
    def _fit_model(self, X_w, Y_w, k_j, k_i1, k_i2):
        """Internal model-fitting function."""
        r1 = k_j + k_i1
        r2 = k_j + k_i2
        device = X_w.device

        # Initialize orthogonal unmixing matrices
        W1 = torch.eye(r1, device=device, requires_grad=True)
        W2 = torch.eye(r2, device=device, requires_grad=True)
        
        optimizer = torch.optim.Adam([W1, W2], lr=self.lr)

        for _ in range(self.n_iter):
            Z1 = X_w @ W1
            Z2 = Y_w @ W2
            
            Z1_j, Z1_i = Z1[:, :k_j], Z1[:, k_j:]
            Z2_j, Z2_i = Z2[:, :k_j], Z2[:, k_j:]
            
            S_bar = (Z1_j + Z2_j) / 2.0
            
            loss_negentropy = self._negentropy_approx(S_bar) + \
                              self._negentropy_approx(Z1_i) + \
                              self._negentropy_approx(Z2_i)
            
            trace_term = torch.trace(Z1_j.T @ Z2_j)
            loss = -(loss_negentropy + self.lambda_reg * trace_term)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            with torch.no_grad():
                W1.data = self._orthogonal_projection(W1.data)
                W2.data = self._orthogonal_projection(W2.data)
        
        return W1.detach(), W2.detach()
        
    def _calculate_nre(self, X_w_test, Y_w_test, W1_train, W2_train, k_j):
        """Calculates the Normalized Reconstruction Error (NRE) on test data."""
        if k_j == 0:
            return float('inf') # Avoid division by zero
            
        Z1_test = X_w_test @ W1_train
        Z2_test = Y_w_test @ W2_train
        
        Z1_j_test = Z1_test[:, :k_j]
        Z2_j_test = Z2_test[:, :k_j]
        
        # Per Section 5 of the paper: z_hat_d = z_d,0 - mean(z_l,0)
        Z_mean = (Z1_j_test + Z2_j_test) / 2.0
        Z_hat_1 = Z1_j_test - Z_mean
        Z_hat_2 = Z2_j_test - Z_mean
        
        # NRE(k) = sum(||z_hat_d||_F^2) / k
        nre = (torch.norm(Z_hat_1, 'fro')**2 + torch.norm(Z_hat_2, 'fro')**2) / k_j
        return nre.item()

    def decompose(self, modalities):
        """
        Performs ShIndICA decomposition with automatic rank selection.
        
        Args:
            modalities (list[torch.Tensor]): [X, Y]
                X (torch.Tensor): Shape (n_samples, n_features_x)
                Y (torch.Tensor): Shape (n_samples, n_features_y)

        Returns:
            dict: {
                'joint_X': Joint component for X (torch.Tensor)
                'joint_Y': Joint component for Y (torch.Tensor)
                'individual_X': Individual component for X (torch.Tensor)
                'individual_Y': Individual component for Y (torch.Tensor)
            },
            dict: {
                'joint_rank': Estimated joint rank (int)
                'individual_rank_X': Estimated individual rank for X (int)
                'individual_rank_Y': Estimated individual rank for Y (int)
            }
        """
        torch.manual_seed(self.random_seed)
        X_raw, Y_raw = modalities
        n = X_raw.shape[0]
        device = X_raw.device
        
        # 0. Center data
        X, mean_X = _center_data(X_raw)
        Y, mean_Y = _center_data(Y_raw)
        
        # 1. Estimate total ranks and whiten data
        r1 = _estimate_signal_rank(X)
        r2 = _estimate_signal_rank(Y)
        
        U1, S1, V1_t = torch.linalg.svd(X, full_matrices=False)
        U2, S2, V2_t = torch.linalg.svd(Y, full_matrices=False)
        
        X_w = U1[:, :r1]
        Y_w = U2[:, :r2]
        
        # Store de-whitening matrices (mixing matrices)
        A1_dewhiten = (torch.diag(S1[:r1]) @ V1_t[:r1, :]).T
        A2_dewhiten = (torch.diag(S2[:r2]) @ V2_t[:r2, :]).T
        
        # 2. Create train/test split for NRE
        n_test = int(n * self.test_split_ratio)
        n_train = n - n_test
        
        X_w_train, X_w_test = X_w[:n_train], X_w[n_train:]
        Y_w_train, Y_w_test = Y_w[:n_train], Y_w[n_train:]
        
        # 3. Grid Search for best k_j
        nre_scores = {}
        max_k_j = min(r1, r2)
        
        for k_j in self.joint_rank_options:
            if k_j > max_k_j:
                continue # Rank is too large
            
            k_i1 = r1 - k_j
            k_i2 = r2 - k_j
            
            # Fit model on training data
            W1_train, W2_train = self._fit_model(X_w_train, Y_w_train, k_j, k_i1, k_i2)
            
            # Evaluate on test data
            nre_score = self._calculate_nre(X_w_test, Y_w_test, W1_train, W2_train, k_j)
            nre_scores[k_j] = nre_score
            
        # Select best k_j (max of the argmin set, per paper Sec. 5)
        min_nre = min(nre_scores.values())
        best_k_j_candidates = [k for k, score in nre_scores.items() if score == min_nre]
        best_k_j = max(best_k_j_candidates)
        
        best_k_i1 = r1 - best_k_j
        best_k_i2 = r2 - best_k_j
        
        # 4. Final fit on full data
        W1_final, W2_final = self._fit_model(X_w, Y_w, best_k_j, best_k_i1, best_k_i2)
        
        # 5. Reconstruct components
        Z1 = X_w @ W1_final
        Z2 = Y_w @ W2_final
        Z1_j, Z1_i = Z1[:, :best_k_j], Z1[:, best_k_j:]
        Z2_j, Z2_i = Z2[:, :best_k_j], Z2[:, best_k_j:]
        S_bar = (Z1_j + Z2_j) / 2.0
        
        A1_j, A1_i = A1_dewhiten[:, :best_k_j], A1_dewhiten[:, best_k_j:]
        A2_j, A2_i = A2_dewhiten[:, :best_k_j], A2_dewhiten[:, best_k_j:]

        # Reconstruct data components
        J_X = S_bar @ A1_j.T + mean_X
        J_Y = S_bar @ A2_j.T + mean_Y
        
        I_X = Z1_i @ A1_i.T
        I_Y = Z2_i @ A2_i.T
        
        decomposed_representations = {
            'joint_X': J_X,
            'joint_Y': J_Y,
            'individual_X': I_X,
            'individual_Y': I_Y
        }
        rank_estimates = {
            'joint_rank': best_k_j,
            'individual_rank_X': best_k_i1,
            'individual_rank_Y': best_k_i2,
            'nre_scores': nre_scores
        }
        
        return decomposed_representations, rank_estimates

# --- Example Usage ---

if __name__ == "__main__":
    # 1. Create synthetic data
    n_samples = 100
    n_features_x = 50
    n_features_y = 60
    
    # Common signal
    common_signal = torch.randn(n_samples, 1)
    
    # X = Joint_X + Ind_X + Noise_X
    J_X = common_signal @ torch.randn(1, n_features_x)
    I_X = torch.randn(n_samples, 2) @ torch.randn(2, n_features_x)
    E_X = torch.randn(n_samples, n_features_x) * 0.5
    X = J_X + I_X + E_X
    
    # Y = Joint_Y + Ind_Y + Noise_Y
    J_Y = common_signal @ torch.randn(1, n_features_y)
    I_Y = torch.randn(n_samples, 3) @ torch.randn(3, n_features_y)
    E_Y = torch.randn(n_samples, n_features_y) * 0.5
    Y = J_Y + I_Y + E_Y
    
    modalities = [X, Y]
    
    print(f"Data created: X shape {X.shape}, Y shape {Y.shape}")
    print("-" * 30)

    # 2. Test CCA
    print("Testing CCA...")
    cca = CCA(energy_threshold=0.8)
    cca_decomp, cca_ranks = cca.decompose(modalities)
    print(f"  Rank: {cca_ranks['rank']}")
    print(f"  Correlations: {cca_decomp['correlations']}")
    print(f"  Scores X shape: {cca_decomp['scores_X'].shape}")
    print("-" * 30)
    
    # 3. Test JIVE (PyTorch implementation)
    print("Testing JIVE (PyTorch)...")
    jive = JIVE()
    jive_decomp, jive_ranks = jive.decompose(modalities)
    print(f"  Estimated Joint Rank: {jive_ranks['joint_rank']}")
    print(f"  Estimated Ind Rank X: {jive_ranks['individual_rank_X']}")
    print(f"  Estimated Ind Rank Y: {jive_ranks['individual_rank_Y']}")
    print(f"  Joint X shape: {jive_decomp['joint_X'].shape}")
    print(f"  Ind Y shape: {jive_decomp['individual_Y'].shape}")
    print("-" * 30)

    # 4. Test AJIVE (py-jive library)
    if PyAJIVE is not None:
        print("Testing AJIVE (py-jive library)...")
        ajive_model = AJIVE()
        ajive_decomp, ajive_ranks = ajive_model.decompose(modalities)
        print(f"  Estimated Joint Rank: {ajive_ranks['joint_rank']}")
        print(f"  Estimated Ind Rank X: {ajive_ranks['individual_rank_X']}")
        print(f"  Estimated Ind Rank Y: {ajive_ranks['individual_rank_Y']}")
        print(f"  Joint X shape: {ajive_decomp['joint_X'].shape}")
        print(f"  Ind Y shape: {ajive_decomp['individual_Y'].shape}")
        print("-" * 30)
    else:
        print("Skipping AJIVE test: 'py-jive' library not found.")
    
    # 5. Test DIVAS (concatenated SVD)
    print("Testing DIVAS...")
    divas = DIVAS()
    divas_decomp, divas_ranks = divas.decompose(modalities)
    print(f"  Estimated Joint Rank: {divas_ranks['joint_rank']}")
    print(f"  Joint X shape: {divas_decomp['joint_X'].shape}")
    print("-" * 30)