'''Helper functions.'''
import numpy as np
import torch


def float_x(data):
    '''Set data array precision.'''
    return np.float32(data)

def matrix_sqrt(A, method='svd'):
    if method == 'svd':
        U,S,Vt = torch.linalg.svd(A)
        return U @ torch.diag(torch.sqrt(S)) @ Vt
    elif method == 'torch_eigh':
        eigs, vecs = torch.linalg.eigh(A)
        eigs = torch.sqrt(eigs)
        eigs[torch.isnan(eigs)] = 0.0
        return vecs @ torch.diag(eigs) @ vecs.T



def matrix_pow(A, pow, method='svd'):

    if method == 'svd':
        # SVD implementation
        U,S,Vt = torch.linalg.svd(A)
        return U @ torch.diag(torch.pow(S, pow)) @ Vt
    
    elif method == 'torch_eigh':
        #Torch eigh implementation
        eigs, vecs = torch.linalg.eigh(A)
        if pow == 0.5:
            eigs = torch.sqrt(eigs)
        else:
            eigs = torch.pow(eigs, pow)
        eigs[torch.isnan(eigs)] = 0.0
        return vecs @ torch.diag(eigs) @ vecs.T
    
    elif method == 'numpy_eigh':
    # Numpy eigh implementation
        eigs, vecs = np.linalg.eigh(A.cpu().numpy())
        if pow == 0.5:
            eigs = np.sqrt(eigs)
        else:
            eigs = np.power(eigs, pow)
        eigs[np.isnan(eigs)] = 0.0
        A = torch.from_numpy(vecs @ np.diag(eigs) @ vecs.T).to(A.device)
        return A