import torch
from torch import Tensor
from typing import Optional, Tuple

def whitening(x: Tensor, eps: float, ridge: float = 1e-3) -> Tuple[Tensor, Tensor]:
    """
    """
    mu = x.mean(dim=0, keepdim=True)
    x_c = x - mu
    cov = (x_c.T @ x_c) / (len(x) - 1) # (K, K)
    ridge_id = ridge * torch.eye(len(cov), device=cov.device)
    s, U = torch.linalg.eigh(cov + ridge_id)    
    s_stable = torch.clamp(s, min=s.max().item() * 1e-12 + eps)
    inv_sqrt_s = torch.diag(1.0 / torch.sqrt(s_stable))

    # Don't need last basis change to compute eigenvalues
    cov_inv_sqrt = (U @ inv_sqrt_s) #@ U.T
    return mu, cov_inv_sqrt
     
def compute_eigenvalues(
    x: Tensor,
    y: Tensor,
    mu: Tensor,
    whitener: Tensor,
    std: Optional[Tensor] = None,
    normalize: bool = True,
) -> Tensor:
    """
    """
    # Centering
    x_c = x - mu # (B, K)
    y_c = y - mu # (B, K)

    # Whitening
    x_w = x_c @ whitener # (B, K)
    y_w = y_c @ whitener # (B, K)

    # Better for training stability
    if normalize:
        if std is None:
            aux = x_w.std(0, keepdim=True)
            x_w = x_w / (x_w.std(0, keepdim=True) + 1e-8)
            y_w = y_w / (y_w.std(0, keepdim=True) + 1e-8)
            print(aux.flatten()[:100])
        else:
            x_w = x_w / (std + 1e-8)
            y_w = y_w / (std + 1e-8)

    eigenvalues = (x_w * y_w).mean(0)

    return eigenvalues # (K,)
