import torch
from torchtyping import TensorType
from dataclasses import dataclass


@dataclass
class TensorPCAOutput:
    """
    projected_data: contains the low-dimensional representation of the data
    components: contains the selected principal components as columns
    component_variances: contains the variance explained by each component
    """

    projected_data: TensorType
    components: TensorType
    component_variances: TensorType


def get_tensor_principal_components(tensor: TensorType, n_components: int = None):
    # Compute the mean along the rows (axis=0)
    mean = torch.mean(tensor, dim=0)

    # Subtract the mean from the tensor
    centered_tensor = tensor - mean

    # Perform SVD on the centered tensor
    _, _, V = torch.linalg.svd(centered_tensor)

    # Extract the principal components
    components = V.t()

    if n_components is not None:
        # Select the first 'n_components' principal components
        components = components[:, :n_components]

    # Project the centered tensor onto the principal components
    projected_data = torch.matmul(centered_tensor, components)

    # Calculate the variance explained by each component
    total_variance = torch.sum(torch.square(centered_tensor))
    component_variances = (
        torch.square(torch.norm(projected_data, dim=0)) / total_variance
    )

    return TensorPCAOutput(
        projected_data=projected_data,
        components=components,
        component_variances=component_variances,
    )

def find_num_components_explaining_variance(component_variances, threshold: float):
    """
    Finds the number of components that explain more than the given threshold of variance.

    Args:
        variance_tensor (torch.Tensor): A tensor containing the component variances.
        threshold (float): The percentage of total variance to explain (e.g., 0.90 for 90%).

    Returns:
        int: The number of components needed to explain more than the threshold of variance.
    """
    assert component_variances.ndim == 1
    cumulative_variance = torch.cumsum(component_variances, dim=0)
    
    # Total variance (sum of all variances)
    total_variance = torch.sum(component_variances)
    
    # Find the number of components that explain more than the threshold
    num_components = torch.sum(cumulative_variance <= total_variance * threshold).item() + 1
    
    return num_components