import numpy as np
import torch
from torch_batch_svd import svd
from data_util import vector2tensor, tensor2vector, vector2tensor_1dim, tensor2vector_1dim, Log_vec2Log

def get_sqrt_sym(data, eps = 1e-5, returnInvAlso = False):
    if len(data.shape) == 2:
        data = vector2tensor_1dim(data)
    U,S,V = svd(data)
    # correct sign of the eigenvalues
    UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
    S_ = S.clone()
    S_[UtV < 0] = - S[UtV < 0]
    sqrt = torch.bmm(U, torch.bmm(torch.diag_embed(torch.sqrt(S_ + eps)), U.permute(0,2,1)))
    if returnInvAlso:
        invsqrt = torch.bmm(torch.bmm(U, torch.diag_embed(1.0/torch.sqrt(S_ + eps))), U.permute(0,2,1))
        return sqrt, invsqrt
    return sqrt

def get_sqrt_sym_DirDeriv(data, vecdir, eps = 1e-5, returnVec = False):
    data = vector2tensor_1dim(data)
    U,S,V = svd(data)
    # correct sign of the eigenvalues
    UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
    S[UtV < 0] = - S[UtV < 0]
    N = data.shape[0]
    dim = data.shape[1]
    sqrtS = torch.sqrt(S + eps)
    
    vecdir_trans = torch.matmul(torch.matmul(U.permute(0,2,1), vector2tensor_1dim(vecdir)), U)
    dS = torch.diagonal(vecdir_trans, dim1=-2, dim2=-1)
    
    if dim == 2:
        omega = vecdir_trans[:,0,1]/(S[:,0] - S[:,1])
        tempMat = torch.zeros(U.shape).cuda()
        tempMat[:,0,1] = (sqrtS[:,0] - sqrtS[:,1]) * omega
        tempMat[:,1,0] = tempMat[:,0,1]
    elif dim == 3:
        omega = torch.cat(
            ((vecdir_trans[:,1,2]/(S[:,1] - S[:,2])).view(-1,1), 
            (vecdir_trans[:,0,2]/(S[:,2] - S[:,0])).view(-1,1),
            (vecdir_trans[:,0,1]/(S[:,0] - S[:,1])).view(-1,1)),
        1)
        tempMat = torch.zeros(U.shape).cuda()
        tempMat[:,0,1] = (sqrtS[:,0] - sqrtS[:,1]) * omega[:,2]
        tempMat[:,1,0] = tempMat[:,0,1]
        tempMat[:,0,2] = (sqrtS[:,2] - sqrtS[:,0]) * omega[:,1]
        tempMat[:,2,0] = tempMat[:,0,2]
        tempMat[:,1,2] = (sqrtS[:,1] - sqrtS[:,2]) * omega[:,0]
        tempMat[:,2,1] = tempMat[:,1,2]
    else:
        NotImplemented
    tempjac = torch.matmul(torch.matmul(U, torch.diag_embed(dS/sqrtS/2.0) + tempMat), U.permute(0,2,1)).view(-1,dim,dim)
    if returnVec:
        return tensor2vector_1dim(tempjac)
    else:
        return tempjac

def get_sqrt(data, eps = 1e-5, returnInvAlso = False):
    if len(data.shape) == 2:
        data = vector2tensor_1dim(data)
    U,S,V = svd(data)
    # correct sign of the eigenvalues
    UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
    S_ = S.clone()
    S_[UtV < 0] = - S[UtV < 0]
    sqrt = torch.bmm(torch.diag_embed(torch.sqrt(S_ + eps)), U.permute(0,2,1))
    if returnInvAlso:
        invsqrt = torch.bmm(U, torch.diag_embed(1.0/torch.sqrt(S_ + eps)))
        return sqrt, invsqrt
    return sqrt

def Log_mat(data, eps = 1e-5, returnVec = False, returnBadIdx = False):
    if eps < 0:
        eps = 1e-5
    dim = data.shape[-1]
    U,S,V = svd(data.view(-1,dim,dim))
    # check sign of the eigenvalues and modify if negative eigenvalue obtained
    UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
    S[UtV < 0] = 0
    output = torch.matmul(torch.matmul(U, torch.diag_embed(torch.log(S + eps))), U.permute(0,2,1)).view(-1,dim,dim)
    if returnVec:
        if returnBadIdx:
            return tensor2vector_1dim(output), torch.cuda.FloatTensor(range(S.shape[0]))[torch.sum(UtV < 0, dim=1) > 0].long()
        return tensor2vector_1dim(output)
    if returnBadIdx:
        return output, torch.cuda.FloatTensor(range(S.shape[0]))[torch.sum(UtV < 0, dim=1) > 0].long()
    return output

def Exp_mat(data, returnVec = False, returnsvd = False):
    dim = data.shape[-1]
    U,S,V = svd(data.view(-1,dim,dim))
    # correct sign of the eigenvalues
    UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
    S_ = S.clone()
    S_[UtV < 0] = - S[UtV < 0]
    output = torch.matmul(torch.matmul(U, torch.diag_embed(torch.exp(S_))), U.permute(0,2,1)).view(-1,dim,dim)
    if returnsvd:
        if returnVec:
            return tensor2vector_1dim(output), (U,S)
        return output, (U,S)
    if returnVec:
        return tensor2vector_1dim(output)
    return output

def Exp_sqrt_mat(data, returnVec = False, returnsvd = False):
    dim = data.shape[-1]
    U,S,V = svd(data.view(-1,dim,dim))
    # correct sign of the eigenvalues
    UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
    S_ = S.clone()
    S_[UtV < 0] = - S[UtV < 0]
    output = torch.matmul(torch.matmul(U, torch.diag_embed(torch.exp(0.5*S_))), U.permute(0,2,1)).view(-1,dim,dim)
    if returnsvd:
        if returnVec:
            return tensor2vector_1dim(output), (U,S)
        return output, (U,S)
    if returnVec:
        return tensor2vector_1dim(output)
    return output

def Log(data, eps = 1e-5, returnVec = True, returnBadIdx = False):
    data = vector2tensor_1dim(data)
    return Log_mat(data, eps, returnVec, returnBadIdx)
    
def Exp(data, returnVec = True, returnsvd = False):
    data = vector2tensor_1dim(data)
    return Exp_mat(data, returnVec, returnsvd)

def Exp_sqrt(data, returnVec = True, returnsvd = False):
    data = vector2tensor_1dim(data)
    return Exp_sqrt_mat(data, returnVec, returnsvd)

def deltaMat_sqrt_approx(U, S, vec, eps = 1e-5):
    # U: eigenvectors of matrix P
    # S: eigenvalues of matrix P
    # vec: change direction in vector or matrix form of matrix P
    if len(vec.shape) == 2:
        vec = vector2tensor_1dim(vec)
    vec_trans = torch.matmul(torch.matmul(U.permute(0,2,1), vec), U)
    dS = torch.diagonal(vec_trans, dim1=-2, dim2=-1)
    
    if vec.shape[1] == 2:
        omega = vec_trans[:,0,1]/(S[:,0] - S[:,1])
        sqrtS = torch.sqrt(S + eps)
        tempMat = torch.zeros(U.shape).cuda()
        tempMat[:,0,1] = (sqrtS[:,0] - sqrtS[:,1]) * omega
        tempMat[:,1,0] = tempMat[:,0,1]
        tempjac = torch.matmul(torch.matmul(U, torch.diag_embed(0.5*dS/sqrtS) + tempMat), U.permute(0,2,1)).view(-1,2,2)
    elif vec.shape[1] == 3:
        omega = torch.cat(
            ((vec_trans[:,1,2]/(S[:,1] - S[:,2])).view(-1,1), 
            (vec_trans[:,0,2]/(S[:,2] - S[:,0])).view(-1,1),
            (vec_trans[:,0,1]/(S[:,0] - S[:,1])).view(-1,1)),
        1)
        sqrtS = torch.sqrt(S + eps)
        tempMat = torch.zeros(U.shape).cuda()
        tempMat[:,0,1] = (sqrtS[:,0] - sqrtS[:,1]) * omega[:,2]
        tempMat[:,1,0] = tempMat[:,0,1]
        tempMat[:,0,2] = (sqrtS[:,2] - sqrtS[:,0]) * omega[:,1]
        tempMat[:,2,0] = tempMat[:,0,2]
        tempMat[:,1,2] = (sqrtS[:,1] - sqrtS[:,2]) * omega[:,0]
        tempMat[:,2,1] = tempMat[:,1,2]
        tempjac = torch.matmul(torch.matmul(U, torch.diag_embed(0.5*dS/sqrtS) + tempMat), U.permute(0,2,1)).view(-1,3,3)
    else:
        NotImplemented
    
    
    return tempjac

def volumetricLog(data, returnVec = True, eps = 1e-5):
    ### input may not be positive-definite... in this case, modify the negative eigenvalues to eps value
    data_shape = data.shape
    data = vector2tensor(data)
    dim = data.shape[-1]
    output = Log_mat(data.view(-1,dim,dim), eps).view(data_shape[0],data_shape[1],data_shape[2],dim,dim)
    if returnVec:
        return tensor2vector(output)
    return output

def volumetricExp(data, returnVec = True):
    data_shape = data.shape
    data = vector2tensor(data)
    dim = data.shape[-1]
    output = Exp_mat(data.view(-1,dim,dim)).view(data_shape[0],data_shape[1],data_shape[2],dim,dim)
    if returnVec:
        return tensor2vector(output)
    return output

def ExpDirDeriv(data, vecdir, returnVec = True):
    data = vector2tensor_1dim(data)
    U,S,V = svd(data)
    # correct sign of the eigenvalues
    UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
    S[UtV < 0] = - S[UtV < 0]
    N = data.shape[0]
    dim = data.shape[1]
    expS = torch.exp(S)
    
    vecdir_trans = torch.matmul(torch.matmul(U.permute(0,2,1), vector2tensor_1dim(vecdir)), U)
    dS = torch.diagonal(vecdir_trans, dim1=-2, dim2=-1)
    
    if dim == 2:
        omega = vecdir_trans[:,0,1]/(S[:,0] - S[:,1])
        tempMat = torch.zeros(U.shape).cuda()
        tempMat[:,0,1] = (expS[:,0] - expS[:,1]) * omega
        tempMat[:,1,0] = tempMat[:,0,1]
    elif dim == 3:
        omega = torch.cat(
            ((vecdir_trans[:,1,2]/(S[:,1] - S[:,2])).view(-1,1), 
            (vecdir_trans[:,0,2]/(S[:,2] - S[:,0])).view(-1,1),
            (vecdir_trans[:,0,1]/(S[:,0] - S[:,1])).view(-1,1)),
        1)
        tempMat = torch.zeros(U.shape).cuda()
        tempMat[:,0,1] = (expS[:,0] - expS[:,1]) * omega[:,2]
        tempMat[:,1,0] = tempMat[:,0,1]
        tempMat[:,0,2] = (expS[:,2] - expS[:,0]) * omega[:,1]
        tempMat[:,2,0] = tempMat[:,0,2]
        tempMat[:,1,2] = (expS[:,1] - expS[:,2]) * omega[:,0]
        tempMat[:,2,1] = tempMat[:,1,2]
    else:
        NotImplemented
    tempjac = torch.matmul(torch.matmul(U, torch.diag_embed(expS*dS) + tempMat), U.permute(0,2,1)).view(-1,dim,dim)
    if returnVec:
        return tensor2vector_1dim(tempjac)
    else:
        return tempjac

def ExpJacobian(data, returnVec = True):
    deriv = torch.zeros(data.shape).cuda()
    data = vector2tensor_1dim(data)
    dim = data.shape[-1]
    U,S,V = svd(data.view(-1,dim,dim))
    # correct sign of the eigenvalues
    UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
    S[UtV < 0] = - S[UtV < 0]
    N = data.shape[0]
    expS = torch.exp(S)
    if dim == 3:
        if returnVec:
            jac = torch.zeros(N,6,6)
        else:
            jac = torch.zeros(N,3,3,6)
        for i in range(6):
            deriv[:,i] = 1
            deriv_trans = torch.matmul(torch.matmul(U.permute(0,2,1), vector2tensor_1dim(deriv)), U)
            dS = torch.diagonal(deriv_trans, dim1=-2, dim2=-1)
            omega = torch.cat(
                ((deriv_trans[:,1,2]/(S[:,1] - S[:,2])).view(-1,1), 
                (deriv_trans[:,0,2]/(S[:,2] - S[:,0])).view(-1,1),
                (deriv_trans[:,0,1]/(S[:,0] - S[:,1])).view(-1,1)),
            1)
            tempMat = torch.zeros(U.shape).cuda()
            tempMat[:,0,1] = (expS[:,0] - expS[:,1]) * omega[:,2]
            tempMat[:,1,0] = tempMat[:,0,1]
            tempMat[:,0,2] = (expS[:,2] - expS[:,0]) * omega[:,1]
            tempMat[:,2,0] = tempMat[:,0,2]
            tempMat[:,1,2] = (expS[:,1] - expS[:,2]) * omega[:,0]
            tempMat[:,2,1] = tempMat[:,1,2]
            tempjac = torch.matmul(torch.matmul(U, torch.diag_embed(expS*dS) + tempMat), U.permute(0,2,1)).view(-1,3,3)
            if returnVec:
                jac[:,:,i] = tensor2vector_1dim(tempjac)
            else:
                jac[:,:,:,i] = tempjac
            deriv[:,i] = 0
        return jac
    elif dim == 2:
        if returnVec:
            jac = torch.zeros(N,3,3)
        else:
            jac = torch.zeros(N,2,2,3)
        for i in range(3):
            deriv[:,i] = 1
            deriv_trans = torch.matmul(torch.matmul(U.permute(0,2,1), vector2tensor_1dim(deriv)), U)
            dS = torch.diagonal(deriv_trans, dim1=-2, dim2=-1)
            omega = deriv_trans[:,0,1]/(S[:,0] - S[:,1])
            tempMat = torch.zeros(U.shape).cuda()
            tempMat[:,0,1] = (expS[:,0] - expS[:,1]) * omega
            tempMat[:,1,0] = tempMat[:,0,1]
            tempjac = torch.matmul(torch.matmul(U, torch.diag_embed(expS*dS) + tempMat), U.permute(0,2,1)).view(-1,2,2)
            if returnVec:
                jac[:,:,i] = tensor2vector_1dim(tempjac)
            else:
                jac[:,:,:,i] = tempjac
            deriv[:,i] = 0
        return jac
    return

def LogJacobian(data, eps = 1e-5, returnVec = True):
    deriv = torch.zeros(data.shape).cuda()
    data = vector2tensor_1dim(data)
    dim = data.shape[-1]
    U,S,V = svd(data.view(-1,dim,dim))
    # correct sign of the eigenvalues
    UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
    S[UtV < 0] = - S[UtV < 0]
    N = data.shape[0]
    logS = torch.log(S + eps)
    if dim == 3:
        if returnVec:
            jac = torch.zeros(N,6,6)
        else:
            jac = torch.zeros(N,3,3,6)
        for i in range(6):
            deriv[:,i] = 1
            deriv_trans = torch.matmul(torch.matmul(U.permute(0,2,1), vector2tensor_1dim(deriv)), U)
            dS = torch.diagonal(deriv_trans, dim1=-2, dim2=-1)
            omega = torch.cat(
                ((deriv_trans[:,1,2]/(S[:,1] - S[:,2])).view(-1,1), 
                (deriv_trans[:,0,2]/(S[:,2] - S[:,0])).view(-1,1),
                (deriv_trans[:,0,1]/(S[:,0] - S[:,1])).view(-1,1)),
            1)
            tempMat = torch.zeros(U.shape).cuda()
            tempMat[:,0,1] = (logS[:,0] - logS[:,1]) * omega[:,2]
            tempMat[:,1,0] = tempMat[:,0,1]
            tempMat[:,0,2] = (logS[:,2] - logS[:,0]) * omega[:,1]
            tempMat[:,2,0] = tempMat[:,0,2]
            tempMat[:,1,2] = (logS[:,1] - logS[:,2]) * omega[:,0]
            tempMat[:,2,1] = tempMat[:,1,2]
            tempjac = torch.matmul(torch.matmul(U, torch.diag_embed(dS/(S + eps)) + tempMat), U.permute(0,2,1)).view(-1,3,3)
            if returnVec:
                jac[:,:,i] = tensor2vector_1dim(tempjac)
            else:
                jac[:,:,:,i] = tempjac
            deriv[:,i] = 0
        return jac
    elif dim == 2:
        if returnVec:
            jac = torch.zeros(N,3,3)
        else:
            jac = torch.zeros(N,2,2,3)
        for i in range(3):
            deriv[:,i] = 1
            deriv_trans = torch.matmul(torch.matmul(U.permute(0,2,1), vector2tensor_1dim(deriv)), U)
            dS = torch.diagonal(deriv_trans, dim1=-2, dim2=-1)
            omega = deriv_trans[:,0,1]/(S[:,0] - S[:,1])
            tempMat = torch.zeros(U.shape).cuda()
            tempMat[:,0,1] = (logS[:,0] - logS[:,1]) * omega
            tempMat[:,1,0] = tempMat[:,0,1]
            tempjac = torch.matmul(torch.matmul(U, torch.diag_embed(dS/(S + eps)) + tempMat), U.permute(0,2,1)).view(-1,2,2)
            if returnVec:
                jac[:,:,i] = tensor2vector_1dim(tempjac)
            else:
                jac[:,:,:,i] = tempjac
            deriv[:,i] = 0
        return jac
    return

def group_action(tensor, A, returnVec = False):
    ### tensor: N x dim x dim or N x cov_dim
    ### A: N x dim x dim
    if len(tensor.shape) == 2:
        tensor = vector2tensor_1dim(tensor)
    temp = torch.bmm(torch.bmm(A, tensor), A.permute(0,2,1))
    if returnVec:
        return tensor2vector_1dim(temp)
    return temp

def posMetric_func_N_n(vec, covInv_sqrt):
    U,S,V = svd(vector2tensor_1dim(vec))
    UtV = torch.diagonal(torch.bmm(U.permute(0,2,1),V), dim1=1, dim2=2)
    S_ = S.clone()
    S_[UtV < 0] = - S[UtV < 0]
    expSinv = torch.exp(-S_)
    AU = torch.bmm(covInv_sqrt, U)
    posMetric = torch.bmm(
        AU,
        torch.bmm(torch.diag_embed(expSinv), AU.permute(0,2,1))
    )
    
    return posMetric

def exponential_map_fromcovvec(covvec, vec, eps = 1e-5, returnsvd = False):
    ### covvec, vec: N x cov_dim dimensional matrix
    ### cov_sqrt: D^(1/2)R^T of covariance (RDR^T)
    ### covInv_sqrt: RD^(-1/2) of covariance (RDR^T)
    
    U, S, V = svd(vector2tensor_1dim(covvec))
    UtV = torch.diagonal(torch.bmm(U.permute(0,2,1),V), dim1=1, dim2=2)
    S[UtV < 0] = -(S[UtV < 0])
    
    ### for debugging
    if (torch.min(S) < 0):
        idx = torch.argmin(S.view(-1))
        dataidx = idx // 3
        print(torch.min(S))
        print("U\n", U[dataidx])
        print("S\n", S[dataidx])
        print("covvec\n", covvec[dataidx])
        print("vec\n", vec[dataidx])
    assert(torch.min(S) >= 0)
    cov_sqrt = torch.bmm(
        torch.diag_embed(torch.sqrt(S)), 
        U.permute(0,2,1)
    )
    assert(torch.sum(torch.isnan(cov_sqrt)) == 0)
    covInv_sqrt = torch.bmm(
        U, 
        torch.diag_embed(1.0/torch.sqrt(S + eps))
    )
    assert(torch.sum(torch.isnan(covInv_sqrt)) == 0)
    exp_vec = tensor2vector_1dim(
        torch.bmm(
            cov_sqrt.permute(0,2,1),
            torch.bmm(
                Exp_mat(
                    torch.bmm(
                        torch.bmm(covInv_sqrt.permute(0,2,1), vector2tensor_1dim(vec)),
                        covInv_sqrt
                    )
                ),
                cov_sqrt
            )
        )
    )
    assert(torch.sum(torch.isnan(exp_vec)) == 0)
    if returnsvd:
        return exp_vec, (U,S)
    return exp_vec

def exponential_map(vec, cov_sqrt, covInv_sqrt):
    ### vec: N x cov_dim dimensional matrix
    ### cov_sqrt: D^(1/2)R^T of covariance (RDR^T)
    ### covInv_sqrt: RD^(-1/2) of covariance (RDR^T)
    exp_vec = tensor2vector_1dim(
        torch.bmm(
            cov_sqrt.permute(0,2,1),
            torch.bmm(
                Exp_mat(
                    torch.bmm(
                        torch.bmm(covInv_sqrt.permute(0,2,1), vector2tensor_1dim(vec)),
                        covInv_sqrt
                    )
                ),
                cov_sqrt
            )
        )
    )
    return exp_vec

def logarithm_map(cov2, cov_sqrt, covInv_sqrt):
    ### calculate Log_(cov)(cov2)
    ### cov2: N x cov_dim vector form of the covariance matrix 2
    ### cov_sqrt: D^(1/2)R^T of covariance (RDR^T)
    ### covInv_sqrt: RD^(-1/2) of covariance (RDR^T)
    exp_vec = tensor2vector_1dim(
        torch.bmm(
            cov_sqrt.permute(0,2,1),
            torch.bmm(
                Log_mat(
                    torch.bmm(
                        torch.bmm(covInv_sqrt.permute(0,2,1), vector2tensor_1dim(cov2)),
                        covInv_sqrt
                    )
                ),
                cov_sqrt
            )
        )
    )
    return exp_vec

def exponential_map_Inv_sqrt(vec, cov_sqrt, covInv_sqrt, eps = 1e-5):
    ### return inverse sqrt of exp_cov(vec)
    ### vec: N x cov_dim dimensional matrix
    ### cov_sqrt: D^(1/2)R^T of covariance (RDR^T)
    ### covInv_sqrt: RD^(-1/2) of covariance (RDR^T)
    temp = torch.bmm(
        torch.bmm(covInv_sqrt.permute(0,2,1), vector2tensor_1dim(vec)),
        covInv_sqrt
    )
    U,S,V = svd(temp)
    UtV = torch.diagonal(torch.bmm(U.permute(0,2,1),V), dim1=1, dim2=2)
    S_ = S.clone()
    S_[UtV < 0] = - S[UtV < 0]
    expS = torch.exp(S_)
    
    exp_cov_vec = torch.bmm(
        torch.bmm(cov_sqrt.permute(0,2,1), U),
        torch.bmm(
            torch.diag_embed(expS), 
            torch.bmm(U.permute(0,2,1), cov_sqrt)
        )
    )
    
    U2,S2,V2 = svd(exp_cov_vec)
    U2tV2 = torch.diagonal(torch.bmm(U2.permute(0,2,1),V2), dim1=1, dim2=2)
    S2_ = S2.clone()
    S2_[U2tV2 < 0] = - S2[U2tV2 < 0]
    return torch.bmm(U2, torch.diag_embed(1.0/torch.sqrt(S2+eps)))

def metric_N_n(data, covInv = None):
    ### data: pos(0:3) and cov(3:9) for N(3)
    ### data: pos(0:2) and cov(2:5) for N(2)
    N = data.shape[0]
    dim = data.shape[1]
    metric = torch.zeros(N,dim,dim)
    if dim == 5:
        if covInv is None:
            cov = vector2tensor_1dim(data[:,2:])
            covInv = torch.inverse(cov)

        metric[:,:2,:2] = covInv

        g11 = covInv[:,0,0]
        g12 = covInv[:,0,1]
        g22 = covInv[:,1,1]

        metric[:,2,2] = g11**2/2.0
        metric[:,2,3] = g11*g12
        metric[:,2,4] = g12**2/2.0

        metric[:,3,2] = g11*g12
        metric[:,3,3] = g12**2 + g11*g22
        metric[:,3,4] = g12*g22

        metric[:,4,2] = g12**2/2.0
        metric[:,4,3] = g12*g22
        metric[:,4,4] = g22**2/2.0
    elif dim == 9:
        if covInv is None:
            cov = vector2tensor_1dim(data[:,3:])
            covInv = torch.inverse(cov)

        metric[:,:3,:3] = covInv

        g11 = covInv[:,0,0]
        g12 = covInv[:,0,1]
        g13 = covInv[:,0,2]
        g22 = covInv[:,1,1]
        g23 = covInv[:,1,2]
        g33 = covInv[:,2,2]

        metric[:,3,3] = g11**2/2.0
        metric[:,3,4] = g11*g12
        metric[:,3,5] = g11*g13
        metric[:,3,6] = g12**2/2.0
        metric[:,3,7] = g12*g13
        metric[:,3,8] = g13**2/2.0

        metric[:,4,3] = g11*g12
        metric[:,4,4] = g12**2 + g11*g22
        metric[:,4,5] = g12*g13 + g11*g23
        metric[:,4,6] = g12*g22
        metric[:,4,7] = g12*g23 + g13*g22
        metric[:,4,8] = g13*g23

        metric[:,5,3] = g11*g13
        metric[:,5,4] = g12*g13 + g11*g23
        metric[:,5,5] = g13**2 + g11*g33
        metric[:,5,6] = g12*g23
        metric[:,5,7] = g13*g23 + g12*g33
        metric[:,5,8] = g13*g33

        metric[:,6,3] = g12**2/2.0
        metric[:,6,4] = g12*g22
        metric[:,6,5] = g12*g23
        metric[:,6,6] = g22**2/2.0
        metric[:,6,7] = g22*g23
        metric[:,6,8] = g23**2/2.0

        metric[:,7,3] = g12*g13
        metric[:,7,4] = g12*g23 + g13*g22
        metric[:,7,5] = g13*g23 + g12*g33
        metric[:,7,6] = g22*g23
        metric[:,7,7] = g23**2 + g22*g33
        metric[:,7,8] = g23*g33

        metric[:,8,3] = g13**2/2.0
        metric[:,8,4] = g13*g23
        metric[:,8,5] = g13*g33
        metric[:,8,6] = g23**2/2.0
        metric[:,8,7] = g23*g33
        metric[:,8,8] = g33**2/2.0
    else:
        NotImplemented
    if data.is_cuda:
        metric = metric.cuda()
    return metric

def metric_sqrt_N_n(data, covInv = None, returnMetric = False):
    ### data: pos(0:3) and cov(3:9) for N(3)
    ### data: pos(0:2) and cov(2:5) for N(2)
    N = data.shape[0]
    dim = data.shape[1]
    metric = metric_N_n(data, covInv)
    metric_sqrt = torch.zeros(metric.shape).cuda()
    if dim == 5:
        pos_dim = 2
    elif dim == 9:
        pos_dim = 3
    else:
        NotImplemented
    metric_sqrt[:,:pos_dim,:pos_dim] = get_sqrt_sym(metric[:,:pos_dim,:pos_dim].cuda())
    metric_sqrt[:,pos_dim:,pos_dim:] = get_sqrt_sym(metric[:,pos_dim:,pos_dim:].cuda())
    if data.is_cuda:
        metric = metric.cuda
    else:
        metric_sqrt = metric_sqrt.cpu()
    if returnMetric:
        return metric_sqrt, metric
    return metric_sqrt

def metricInv_sqrt_N_n(data, covInv = None, returnMetric = False):
    ### data: pos(0:3) and cov(3:9) for N(3)
    ### data: pos(0:2) and cov(2:5) for N(2)
    N = data.shape[0]
    dim = data.shape[1]
    metric = metric_N_n(data, covInv)
    metricInv_sqrt = torch.zeros(metric.shape).cuda()
    if dim == 5:
        pos_dim = 2
    elif dim == 9:
        pos_dim = 3
    else:
        NotImplemented
    _, metricInv_sqrt[:,:pos_dim,:pos_dim] = get_sqrt_sym(metric[:,:pos_dim,:pos_dim].cuda(), returnInvAlso = True)
    _, metricInv_sqrt[:,pos_dim:,pos_dim:] = get_sqrt_sym(metric[:,pos_dim:,pos_dim:].cuda(), returnInvAlso = True)
    if data.is_cuda:
        metric = metric.cuda()
    else:
        metricInv_sqrt = metricInv_sqrt.cpu()
    if returnMetric:
        return metricInv_sqrt, metric
    return metricInv_sqrt

def christoffelSum_N_n(data, covInv = None):
    ### data: pos(0:3) and cov(3:9) for N(3)
    ### data: pos(0:2) and cov(2:5) for N(2)
    N = data.shape[0]
    dim = data.shape[1]
    chSum = torch.zeros(N,dim)
    if dim == 5:
        if covInv is None:
            cov = vector2tensor_1dim(data[:,2:])
            covInv = torch.inverse(cov)

        g11 = covInv[:,0,0]
        g12 = covInv[:,0,1]
        g22 = covInv[:,1,1]

        chSum[:,2] = - 2.0*g11
        chSum[:,3] = - 4.0*g12
        chSum[:,4] = - 2.0*g22
    elif dim == 9:
        if covInv is None:
            cov = vector2tensor_1dim(data[:,3:])
            covInv = torch.inverse(cov)

        g11 = covInv[:,0,0]
        g12 = covInv[:,0,1]
        g13 = covInv[:,0,2]
        g22 = covInv[:,1,1]
        g23 = covInv[:,1,2]
        g33 = covInv[:,2,2]

        chSum[:,3] = - 2.5*g11
        chSum[:,4] = - 5.0*g12
        chSum[:,5] = - 5.0*g13
        chSum[:,6] = - 2.5*g22
        chSum[:,7] = - 5.0*g23
        chSum[:,8] = - 2.5*g33
    else:
        NotImplemented
    if data.is_cuda:
        chSum = chSum.cuda()
    return chSum

def christoffelSum(data):
    #### christoffelSum for R^3 x P(3) using covariance inverse metric for pos, and log-Euclidean metric for cov
    N = data.shape[0]
    if data.shape[1] == 9:
        chSum = torch.FloatTensor([0,0,0,-0.5,0,0,-0.5,0,-0.5])
        if data.is_cuda:
            chSum = chSum.cuda()
        return chSum.view(1,9).expand(N,-1)
    elif data.shape[1] == 5:
        chSum = torch.FloatTensor([0,0,-0.5,0,-0.5])
        if data.is_cuda:
            chSum = chSum.cuda()
        return chSum.view(1,5).expand(N,-1)
    return

def posMetric_sqrt_func(data_vec, eps = 1e-5):
    #### posMetric_sqrt function for log-Euclidean metric
    if data_vec.shape[1] == 9:
        pos_dim = 3
    else:
        pos_dim = 2
    log_vec_data = data_vec[:,pos_dim:]
    U,S,V = svd(vector2tensor_1dim(Log_vec2Log(log_vec_data.cuda())).view(-1,pos_dim,pos_dim))
    UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
    S_ = S.clone()
    S_[UtV < 0] = - S[UtV < 0]
    expS = torch.exp(S_)
    metric_sqrt = torch.bmm(U, torch.diag_embed(1.0/torch.sqrt(expS + eps)))
    return metric_sqrt

def posMetric_sqrt_func_interpolate(data_vec, eps = 1e-5):
    if data_vec.shape[1] == 9:
        pos_dim = 3
    else:
        pos_dim = 2
    log_vec_data = data_vec[:,pos_dim:]
    U,S,V = svd(vector2tensor_1dim(Log_vec2Log(log_vec_data.cuda())).view(-1,pos_dim,pos_dim))
    UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
    S_ = S.clone()
    S_[UtV < 0] = - S[UtV < 0]
    expS = torch.exp(S_)
    metric_sqrt = torch.bmm(U, torch.diag_embed(1.0/torch.sqrt(expS + eps)))
    return metric_sqrt

def metricInv(data_vec, metricCoeff = 1):
    N = data_vec.shape[0]
    dim = data_vec.shape[1]
    if dim == 9:
        pos_dim = 3
    else:
        pos_dim = 2
    logdata = Log_vec2Log(data_vec[:,pos_dim:]).cuda()
    posMetricInv = Exp(logdata, returnVec = False).cpu()
    metInv = torch.zeros(N,dim,dim)
    metInv[:,:pos_dim,:pos_dim] = posMetricInv
    for i in range(dim - pos_dim):
        metInv[:,i+pos_dim,i+pos_dim] = 1.0/metricCoeff
    return metInv

def metricInvDeriv(data_vec, returnPosPartOnly = False):
    # output shape: N x d x d x d (d = 9 or 5)
    N = data_vec.shape[0]
    dim = data_vec.shape[1]
    if dim == 9:
        pos_dim = 3
        multiply_idx = [1,2,4]
    else:
        pos_dim = 2
        multiply_idx = [1]
    logdata = Log_vec2Log(data_vec[:,pos_dim:])
    jac = ExpJacobian(logdata, returnVec = False).cpu()
    for i in multiply_idx:
        jac[:,:,:,i] /= np.sqrt(2)
    if returnPosPartOnly:
        output = torch.zeros(N,pos_dim,pos_dim,dim)
        output[:,:,:,pos_dim:] = jac
        return output
    output = torch.zeros(N,dim,dim,dim)
    output[:,:pos_dim,:pos_dim,pos_dim:] = jac
    return output

def DTI2dim2ellipseInfo(data, num = 20):
    data = data.cuda()
    N = data.shape[0]
    pos = data[:,:2]
    logvec = data[:,2:]
    U,S,V = svd(vector2tensor_1dim(Log_vec2Log(logvec)))
    UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
    S[UtV < 0] = - S[UtV < 0]
    radii = torch.exp(S)
    theta = torch.linspace(0,2*np.pi,num).cuda().view(1,-1)
    ellipse_points = torch.cat(
        [(radii[:,0:1]*torch.cos(theta)).view(N,num,1), 
         (radii[:,1:2]*torch.sin(theta)).view(N,num,1)], 2)
    ellipse_points = torch.matmul(ellipse_points, U.permute(0,2,1))
    return ellipse_points + pos.view(N,1,2)

def DTI2dim2ellipseInfoFromCov(data, num = 20):
    data = data.cuda()
    N = data.shape[0]
    pos = data[:,:2]
    cov = data[:,2:]
    U,S,V = svd(vector2tensor_1dim(cov))
    UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
    S[UtV < 0] = - S[UtV < 0]
    radii = S
    theta = torch.linspace(0,2*np.pi,num).cuda().view(1,-1)
    ellipse_points = torch.cat(
        [(radii[:,0:1]*torch.cos(theta)).view(N,num,1), 
         (radii[:,1:2]*torch.sin(theta)).view(N,num,1)], 2)
    ellipse_points = torch.matmul(ellipse_points, U.permute(0,2,1))
    return ellipse_points + pos.view(N,1,2)