import math
import torch
from torch_batch_svd import svd

##### codes for generic P(n) operations #####

def vecdim2matdim(vdim):
    return int((math.sqrt(1 + 8*vdim)-1)/2)

def vec2mat(x):
    assert len(x.shape) == 2
    N = x.shape[0]
    dim = int((math.sqrt(1 + 8*x.shape[1])-1)/2)
    if x.is_cuda:
        mat = torch.cuda.FloatTensor(N, dim, dim)
    else:
        mat = torch.FloatTensor(N, dim, dim)
    k = 0
    for i in range(dim):
        for j in range(i, dim):
            mat[:,i,j] = x[:,k]
            mat[:,j,i] = x[:,k]
            k += 1
    return mat

def mat2vec(X):
    assert len(X.shape) == 3
    N = X.shape[0]
    dim = X.shape[1]
    if X.is_cuda:
        vec = torch.cuda.FloatTensor(N, int(dim*(dim+1)/2))
    else:
        vec = torch.FloatTensor(N, int(dim*(dim+1)/2))
    k = 0
    for i in range(dim):
        for j in range(i, dim):
            vec[:,k] = X[:,i,j]
            k += 1
    return vec

def batch_eigsym(X, safe_backward = False, maxIter = 10):
    # assume symmetric input matrices X of size N x dim x dim
    U,S,V = svd(X)
    
    if safe_backward:
        i = 0
        while torch.min(S[:,:-1]-S[:,1:]) < 1e-7: # 1e-7
            T = torch.diag_embed(torch.cuda.FloatTensor(X.shape[0], X.shape[1]).uniform_()*5e-7)
            U,S,V = svd(X+T)
            i += 1
            if i > maxIter:
                break
    UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
    S_ = S.clone()
    S_[UtV < 0] = - S[UtV < 0]
    return S_, U

def Exp_mat(X, S = None, U = None, safe_backward = False):
    if S is None or U is None:
        S, U = batch_eigsym(X, safe_backward)
    return torch.bmm(torch.bmm(U, torch.diag_embed(torch.exp(S))), U.permute(0,2,1))

def Log_mat(X, eps = 1e-7, S = None, U = None, safe_backward = False):
    if S is None or U is None:
        S, U = batch_eigsym(X, safe_backward)
        S[S<eps] = eps
    return torch.bmm(torch.bmm(U, torch.diag_embed(torch.log(S))), U.permute(0,2,1))

def Exp_vec(x, S = None, U = None, safe_backward = False):
    return Exp_mat(vec2mat(x), S = S, U = U, safe_backward = safe_backward)

def Exp_vec_approx(v, approx = None, Eye = None, safe_backward = False):
    if Eye is None:
        mat_dim = vecdim2matdim(v.shape[1])
        Eye = torch.eye(mat_dim).cuda().view(1, mat_dim, mat_dim)
    V = vec2mat(v)
    if approx == 1:
        Exp_v = Eye + V
    elif approx == 2:
        Exp_v = Eye + V + 0.5*torch.bmm(V, V)
    elif approx == 3:
        V_sq = torch.bmm(V, V)
        Exp_v = Eye + V + 0.5*V_sq + torch.bmm(V_sq, V)/6.
    elif approx == 4:
        V_sq = torch.bmm(V, V)
        V_cub = torch.bmm(V_sq, V)
        Exp_v = Eye + V + 0.5*V_sq + V_cub/6. + torch.bmm(V_cub, V)/24.
    else:
        Exp_v = Exp_vec(v, safe_backward = safe_backward)
    return Exp_v

def Log_mat_FnormSq_approx(dX, approx = None, Eye = None, eps = 1e-7, safe_backward = False):
    # return approximated squared Frobenius norm of Log(dX)
    if Eye is None:
        mat_dim = dX.shape[1]
        Eye = torch.eye(mat_dim).cuda().view(1, mat_dim, mat_dim)
    delta = dX-Eye
    if approx == 1:
        return torch.sum((delta)**2)
    elif approx == 2:
        return torch.sum((delta - 0.5*torch.bmm(delta, delta))**2)
    elif approx == 3:
        delta_sq = torch.bmm(delta, delta)
        return torch.sum((delta - 0.5*delta_sq + torch.bmm(delta_sq, delta)/3.)**2)
    elif approx == 4:
        delta_sq = torch.bmm(delta, delta)
        delta_cub = torch.bmm(delta_sq, delta)
        return torch.sum((delta - 0.5*delta_sq + delta_cub/3. - 0.25*torch.bmm(delta_cub, delta))**2)
    S, _ = batch_eigsym(dX, safe_backward = safe_backward)
    S[S<eps] = eps
    return torch.sum(torch.log(S)**2)

def get_sqrt_sym(X, eps = 1e-14, returnInvAlso = False, S = None, U = None, safe_backward = False):
    if S is None or U is None:
        S, U = batch_eigsym(X, safe_backward)
        S[S<eps] = eps
    sqrt = torch.bmm(U, torch.bmm(torch.diag_embed(torch.sqrt(S)), U.permute(0,2,1)))
    if returnInvAlso:
        invsqrt = torch.bmm(torch.bmm(U, torch.diag_embed(1.0/torch.sqrt(S))), U.permute(0,2,1))
        return sqrt, invsqrt
    return sqrt

def get_sqrt_sym_DirDeriv(X, Xdot, eps = 1e-14, S = None, U = None, Xdot_trans = None, safe_backward = False):
    if S is None or U is None:
        S, U = batch_eigsym(X, safe_backward)
        S[S<eps] = eps
    if Xdot_trans is None:
        Xdot_trans = torch.matmul(torch.matmul(U.permute(0,2,1), Xdot), U)
    N = X.shape[0]
    dim = X.shape[1]
    sqrtS = torch.sqrt(S)
    Sdot = torch.diagonal(Xdot_trans, dim1=-2, dim2=-1)
    tempMat = torch.zeros(U.shape).cuda()
    for i in range(dim):
        for j in range(i+1, dim):
            tempMat[:,i,j] = Xdot_trans[:,i,j] * (sqrtS[:,i] - sqrtS[:,j]) / (S[:,i] - S[:,j])
            temp_norm = torch.abs(S[:,i] - S[:,j])
            tempMat[temp_norm < eps,i,j] = Xdot_trans[temp_norm < eps,i,j] * 0.5 / sqrtS[temp_norm < eps,i]
            tempMat[:,j,i] = tempMat[:,i,j]
    
    tempjac = torch.matmul(torch.matmul(U, torch.diag_embed(Sdot/sqrtS/2.0) + tempMat), U.permute(0,2,1))
    return tempjac

def get_sqrtInv_sym_DirDeriv(X, Xdot, eps = 1e-14, S = None, U = None, Xdot_trans = None, safe_backward = False):
    if S is None or U is None:
        S, U = batch_eigsym(X, safe_backward)
        S[S<eps] = eps
    if Xdot_trans is None:
        Xdot_trans = torch.matmul(torch.matmul(U.permute(0,2,1), Xdot), U)
    N = X.shape[0]
    dim = X.shape[1]
    sqrtS_inv = 1./torch.sqrt(S)
    Sdot = torch.diagonal(Xdot_trans, dim1=-2, dim2=-1)
    tempMat = torch.zeros(U.shape).cuda()
    for i in range(dim):
        for j in range(i+1, dim):
            tempMat[:,i,j] = Xdot_trans[:,i,j] * (sqrtS_inv[:,i] - sqrtS_inv[:,j]) / (S[:,i] - S[:,j])
            temp_norm = torch.abs(S[:,i] - S[:,j])
            tempMat[temp_norm < eps,i,j] = Xdot_trans[temp_norm < eps,i,j] * 0.5 * sqrtS_inv[temp_norm < eps,i]**3
            tempMat[:,j,i] = tempMat[:,i,j]
    
    tempjac = torch.matmul(torch.matmul(U, torch.diag_embed(-Sdot*sqrtS_inv**3/2.0) + tempMat), U.permute(0,2,1))
    return tempjac

def ExpDirDeriv(X, Xdot, eps = 1e-14, S = None, U = None, Xdot_trans = None, safe_backward = False):
    if S is None or U is None:
        S, U = batch_eigsym(X, safe_backward)
    if Xdot_trans is None:
        Xdot_trans = torch.matmul(torch.matmul(U.permute(0,2,1), Xdot), U)
    N = X.shape[0]
    dim = X.shape[1]
    expS = torch.exp(S)
    Sdot = torch.diagonal(Xdot_trans, dim1=-2, dim2=-1)
    tempMat = torch.zeros(U.shape).cuda()
    for i in range(dim):
        for j in range(i+1, dim):
            tempMat[:,i,j] = Xdot_trans[:,i,j] * (expS[:,i] - expS[:,j]) / (S[:,i] - S[:,j])
            temp_norm = torch.abs(S[:,i] - S[:,j])
            tempMat[temp_norm < eps,i,j] = Xdot_trans[temp_norm < eps,i,j] * expS[temp_norm < eps,i]
            tempMat[:,j,i] = tempMat[:,i,j]
        
    tempjac = torch.matmul(torch.matmul(U, torch.diag_embed(expS*Sdot) + tempMat), U.permute(0,2,1))
    return tempjac

def ExpDir2ndDeriv(X, Xdot1, Xdot2, Xdot12, eps = 1e-14, S = None, U = None, 
                   Xdot1_trans = None, Xdot2_trans = None, Xdot12_trans = None, safe_backward = False):
    if S is None or U is None:
        S, U = batch_eigsym(X, safe_backward)
    if Xdot1_trans is None:
        Xdot1_trans = torch.matmul(torch.matmul(U.permute(0,2,1), Xdot1), U)
    if Xdot2_trans is None:
        Xdot2_trans = torch.matmul(torch.matmul(U.permute(0,2,1), Xdot2), U)
    if Xdot12_trans is None:
        Xdot12_trans = torch.matmul(torch.matmul(U.permute(0,2,1), Xdot12), U)
    N = X.shape[0]
    dim = X.shape[1]
    expS = torch.exp(S)
    Sdot1 = torch.diagonal(Xdot1_trans, dim1=-2, dim2=-1)
    Sdot2 = torch.diagonal(Xdot2_trans, dim1=-2, dim2=-1)
    Omega1 = torch.zeros(U.shape).cuda()
    Omega2 = torch.zeros(U.shape).cuda()
    for i in range(dim):
        for j in range(i+1, dim):
            Omega1[:,i,j] = Xdot1_trans[:,i,j] / (S[:,j] - S[:,i])
            Omega1[:,j,i] = -Omega1[:,i,j]
            Omega2[:,i,j] = Xdot2_trans[:,i,j] / (S[:,j] - S[:,i])
            Omega2[:,j,i] = -Omega2[:,i,j]
            
    temp1 = Omega1 * Sdot2.unsqueeze(1)
    temp2 = torch.bmm(Omega1 * S.unsqueeze(1), -Omega2)
    temp3 = Omega2 * Sdot1.unsqueeze(1)
    temp123 = temp1 + temp2 + temp3
    Temp = Xdot12_trans - temp123 - temp123.permute(0,2,1)
    Temp = 0.5*(Temp + Temp.permute(0,2,1))
    temp = torch.bmm(Omega1, Omega2)
    Temp2 = Temp - temp.permute(0,2,1)*S.unsqueeze(1) - S.unsqueeze(-1)*temp
    Sdot12 = torch.diagonal(Temp2, dim1=-2, dim2=-1)
    Beta12 = torch.zeros(U.shape).cuda()
    for i in range(dim):
        for j in range(i+1, dim):
            Beta12[:,i,j] = Temp2[:,i,j] / (S[:,j] - S[:,i])
            Beta12[:,j,i] = - Beta12[:,i,j]
    Omega12 = Beta12 + temp.permute(0,2,1)
    
    temp1 = Omega12 * expS.unsqueeze(1)
    temp2 = Omega1 * (expS * Sdot2).unsqueeze(1)
    temp3 = torch.matmul(Omega1 * (expS).unsqueeze(1), Omega2.permute(0,2,1))
    temp4 = Omega2 * (expS * Sdot1).unsqueeze(1)
    temp5 = torch.diag_embed(expS * (Sdot1 * Sdot2 + Sdot12))
    temp1234 = temp1 + temp2 + temp3 + temp4
    tempjac = torch.matmul(torch.matmul(U, temp1234 + temp5 + temp1234.permute(0,2,1)), U.permute(0,2,1))
    return tempjac

def LogDirDeriv(X, Xdot, eps = 1e-14, S = None, U = None, Xdot_trans = None, safe_backward = False):
    if S is None or U is None:
        S, U = batch_eigsym(X, safe_backward)
        S[S<eps] = eps
    if Xdot_trans is None:
        Xdot_trans = torch.matmul(torch.matmul(U.permute(0,2,1), Xdot), U)
    N = X.shape[0]
    dim = X.shape[1]
    logS = torch.log(S)
    Sdot = torch.diagonal(Xdot_trans, dim1=-2, dim2=-1)
    tempMat = torch.zeros(U.shape).cuda()
    for i in range(dim):
        for j in range(i+1, dim):
            tempMat[:,i,j] = Xdot_trans[:,i,j] * (logS[:,i] - logS[:,j]) / (S[:,i] - S[:,j])
            temp_norm = torch.abs(S[:,i] - S[:,j])
            tempMat[temp_norm < eps,i,j] = Xdot_trans[temp_norm < eps,i,j] / S[temp_norm < eps,i]
            tempMat[:,j,i] = tempMat[:,i,j]
        
    tempjac = torch.matmul(torch.matmul(U, torch.diag_embed(Sdot/S) + tempMat), U.permute(0,2,1))
    return tempjac

def exponential_map(X, V, eps = 1e-14, safe_backward = False):
    X_sqrt, X_invsqrt = get_sqrt_sym(X, eps = eps, returnInvAlso = True, safe_backward = safe_backward)
    return torch.bmm(torch.bmm(X_sqrt, Exp_mat(torch.bmm(torch.bmm(X_invsqrt, V), X_invsqrt), safe_backward = safe_backward)), X_sqrt)

def logarithm_map(X, Y, eps1 = 1e-14, eps2 = 1e-7, safe_backward = False):
    X_sqrt, X_invsqrt = get_sqrt_sym(X, eps = eps1, returnInvAlso = True, safe_backward = safe_backward)
    return torch.bmm(torch.bmm(X_sqrt, Log_mat(torch.bmm(torch.bmm(X_invsqrt, Y), X_invsqrt), eps = eps2, safe_backward = safe_backward)), X_sqrt)

def squared_distance(X, Y, eps1 = 1e-14, eps2 = 1e-7, S = None, U = None, safe_backward = False, X_invsqrt = None):
    if X_invsqrt is None:
        _, X_invsqrt = get_sqrt_sym(X, eps = eps1, returnInvAlso = True, S = S, U = U, safe_backward = safe_backward)
    T = torch.matmul(torch.matmul(X_invsqrt, Y), X_invsqrt)
    S_T, U_T = batch_eigsym(T, safe_backward)
    S_T[S_T<eps2] = eps2
    logS = torch.log(S_T)
    return torch.sum(logS*logS, dim=-1)

def pairwise_distance(X, eps1 = 1e-14, eps2 = 1e-7, S = None, U = None, safe_backward = False):
    N = X.shape[0]
    dim = X.shape[1]
    _, X_invsqrt = get_sqrt_sym(X, eps = eps1, returnInvAlso = True, S = S, U = U, safe_backward = safe_backward)
    X_invsqrt = X_invsqrt.unsqueeze(0)
    T = torch.matmul(torch.matmul(X_invsqrt, X.unsqueeze(1)), X_invsqrt).view(-1, dim, dim)
    S_T, U_T = batch_eigsym(T, safe_backward)
    S_T[S_T<eps2] = eps2
    logS = torch.log(S_T)
    return (logS*logS).sum(-1).sqrt().view(N,N)

def getConstForMinDist(X, Y, eps1 = 1e-14, eps2 = 1e-7):
    # find the constant C to minimize the distance between C*X and 1/C*Y
    _, X_invsqrt = get_sqrt_sym(X, eps = eps1, returnInvAlso = True)
    T = torch.bmm(torch.bmm(X_invsqrt, Y), X_invsqrt)
    S_T, U_T = batch_eigsym(T)
    S_T[S_T<eps2] = eps2
    logS = torch.log(S_T)
    return 0.5*logS.mean(1)

def logarithm_map_DirDeriv(X, Y, Xdot, eps1 = 1e-14, eps2 = 1e-7, S = None, U = None, Xdot_trans = None, X_sqrt = None, X_invsqrt = None, T = None, S_T = None, U_T = None, LogT = None, safe_backward = False):
    # derivative of logarithm_map with respect to X
    # output shape: N x mat_dim x mat_dim
    if S is None or U is None:
        S, U = batch_eigsym(X, safe_backward)
        S[S<eps1] = eps1
    if Xdot_trans is None:
        Xdot_trans = torch.matmul(torch.matmul(U.permute(0,2,1), Xdot), U)
    if X_sqrt is None or X_invsqrt is None:
        X_sqrt, X_invsqrt = get_sqrt_sym(X, eps = eps1, returnInvAlso = True, S = S, U = U, safe_backward = safe_backward)
    if T is None:
        T = torch.bmm(torch.bmm(X_invsqrt, Y), X_invsqrt)
    if S_T is None or U_T is None:
        S_T, U_T = batch_eigsym(T, safe_backward)
        S_T[S_T<eps2] = eps2
    if LogT is None:
        LogT = Log_mat(T, eps = eps2, S = S_T, U = U_T, safe_backward = safe_backward)
        
    X_sqrt_dot = get_sqrt_sym_DirDeriv(X, Xdot, eps = eps1, S = S, U = U, Xdot_trans = Xdot_trans, safe_backward = safe_backward)
    X_invsqrt_dot = get_sqrtInv_sym_DirDeriv(X, Xdot, eps = eps1, S = S, U = U, Xdot_trans = Xdot_trans, safe_backward = safe_backward)
    Tdot = torch.bmm(torch.bmm(X_invsqrt_dot, Y), X_invsqrt) + torch.bmm(torch.bmm(X_invsqrt, Y), X_invsqrt_dot)
    LogT_dot = LogDirDeriv(T, Tdot, eps = eps1, S = S_T, U = U_T, safe_backward = safe_backward)
    output = torch.bmm(torch.bmm(X_sqrt_dot, LogT), X_sqrt) + torch.bmm(torch.bmm(X_sqrt, LogT), X_sqrt_dot) + torch.bmm(torch.bmm(X_sqrt, LogT_dot), X_sqrt)
    return output


###### codes for local coordinates ######

def idx2rowcol(idx, mat_dim):
    rowChangeIdx = [0]
    for i in range(mat_dim):
        rowChangeIdx.append(mat_dim-i + rowChangeIdx[-1])
        if idx < rowChangeIdx[-1]:
            rowIdx = i
            colIdx = idx - rowChangeIdx[-2] + i
            return rowIdx, colIdx

def rowcol2idx(row, col, mat_dim):
    if row > col:
        temp = row
        row = col
        col = temp
    idx = 0
    for i in range(row):
        idx += mat_dim - i
    idx += col - row
    return idx
        
def metric_P_n(X, X_inv = None):
    if X_inv is None:
        X_inv = torch.inverse(X)
        X_inv = 0.5*(X_inv + X_inv.permute(0,2,1))
    N = X.shape[0]
    dim = int(X.shape[1]*(X.shape[1] + 1) / 2)
    metric = torch.cuda.FloatTensor(N,dim,dim).zero_()
    for idx1 in range(dim):
        i, j = idx2rowcol(idx1, X.shape[1])
        for idx2 in range(idx1, dim):
            k, l = idx2rowcol(idx2, X.shape[1])
            if i == j and k == l:
                metric[:,idx1,idx2] = X_inv[:,i,k]**2
            elif i == j:
                metric[:,idx1,idx2] = 2*X_inv[:,i,k]*X_inv[:,i,l]
            elif k == l:
                metric[:,idx1,idx2] = 2*X_inv[:,i,k]*X_inv[:,j,k]
            else:
                metric[:,idx1,idx2] = 2*(X_inv[:,i,k]*X_inv[:,j,l] + X_inv[:,i,l]*X_inv[:,j,k])
            metric[:,idx2,idx1] = metric[:,idx1,idx2]
    
    return metric

def metricInv_P_n(X):
    N = X.shape[0]
    dim = int(X.shape[1]*(X.shape[1] + 1) / 2)
    metricInv = torch.cuda.FloatTensor(N,dim,dim).zero_()
    for idx1 in range(dim):
        i, j = idx2rowcol(idx1, X.shape[1])
        for idx2 in range(idx1, dim):
            k, l = idx2rowcol(idx2, X.shape[1])
            if i == j and k == l:
                metricInv[:,idx1,idx2] = X[:,i,k]**2
            elif i == j:
                metricInv[:,idx1,idx2] = X[:,i,k]*X[:,i,l]
            elif k == l:
                metricInv[:,idx1,idx2] = X[:,i,k]*X[:,j,k]
            else:
                metricInv[:,idx1,idx2] = 0.5*(X[:,i,k]*X[:,j,l] + X[:,i,l]*X[:,j,k])
            metricInv[:,idx2,idx1] = metricInv[:,idx1,idx2]
    
    return metricInv

def metricInv2_P_n(X, X_inv = None):
    metric = metric_P_n(X, X_inv)
    metricInv = torch.inverse(metric)
    if not X.is_cuda:
        metricInv = metricInv.cpu()
    
    return metricInv

def PinvDeriv_P_n(X, X_inv = None):
    N = X.shape[0]
    dim = int(X.shape[1]*(X.shape[1] + 1) / 2)
    if X_inv is None:
        X_inv = torch.inverse(X)
        X_inv = 0.5*(X_inv + X_inv.permute(0,2,1))
    dX_inv_dX = torch.cuda.FloatTensor(N,dim,dim).zero_()
    E = torch.cuda.FloatTensor(X.shape).zero_()
    for idx in range(dim):
        i, j = idx2rowcol(idx, X.shape[1])
        E[:,i,j] = 1
        E[:,j,i] = 1
        dX_inv_dX[:,:,idx] = - mat2vec(torch.bmm(torch.bmm(X_inv, E), X_inv))
        E[:,i,j] = 0
        E[:,j,i] = 0
    return dX_inv_dX

def metricDeriv_P_n(X, X_inv = None):
    N = X.shape[0]
    dim = int(X.shape[1]*(X.shape[1] + 1) / 2)
    if X_inv is None:
        X_inv = torch.inverse(X)
        X_inv = 0.5*(X_inv + X_inv.permute(0,2,1))
    dX_inv_dX = PinvDeriv_P_n(X, X_inv)
    metricDeriv = torch.cuda.FloatTensor(N,dim,dim,dim).zero_()
    for idx1 in range(dim):
        i, j = idx2rowcol(idx1, X.shape[1])
        for idx2 in range(idx1, dim):
            k, l = idx2rowcol(idx2, X.shape[1])
            for idx3 in range(dim):
                if i == j and k == l:
                    tempidx = rowcol2idx(i, k, X.shape[1])
                    metricDeriv[:,idx1,idx2,idx3] = 2*X_inv[:,i,k]*dX_inv_dX[:,tempidx,idx3]
                elif i == j:
                    tempidx1 = rowcol2idx(i, k, X.shape[1])
                    tempidx2 = rowcol2idx(i, l, X.shape[1])
                    metricDeriv[:,idx1,idx2,idx3] = 2*X_inv[:,i,k]*dX_inv_dX[:,tempidx2,idx3] + 2*dX_inv_dX[:,tempidx1,idx3]*X_inv[:,i,l]
                elif k == l:
                    tempidx1 = rowcol2idx(i, k, X.shape[1])
                    tempidx2 = rowcol2idx(j, k, X.shape[1])
                    metricDeriv[:,idx1,idx2,idx3] = 2*X_inv[:,i,k]*dX_inv_dX[:,tempidx2,idx3] + 2*dX_inv_dX[:,tempidx1,idx3]*X_inv[:,j,k]
                else:
                    tempidx1 = rowcol2idx(i, k, X.shape[1])
                    tempidx2 = rowcol2idx(j, l, X.shape[1])
                    tempidx3 = rowcol2idx(i, l, X.shape[1])
                    tempidx4 = rowcol2idx(j, k, X.shape[1])
                    metricDeriv[:,idx1,idx2,idx3] = 2*(X_inv[:,i,k]*dX_inv_dX[:,tempidx2,idx3] + dX_inv_dX[:,tempidx1,idx3]*X_inv[:,j,l] \
                                                       + X_inv[:,i,l]*dX_inv_dX[:,tempidx4,idx3] + dX_inv_dX[:,tempidx3,idx3]*X_inv[:,j,k])
                metricDeriv[:,idx2,idx1,idx3] = metricDeriv[:,idx1,idx2,idx3]
                
    return metricDeriv

def metricInvDeriv_P_n(X):
    N = X.shape[0]
    dim = int(X.shape[1]*(X.shape[1] + 1) / 2)
    metricInvDeriv = torch.cuda.FloatTensor(N,dim,dim,dim).zero_()
    for idx1 in range(dim):
        i, j = idx2rowcol(idx1, X.shape[1])
        for idx2 in range(idx1, dim):
            k, l = idx2rowcol(idx2, X.shape[1])
            if i == j and k == l:
                tempidx = rowcol2idx(i, k, X.shape[1])
                metricInvDeriv[:,idx1,idx2,tempidx] += 2*X[:,i,k]
            elif i == j:
                tempidx1 = rowcol2idx(i, k, X.shape[1])
                tempidx2 = rowcol2idx(i, l, X.shape[1])
                metricInvDeriv[:,idx1,idx2,tempidx1] += X[:,i,l]
                metricInvDeriv[:,idx1,idx2,tempidx2] += X[:,i,k]
            elif k == l:
                tempidx1 = rowcol2idx(i, k, X.shape[1])
                tempidx2 = rowcol2idx(j, k, X.shape[1])
                metricInvDeriv[:,idx1,idx2,tempidx1] += X[:,j,k]
                metricInvDeriv[:,idx1,idx2,tempidx2] += X[:,i,k]
            else:
                tempidx1 = rowcol2idx(i, k, X.shape[1])
                tempidx2 = rowcol2idx(j, l, X.shape[1])
                tempidx3 = rowcol2idx(i, l, X.shape[1])
                tempidx4 = rowcol2idx(j, k, X.shape[1])
                metricInvDeriv[:,idx1,idx2,tempidx1] += 0.5*X[:,j,l]
                metricInvDeriv[:,idx1,idx2,tempidx2] += 0.5*X[:,i,k]
                metricInvDeriv[:,idx1,idx2,tempidx3] += 0.5*X[:,j,k]
                metricInvDeriv[:,idx1,idx2,tempidx4] += 0.5*X[:,i,l]
            metricInvDeriv[:,idx2,idx1] = metricInvDeriv[:,idx1,idx2]
    
    return metricInvDeriv

def metric_sqrt_P_n(X, X_inv = None, returnMetric = False, safe_backward = False):
    metric = metric_P_n(X, X_inv)
    metric_sqrt = get_sqrt_sym(metric, safe_backward = safe_backward)
    if not X.is_cuda:
        metric = metric.cpu()
        metric_sqrt = metric_sqrt.cpu()
    if returnMetric:
        return metric_sqrt, metric
    return metric_sqrt

def metricInv_sqrt_P_n(X, X_inv = None, returnMetric = False, safe_backward = False):
    metric = metric_P_n(X, X_inv)
    _, metricInv_sqrt = get_sqrt_sym(metric, returnInvAlso = True, safe_backward = safe_backward)
    if not X.is_cuda:
        metricInv_sqrt = metricInv_sqrt.cpu()
        metric = metric.cpu()
    if returnMetric:
        return metricInv_sqrt, metric
    return metricInv_sqrt

def metricDetSqrt_P_n(X):
    ### this should be proved mathematically...
    dim = X.shape[1]
    return math.pow(2, dim*(dim-1)/4) / torch.pow(X.det(), (dim+1)/2)

def metricDetSqrtLog_P_n(X):
    ### this should be proved mathematically...
    dim = X.shape[1]
    return dim*(dim-1)/4*math.log(2) - (dim+1)/2*X.logdet()

def christoffelSum_P_n(X, X_inv = None):
    mat_dim = X.shape[1]
    ### this can be obtained by differentiating metricDetSqrtLog_P_n
    if X_inv is None:
        X_inv = torch.inverse(X)
        X_inv = 0.5*(X_inv + X_inv.permute(0,2,1))
    x_inv = mat2vec(X_inv)
    
    chSum = torch.cuda.FloatTensor(x_inv.shape).zero_()
    for idx in range(chSum.shape[1]):
        i, j = idx2rowcol(idx, mat_dim)
        if i == j:
            chSum[:,idx] = -(mat_dim+1)/2*x_inv[:,idx]
        else:
            chSum[:,idx] = -(mat_dim+1)*x_inv[:,idx]
    if not X.is_cuda:
        chSum = chSum.cpu()
    return chSum

def christoffelSumDeriv_P_n(X, X_inv = None):
    N = X.shape[0]
    dim = int(X.shape[1]*(X.shape[1] + 1) / 2)
    mat_dim = X.shape[1]
    if X_inv is None:
        X_inv = torch.inverse(X)
        X_inv = 0.5*(X_inv + X_inv.permute(0,2,1))
    dX_inv_dX = PinvDeriv_P_n(X, X_inv)
    
    chSumDeriv = torch.cuda.FloatTensor(N,dim,dim).zero_()
    for idx in range(dim):
        i, j = idx2rowcol(idx, mat_dim)
        if i == j:
            chSumDeriv[:,idx] = -(mat_dim+1)/2*dX_inv_dX[:,idx]
        else:
            chSumDeriv[:,idx] = -(mat_dim+1)*dX_inv_dX[:,idx]
    if not X.is_cuda:
        chSumDeriv = chSumDeriv.cpu()
    return chSumDeriv

def logarithmMapDeriv_P_n(X, Y, eps1 = 1e-14, eps2 = 1e-7, S = None, U = None, X_sqrt = None, X_invsqrt = None, T = None, S_T = None, U_T = None, LogT = None, safe_backward = False):
    # derivative of logarithm_map with respect to X
    # output shape: N x vec_dim x vec_dim
    if S is None or U is None:
        S, U = batch_eigsym(X, safe_backward)
        S[S<eps1] = eps1
    if X_sqrt is None or X_invsqrt is None:
        X_sqrt, X_invsqrt = get_sqrt_sym(X, eps = eps1, returnInvAlso = True, S = S, U = U, safe_backward = safe_backward)
    if T is None:
        T = torch.bmm(torch.bmm(X_invsqrt, Y), X_invsqrt)
    if S_T is None or U_T is None:
        S_T, U_T = batch_eigsym(T, safe_backward)
        S_T[S_T<eps2] = eps2
    if LogT is None:
        LogT = Log_mat(T, eps = eps2, S = S_T, U = U_T, safe_backward = safe_backward)
    N = X.shape[0]
    mat_dim = X.shape[1]
    dim = int(mat_dim*(mat_dim + 1) / 2)
    output = torch.cuda.FloatTensor(N,dim,dim).zero_()
    Xdot = torch.cuda.FloatTensor(N,mat_dim,mat_dim).zero_()
    k = 0
    for i in range(mat_dim):
        for j in range(i, mat_dim):
            Xdot[:,i,j] = 1
            Xdot[:,j,i] = 1
            temp = logarithm_map_DirDeriv(X, Y, Xdot, eps1 = eps1, eps2 = eps2, S = S, U = U, X_sqrt = X_sqrt, X_invsqrt = X_invsqrt, T = T, S_T = S_T, U_T = U_T, LogT = LogT, safe_backward = safe_backward)
            output[:,:,k] = mat2vec(temp)
            Xdot[:,i,j] = 0
            Xdot[:,j,i] = 0
            k += 1
    return output


def sqdistDeriv_P_n(X, Y, X_sqrt, X_invsqrt, metric, eps=1e-7, returnsqdistAlso = False):
    sqdistSet = [] #torch.cuda.FloatTensor(centerPoints.shape[0], x.shape[0]).zero_()
    Log_X_C_vec_set = [] #torch.cuda.FloatTensor(centerPoints.shape[0], x.shape[0], vec_dim).zero_()
    Log_X_C_covec_set = [] #torch.cuda.FloatTensor(centerPoints.shape[0], x.shape[0], vec_dim).zero_()
    for i in range(Y.shape[0]):
        Y_i = Y[i].unsqueeze(0).repeat(X.shape[0],1,1)
        T = torch.bmm(torch.bmm(X_invsqrt, Y_i), X_invsqrt)
        S_T, U_T = batch_eigsym(T)
        S_T[S_T<eps] = eps
        logS = torch.log(S_T)
        sqdistSet.append(torch.sum(logS*logS, dim=-1))
        LogT = Log_mat(T, eps = eps, S = S_T, U = U_T)
        Log_X_C_vec_set.append(mat2vec(torch.bmm(torch.bmm(X_sqrt, LogT), X_sqrt)))
        Log_X_C_covec_set.append(torch.matmul(metric, Log_X_C_vec_set[i].view(X.shape[0],metric.shape[1],1)).squeeze(-1))
    sqdistSet = torch.stack(sqdistSet, dim=1)
    sqdistDerivSet = -2*torch.stack(Log_X_C_covec_set, dim=1)
    if returnsqdistAlso:
        return sqdistDerivSet, sqdistSet
    return sqdistDerivSet

def sqdist2ndDeriv_P_n(X, Y, X_sqrt, X_invsqrt, metric, metricDeriv, eps=1e-7, returnsqdistAlso = False):
    sqdistSet = [] #torch.cuda.FloatTensor(centerPoints.shape[0], x.shape[0]).zero_()
    Log_X_C_vec_set = [] #torch.cuda.FloatTensor(centerPoints.shape[0], x.shape[0], vec_dim).zero_()
    Log_X_C_covec_set = [] #torch.cuda.FloatTensor(centerPoints.shape[0], x.shape[0], vec_dim).zero_()
    sqdist2ndDerivSet = []
    S, U = batch_eigsym(X)
    S[S<eps] = eps
    for i in range(Y.shape[0]):
        Y_i = Y[i].unsqueeze(0).repeat(X.shape[0],1,1)
        T = torch.bmm(torch.bmm(X_invsqrt, Y_i), X_invsqrt)
        S_T, U_T = batch_eigsym(T)
        S_T[S_T<eps] = eps
        logS = torch.log(S_T)
        sqdistSet.append(torch.sum(logS*logS, dim=-1))
        LogT = Log_mat(T, eps = eps, S = S_T, U = U_T)
        Log_X_C_vec_set.append(mat2vec(torch.bmm(torch.bmm(X_sqrt, LogT), X_sqrt)))
        Log_X_C_covec_set.append(torch.matmul(metric, Log_X_C_vec_set[i].view(X.shape[0],metric.shape[1],1)).squeeze(-1))
        
        logmapDeriv = logarithmMapDeriv_P_n(X, Y_i, S = S, U = U, X_sqrt = X_sqrt, X_invsqrt = X_invsqrt, 
                                            T = T, S_T = S_T, U_T = U_T, LogT = LogT)
        sqdist2ndDeriv1 = torch.matmul(metricDeriv.permute(0,3,1,2), Log_X_C_vec_set[i].unsqueeze(1).unsqueeze(-1)).squeeze(-1).permute(0,2,1)
        sqdist2ndDeriv2 = torch.matmul(metric, logmapDeriv)
        sqdist2ndDerivSet.append(sqdist2ndDeriv1 + sqdist2ndDeriv2)
    sqdistSet = torch.stack(sqdistSet, dim=1)
    sqdistDerivSet = -2*torch.stack(Log_X_C_covec_set, dim=1)
    sqdist2ndDerivSet = -2*torch.stack(sqdist2ndDerivSet, dim=1)
    sqdist2ndDerivSet = 0.5*(sqdist2ndDerivSet + sqdist2ndDerivSet.permute(0,1,3,2))
    if returnsqdistAlso:
        return sqdist2ndDerivSet, sqdistDerivSet, sqdistSet
    return sqdist2ndDerivSet

##### codes for tangent space Gaussian (using local coordinates) #####

def getCoeff(dim, is_cuda = True):
    # multiplying coefficient to get exponential coordinate (from local coordinate)
    if is_cuda:
        coeff_mat = torch.cuda.FloatTensor(dim, dim).fill_(math.sqrt(2))
    else:
        coeff_mat = torch.FloatTensor(dim, dim).fill_(math.sqrt(2))
    coeff_mat.fill_diagonal_(1)
    return mat2vec(coeff_mat.view(1, dim, dim))

def getCoeff2(dim, is_cuda = True):
    # multiplying coefficient to get local coordinate (from exponential coordinate)
    if is_cuda:
        coeff_mat = torch.cuda.FloatTensor(dim, dim).fill_(1/math.sqrt(2))
    else:
        coeff_mat = torch.FloatTensor(dim, dim).fill_(1/math.sqrt(2))
    coeff_mat.fill_diagonal_(1)
    return mat2vec(coeff_mat.view(1, dim, dim))

def getOrthogonalAffineTransformJacobian(R):
    N = R.shape[0]
    mat_dim = R.shape[1]
    dim = int(mat_dim*(mat_dim+1)/2)
    if R.is_cuda:
        J = torch.cuda.FloatTensor(N, dim, dim).fill_(0)
        temp = torch.cuda.FloatTensor(1, mat_dim, mat_dim).fill_(0)
    else:
        J = torch.FloatTensor(N, dim, dim).fill_(0)
        temp = torch.FloatTensor(1, mat_dim, mat_dim).fill_(0)
    k = 0
    for i in range(mat_dim):
        for j in range(i,mat_dim):
            if i==j:
                temp[0,i,j] = 1
            else:
                temp[0,i,j] = 1/math.sqrt(2)
                temp[0,j,i] = 1/math.sqrt(2)
            J[:,:,k] = mat2vec(torch.matmul(torch.matmul(R, temp), R.permute(0,2,1)))*getCoeff(mat_dim)
            temp[0,i,j] = 0
            temp[0,j,i] = 0
            k += 1
    return J

def getExpCoord(X, mean_invsqrt):
    # X = m^1/2 * Exp([e]) * m^1/2
    N = X.shape[0]
    dim = X.shape[1]
    mean_invsqrt = mean_invsqrt.view(1,dim,dim).repeat(N,1,1)
    temp_coord = mat2vec(Log_mat(torch.bmm(torch.bmm(mean_invsqrt, X), mean_invsqrt)))
    return temp_coord * getCoeff(dim)

def getExpCoordJacobian(X, mean_invsqrt):
    # de/dx
    N = X.shape[0]
    dim = X.shape[1]
    vec_dim = int(dim*(dim+1)/2)
    mean_invsqrt = mean_invsqrt.view(1,dim,dim).repeat(N,1,1)
    Xnew = torch.bmm(torch.bmm(mean_invsqrt, X), mean_invsqrt)
    E_k = torch.cuda.FloatTensor(N,dim,dim).zero_()
    coeff = getCoeff(dim)
    
    expCoordJac = torch.cuda.FloatTensor(N,vec_dim,vec_dim).zero_()
    k = 0
    for i in range(dim):
        for j in range(i, dim):
            E_k[:,i,j] = 1
            E_k[:,j,i] = 1
            Xdot_k = torch.bmm(torch.bmm(mean_invsqrt, E_k), mean_invsqrt)
            expCoordJac[:,:,k] = mat2vec(LogDirDeriv(Xnew, Xdot_k)) * coeff
            E_k[:,i,j] = 0
            E_k[:,j,i] = 0
            k += 1
    return expCoordJac

def getExpCoordJacobianInv(expCoord, mean_sqrt):
    # dx/de
    N = expCoord.shape[0]
    vec_dim = expCoord.shape[1]
    E_k = torch.cuda.FloatTensor(N,vec_dim).zero_()
    dim = vecdim2matdim(vec_dim)
    coeff = getCoeff2(dim)
    E = vec2mat(expCoord * coeff)
    expCoordJacInv = torch.cuda.FloatTensor(N,vec_dim,vec_dim).zero_()
    k = 0
    for i in range(dim):
        for j in range(i, dim):
            E_k[:,k] = 1
            expCoordJacInv[:,:,k] = mat2vec(torch.matmul(torch.matmul(mean_sqrt.view(1,dim,dim), 
                                                              (ExpDirDeriv(E, vec2mat(E_k*coeff)))), mean_sqrt.view(1,dim,dim)))
            E_k[:,k] = 0
            k += 1
    return expCoordJacInv

def getExpCoordJacobianInvLogDetDeriv(X, mean_sqrt, mean_invsqrt):
    N = X.shape[0]
    dim = X.shape[1]
    vec_dim = int(dim*(dim+1)/2)
    output = torch.cuda.FloatTensor(N,vec_dim).zero_()
    temp1 = torch.cuda.FloatTensor(N,vec_dim).zero_()
    temp2 = torch.cuda.FloatTensor(N,vec_dim).zero_()
    
    ecoord = getExpCoord(X, mean_invsqrt)
    J = getExpCoordJacobian(X, mean_invsqrt)
    
    coeff = getCoeff2(dim)
    E = vec2mat(ecoord * coeff)
    S_E, U_E = batch_eigsym(E)
    
    V12 = vec2mat(temp1)
    Xdot12_trans = vec2mat(temp1)
    for i in range(vec_dim):
        temp1[:,i] = 1
        V1 = vec2mat(temp1 * coeff)
        Xdot1_trans = torch.matmul(torch.matmul(U_E.permute(0,2,1), V1), U_E)
        deriv_i = torch.cuda.FloatTensor(N, vec_dim, vec_dim).zero_()
        for j in range(vec_dim):
            temp2[:,j] = 1
            V2 = vec2mat(temp2 * coeff)
            deriv_i[:,:,j] = mat2vec(
                torch.matmul(
                    mean_sqrt.view(1,dim,dim), 
                    torch.matmul(
                        ExpDir2ndDeriv(E, V1, V2, V12, S = S_E, U = U_E, Xdot1_trans = Xdot1_trans, Xdot12_trans = Xdot12_trans), 
                        mean_sqrt.view(1,dim,dim)
                    )
                )
            )
            temp2[:,j] = 0
        output += (J.permute(0,2,1) * deriv_i).sum((-1,-2)).unsqueeze(-1) * J[:,i]
        temp1[:,i] = 0
    return output

def log_rho_tangentGaussian(X, mean, CovInv):
    # ignore constant from partition function
    N = X.shape[0]
    dim = X.shape[1]
    vec_dim = int(dim*(dim+1)/2)
    mean_sqrt, mean_invsqrt = get_sqrt_sym(mean.view(1,dim,dim), returnInvAlso = True)
    expCoord = getExpCoord(X, mean_invsqrt)
    _, logabsdet = getExpCoordJacobianInv(expCoord, mean_sqrt).slogdet()
    return -0.5*torch.bmm(torch.matmul(expCoord.view(N,1,vec_dim), CovInv.view(1,vec_dim,vec_dim)), expCoord.view(N,vec_dim,1)).view(N) \
- logabsdet

def log_rho_g_tangentGaussian(X, mean, CovInv):
    # ignore constant
    return log_rho_tangentGaussian(X, mean, CovInv) - torch.log(metricDetSqrt_P_n(X))

def log_rho_g_tangentGaussian2(X, mean, CovInv):
    # ignore constant
    return log_rho_tangentGaussian(X, mean, CovInv) - metricDetSqrtLog_P_n(X)

def log_rho_tangentGaussianMixture(X, weights, means, CovInvs):
    # ignore constant
    # only consider Gaussians with identical partition functions
    N = X.shape[0]
    dim = X.shape[1]
    vec_dim = int(dim*(dim+1)/2)
    Nmix = weights.shape[0]
    
    means_sqrt, means_invsqrt = get_sqrt_sym(means, returnInvAlso = True)
    p_i = torch.cuda.FloatTensor(N, Nmix).zero_()
    for i in range(Nmix):
        expCoord = getExpCoord(X, means_invsqrt[i])
        _, logabsdet = getExpCoordJacobianInv(expCoord, means_sqrt[i]).slogdet()
        p_i[:,i] = torch.exp(-0.5*torch.bmm(torch.matmul(expCoord.view(N,1,vec_dim), CovInvs[i].view(1,vec_dim,vec_dim)), expCoord.view(N,vec_dim,1)).view(N) - logabsdet)
    p = (p_i * weights.view(1,Nmix)).sum(-1)
    # to avoid -inf
    logp = torch.log(p)
    logp[logp.isinf()] = -103 #torch.min(logp[~logp.isinf()])
    return logp

def log_rho_g_tangentGaussianMixture(X, weights, means, CovInvs):
    # ignore constant
    return log_rho_tangentGaussianMixture(X, weights, means, CovInvs) - torch.log(metricDetSqrt_P_n(X))

def log_rho_g_tangentGaussianMixture2(X, weights, means, CovInvs):
    # ignore constant
    return log_rho_tangentGaussianMixture(X, weights, means, CovInvs) - metricDetSqrtLog_P_n(X)

def geometricScore_tangentGaussian(X, mean, CovInv):
    # get geometric score for X = mean^(1/2)*Exp([e])*mean^(1/2), where e ~ N(0, Cov) on the tangent space of I then transformed to mean
    # assume all inputs are on GPU
    N = X.shape[0]
    dim = X.shape[1]
    vec_dim = int(dim*(dim+1)/2)
    
    mean_sqrt, mean_invsqrt = get_sqrt_sym(mean.view(1,dim,dim), returnInvAlso = True)
    expCoord = getExpCoord(X, mean_invsqrt)
    dexpCoord_dx = getExpCoordJacobian(X, mean_invsqrt)
    score = - torch.bmm(torch.matmul(expCoord.view(N,1,vec_dim), CovInv.view(1,vec_dim,vec_dim)), dexpCoord_dx).view(N, vec_dim) - getExpCoordJacobianInvLogDetDeriv(X, mean_sqrt, mean_invsqrt)
    gscore = score - christoffelSum_P_n(X)
    
    return gscore, score

def geometricScore_tangentGaussianMixture(X, weights, means, CovInvs):
    # get geometric score for X ~ p = w_i*p_i; 
    # from each p_i, X is sampled as X = mean_i^(1/2)*Exp([e])*mean_i^(1/2), where e ~ N(0, Cov_i) on the tangent space of I then transformed to mean_i
    # assume all inputs are on GPU
    N = X.shape[0]
    dim = X.shape[1]
    vec_dim = int(dim*(dim+1)/2)
    Nmix = weights.shape[0]
    
    means_sqrt, means_invsqrt = get_sqrt_sym(means, returnInvAlso = True)
    p_i = torch.cuda.FloatTensor(N, Nmix).zero_()
    dp_i_dx = torch.cuda.FloatTensor(N, vec_dim, Nmix).zero_()
    for i in range(Nmix):
        expCoord = getExpCoord(X, means_invsqrt[i])
        dexpCoord_dx = getExpCoordJacobian(X, means_invsqrt[i])
        sign, logabsdet = getExpCoordJacobianInv(expCoord, means_sqrt[i]).slogdet()
        p_i[:,i] = torch.exp(-0.5*torch.bmm(torch.matmul(expCoord.view(N,1,vec_dim), CovInvs[i].view(1,vec_dim,vec_dim)), 
                                            expCoord.view(N,vec_dim,1)).view(N) - logabsdet)
        dlogp_i_dx = - torch.bmm(torch.matmul(expCoord.view(N,1,vec_dim), CovInvs[i].view(1,vec_dim,vec_dim)), dexpCoord_dx).view(N, vec_dim) \
        - getExpCoordJacobianInvLogDetDeriv(X, means_sqrt[i], means_invsqrt[i])
        dp_i_dx[:,:,i] = p_i[:,i].view(N,1) * dlogp_i_dx
        
    p = (p_i * weights.view(1,Nmix)).sum(-1)
    
    # to avoid nan and inf
    p[p==0] = 1e-45 #torch.min(p[p>0])
    
    dp_dx = (dp_i_dx * weights.view(1,1,Nmix)).sum(-1)
    score = dp_dx / p.view(N,1)
    gscore = score - christoffelSum_P_n(X)
    
    return gscore, score