'''Helper functions.'''
import numpy as np
import torch


def float_x(data):
    '''Set data array precision.'''
    return np.float32(data)

def convert_one_hot(y, c):
    o = np.zeros((y.size, c))
    o[np.arange(y.size), y] = 1
    return o

def matrix_sqrt(A):
    U,S,Vt = torch.linalg.svd(A)
    return U @ torch.diag(torch.sqrt(S)) @ Vt


def matrix_pow(A, pow):
    U,S,Vt = torch.linalg.svd(A)
    return U @ torch.diag(torch.pow(S,pow)) @ Vt