import torch

def ldl(C, epsilon=1e-6):
    C_ = C.clone()
    L = torch.linalg.cholesky(C_ + epsilon*torch.eye(len(C_), dtype=C_.dtype, device=C_.device), upper=False)

    D = L.diagonal()**2
    L = L / L.diagonal()
    
    return L, D

def sqrtm(M, epsilon=1e-12):
    return mpow(M, 1/2, epsilon)

def mpow(M, power, epsilon=1e-12):           # ONLY FOR SYMMETRIC MATRICES!
    L, Q = torch.linalg.eigh(M + epsilon * torch.eye(len(M), dtype=M.dtype, device=M.device))
    return Q @ torch.diag_embed( L**(power) ) @ Q.mH

def anymul(M1, M2):
    return torch.einsum('ij..., jk->ik...', M1, M2)

def rev_cumsum(x):
    return torch.flip(torch.cumsum(torch.flip(x, dims=[0]), dim=0), dims=[0])


