import torch


def fast_exp4(X):
    half_x2 = torch.linalg.matrix_power(X, 2) / 2
    y = X + half_x2
    y += X @ half_x2 / 3
    y += half_x2 @ half_x2 / 6
    y.diagonal(dim1=-2, dim2=-1).add_(1)
    return y


def fast_exp3(X):
    half_x2 = torch.linalg.matrix_power(X, 2) / 2
    y = X + half_x2
    y += X @ half_x2 / 3
    y.diagonal(dim1=-2, dim2=-1).add_(1)
    return y


def fast_exp2(X):
    half_x2 = torch.linalg.matrix_power(X, 2) / 2
    y = X + half_x2
    y.diagonal(dim1=-2, dim2=-1).add_(1)
    return y


@torch.no_grad()
def fast_exp(x):
    norm = x.norm(dim=(1, 2)).max()

    if norm < 0.05:
        return fast_exp2(x)

    if norm < 0.25:
        return fast_exp3(x)

    if norm < 1:
        return fast_exp4(x)

    return torch.matrix_exp(x)


@torch.no_grad()
def polar(A, eps=1e-10):
    AA = A.double()
    ATA = AA.mT @ AA
    r = (ATA - torch.eye(A.shape[-1], device=AA.device, dtype=AA.dtype))
    r = r.norm(dim=(1, 2)).max().item()
    if r < 1e-5:
        return A
    try:
        eigvals, eigvecs = torch.linalg.eigh(ATA)
        sqrt_eigvals = eigvals.clamp(min=eps).sqrt().unsqueeze(1)
        ATA_inv_sqrt = (eigvecs / sqrt_eigvals) @ eigvecs.mT
        U = AA @ ATA_inv_sqrt
        return U.to(A.dtype)
    except:
        # more stable but slower
        U, _, Vt = torch.linalg.svd(AA)
        Q = U @ Vt
        return Q.to(A.dtype)


@torch.compile
@torch.no_grad()
def so_proj(X, grad):
    proj_grad = X.mT @ grad
    proj_grad = 0.5 * (proj_grad - proj_grad.mT)
    return proj_grad


def l1_normalize(x, eps=1e-8):
    mean = x.mean((1, 2), keepdim=True).clamp(min=eps)
    return (x / mean).clamp(min=eps)
