import torch


@torch.no_grad()
@torch.amp.autocast(device_type="cuda", enabled=False)
def get_cum_var(data, eps=1e-4):
    # data: [num_samples, num_features]
    data = data.float()
    N, D = data.shape
    device = data.device
    assert N >= D, f"Number of samples must be greater than number of features. Got {N} samples and {D} features."
    data = data - data.mean(dim=0, keepdim=True)
    cov = data.T @ data / (N - 1)
    eigvals, eigvecs = torch.linalg.eig(cov)
    eigvals, eigvecs = eigvals.real, eigvecs.real

    # sanity check
    # if not torch.allclose(eigvecs @ eigvecs.T, torch.eye(D, device=device), atol=eps):
    #     print(
    #         f"Warning: eigenvectors are not orthogonal: "
    #         f"{torch.linalg.norm(eigvecs @ eigvecs.T - torch.eye(D, device=device))}"
    #     )

    # The returned eigenvalues are not guaranteed to be in any specific order.
    # Sort eigenvalues and eigenvectors
    sorted_indices = torch.argsort(eigvals, descending=True)
    eigvals = eigvals[sorted_indices]
    eigvecs = eigvecs[:, sorted_indices]
    # Compute explained variance
    explained_var = eigvals / eigvals.sum()
    cum_explained_var = explained_var.cumsum(dim=0)  # [D]
    cum_explained_var = torch.cat([torch.tensor([0], device=device), cum_explained_var])  # [D+1]
    return cum_explained_var


def get_cum_var_ref(data, eps=1e-4):
    import numpy as np
    from sklearn.decomposition import PCA

    pca = PCA()
    pca.fit(data)
    cum_var = np.array([0] + np.cumsum(pca.explained_variance_ratio_).tolist())
    return torch.tensor(cum_var, dtype=torch.float32)


@torch.no_grad()
def get_effective_dim(data, thresh=0.99):
    try:
        cum_var = get_cum_var(data)
        # get the first index where cum_var > thresh
        eff_dim = torch.argmax((cum_var > thresh).long())
        return float(eff_dim)
    except Exception as e:
        print(f"Error: {e}")
        return float("nan")