# define functions for spherical data (S^n)

import numpy as np

def getData(coord):
    # get R^(n+1) from th^n
    dim = coord.shape[1]
    N = coord.shape[0]
    data = np.ones((dim+1, N))
    s = np.sin(coord)
    c = np.cos(coord)
    for i in range(dim):
        for j in range(i, dim+1):
            if j == i:
                data[j] *= c[:,i]
            else:
                data[j] *= s[:,i]
                
    return data

def getCoord(data):
    # get th from R^(n+1)
    dim = data.shape[0] - 1
    N = data.shape[1]
    th = np.zeros((N, dim))
    th[:,0] = np.arccos(data[0])
    th[:,dim-1] = np.arctan2(data[dim], data[dim-1])
    sin_multi = np.ones(N)
    for i in range(1, dim-1):
        sin_multi *= np.sin(th[:,i-1])
        temp = data[i] / sin_multi
        temp[temp > 1] = 1
        temp[temp < -1] = -1
        th[:,i] = np.arccos(temp)
        
    return th

def getJacobianPos(coord):
    # get derivative of (n+1)-dim pos w.r.t. n-dim th
    N = coord.shape[0]
    dim = coord.shape[1]
    dpos_dcoord = np.ones((N, dim+1, dim))
    s = np.sin(coord)
    c = np.cos(coord)
    for i in range(dim+1):
        dpos_dcoord[:, i, i+1:] = 0
        for j in range(i+1):
            for k in range(i+1):
                if j < dim and k < dim:
                    if j == i and k == j:
                        dpos_dcoord[:,i,j] *= -s[:,k]
                    elif k == j or k == i:
                        dpos_dcoord[:,i,j] *= c[:,k]
                    else:
                        dpos_dcoord[:,i,j] *= s[:,k]
                        
    return dpos_dcoord

def coordChange(coord2, R1, R2, print_pos = False):
    # from coord2 to coord1
    pos2 = getData(coord2)
    R1R2T = np.dot(R1, R2.T)
    pos = np.dot(R1R2T, pos2)
    
    if not print_pos:
        return getCoord(pos)
    else:
        return getCoord(pos), pos, pos2

def coordJacobian(coord2, R1, R2):
    # jacobian of coord1 w.r.t. coord2
    coord1= coordChange(coord2, R1, R2)
    dpos2_dcoord2 = getJacobianPos(coord2)
    dpos1_dcoord1 = getJacobianPos(coord1)
    
    R1R2T = np.dot(R1, R2.T)
    N = coord2.shape[0]
    dim = coord2.shape[1]
    J = np.zeros((N,dim,dim))
    
    for i in range(N):
        J[i] = np.dot(np.dot(np.linalg.pinv(dpos1_dcoord1[i]), R1R2T), dpos2_dcoord2[i])
    
    return J

def metricSqrt(data):
    # return only diagonal components
    dim = data.shape[1]
    metricSqrt = np.ones(data.shape)
    s = np.sin(data)
    for i in range(1,dim):
        metricSqrt[:,i] = metricSqrt[:,i-1] * s[:,i-1]
    return metricSqrt

def metric(data):
    # return only diagonal components
    dim = data.shape[1]
    metricSqrt = np.ones(data.shape)
    s = np.sin(data)
    for i in range(1,dim):
        metricSqrt[:,i] = metricSqrt[:,i-1] * s[:,i-1]
    return metricSqrt**2

def metricDeriv(data):
    N = data.shape[0]
    dim = data.shape[1]
    metricDeriv = np.zeros((N,dim,dim))
    s = np.sin(data)
    c = np.cos(data)
    metricSqrt = np.ones(data.shape)
    for i in range(1,dim):
        metricSqrt[:,i] = metricSqrt[:,i-1] * s[:,i-1]
    metric = metricSqrt**2
    for i in range(dim):
        metricDeriv[:,i+1:,i] = metric[:,i+1:] * 2.0*c[:,i:i+1]/s[:,i:i+1]
    return metricDeriv

def metricInvSqrt(data):
    # return only diagonal components
    dim = data.shape[1]
    metricSqrt = np.ones(data.shape)
    s = np.sin(data)
    for i in range(1,dim):
        metricSqrt[:,i] = metricSqrt[:,i-1] * s[:,i-1]
    return 1.0/metricSqrt

def metricInvDeriv(data):
    N = data.shape[0]
    dim = data.shape[1]
    metric_deriv = metricDeriv(data)
    metricInv = metricInvSqrt(data)**2
    return - metricInv.reshape((N,dim,1))*metric_deriv*metricInv.reshape((N,dim,1))

def christoffelSum(data):
    christoffel = np.zeros(data.shape)
    dim = data.shape[1]
    for i in range(dim):
        th = data[:,i]
        christoffel[:,i] = (dim - i - 1) * np.cos(th) / np.sin(th)
    
    return christoffel

def christoffelSumDeriv(data):
    N = data.shape[0]
    dim = data.shape[1]
    christoffelDeriv = np.zeros((N,dim,dim))
    for i in range(dim):
        th = data[:,i]
        christoffelDeriv[:,i, i] = - (dim - i - 1) / np.sin(th)**2
    
    return christoffelDeriv

def Exp(x):
    # Exponential map is defined as Exp: T_0 S^n -> (th_1, ..., th_n) coordinate
    coord = np.zeros(x.shape)
    dim = x.shape[1]
    coord[:,0] = np.linalg.norm(x, axis = 1)
    sin_multi = np.array(coord[:,0])
    coord[:,-1] = np.arctan2(x[:,-1], x[:,-2])
    for i in range(1, dim-1):
        temp = x[:,i-1]/sin_multi
        temp[temp > 1] = 1
        temp[temp < -1] = -1
        coord[:,i] = np.arccos(temp)
        sin_multi *= np.sin(coord[:,i])
        
    return coord

def ExpInv(coord):
    # Exponential map is defined as Exp: T_0 S^n -> (th_1, ..., th_n) coordinate
    # Inverse is defined as Exp^-1: (th_1, ..., th_n) -> T_0 S^n
    x = np.zeros(coord.shape)
    dim = coord.shape[1]
    s = np.sin(coord)
    c = np.cos(coord)
    temp = np.array(coord[:,0])
    for i in range(dim-1):
        x[:,i] = temp * np.cos(coord[:,i+1])
        temp *= np.sin(coord[:,i+1])
    x[:,dim-1] = temp
    
    return x
    
def getJacobianExpInv(coord):
    # Jacobian of the Exp^-1: (th_1, ..., th_n) -> T_0 S^n
    N = coord.shape[0]
    dim = coord.shape[1]
    
    s = np.sin(coord)
    c = np.cos(coord)
    
    dx_dcoord = np.ones((N, dim, dim))
    dx_dcoord[:, :, 1:] *= coord[:,0].reshape((N,1,1))
    for i in range(dim):
        dx_dcoord[:, i, i+2:] = 0
        for j in range(i+2):
            for k in range(1, i+2):
                if j < dim and k < dim:
                    if j == i+1 and k == j:
                        dx_dcoord[:,i,j] *= -s[:,k]
                    elif (j > 0 and k == j) or k == i+1:
                        dx_dcoord[:,i,j] *= c[:,k]
                    else:
                        dx_dcoord[:,i,j] *= s[:,k]
    return dx_dcoord
    
def geometricScore_coord0_tangentGaussian(coord0, CovInv):
    # get geometric score from x ~ N(0, Cov) in tangent space of th = 0
    N = coord0.shape[0]
    dim = coord0.shape[1]
    gscore = np.zeros(coord0.shape)
    
    x = ExpInv(coord0)
    dx_dth = getJacobianExpInv(coord0)

    for i in range(N):
        gscore[i] = -np.dot(x[i:i+1], np.dot(CovInv, dx_dth[i]))
    th1 = coord0[:,0]
    gscore[:,0] += (dim-1)*(1.0/th1 - np.cos(th1)/np.sin(th1))
    
    return gscore

def geometricScore_tangentGaussian(coord, R, CovInv):
    # get geometric score from x ~ N(0, Cov) in tangent space of th = 0 in coord0
    N = coord.shape[0]
    dim = coord.shape[1]
    coord0 = coordChange(coord, np.eye(dim), R)
    gscore0 = geometricScore_coord0_tangentGaussian(coord0, CovInv)
    J = coordJacobian(coord, np.eye(dim), R)
    gscore = np.zeros(gscore0.shape)
    for i in range(N):
        gscore[i] = np.dot(gscore0[i], J[i])
    score = np.array(gscore)
    for i in range(dim):
        score[:,i] += (dim-i-1)*np.cos(coord[:,i]) / np.sin(coord[:,i])

    return gscore, score
   
    
def geometricScore_coord0_tangentGaussianMixture(coord0, weights, means, CovInvs):
    # get geometric score from x ~ p = w_i*p_i(means_i, Cov_i) 
    #in tangent space of th = 0
    # assume equal partition function for each tangentGaussian
    gscore = np.zeros(coord0.shape)
    N = coord0.shape[0]
    dim = coord0.shape[1]
    Nmix = weights.size
    
    x = ExpInv(coord0)
    dx_dth = getJacobianExpInv(coord0)
    
    p_i = np.zeros((N, Nmix))
    dp_i_dth = np.zeros((N, dim, Nmix))
    for i in range(Nmix):
        dx = x - means[i,:]
        dlogp_i_dth = np.zeros((N,dim))
        for j in range(N):
            p_i[j,i] = np.exp(-0.5*np.dot(np.dot(dx[j,:], CovInvs[i]), dx[j,:].T))
            dlogp_i_dth[j] = - np.dot(np.dot(dx[j,:], CovInvs[i]), dx_dth[j])
        dp_i_dth[:,:,i] = p_i[:,i:i+1]*dlogp_i_dth
        
    p = np.dot(p_i, weights.T)[:,0]
    for i in range(dim):
        gscore[:,i] = np.dot(dp_i_dth[:,i,:].reshape(N,Nmix), weights.T)[:,0] / p
    
    gscore[:,0] += (dim - 1)*(1.0/coord0[:,0] - 1.0/np.tan(coord0[:,0]))
    
    return gscore

def geometricScore_tangentGaussianMixture(coord, R, weights, means, CovInvs):
    # get geometric score from x ~ p = w_i*p_i(means_i, Cov_i) 
    #in tangent space of th = (0,0) in coord0
    # assume equal partition function for each tangentGaussian
    N = coord.shape[0]
    dim = coord.shape[1]
    coord0 = coordChange(coord, np.eye(dim), R)
    gscore0 = geometricScore_coord0_tangentGaussianMixture(coord0, weights, means, CovInvs)
    J = coordJacobian(coord, np.eye(dim), R)
    gscore = np.zeros(gscore0.shape)
    for i in range(N):
        gscore[i] = np.dot(gscore0[i], J[i])
    score = np.array(gscore)
    for i in range(dim):
        score[:,i] += (dim-i-1)*np.cos(coord[:,i]) / np.sin(coord[:,i])
        
    return gscore, score

def geodesicDistance(pos1, pos2):
    temp = np.sum(pos1*pos2, axis=0)
    temp[temp>1] = 1
    temp[temp<-1] = -1
    
    return np.arccos(temp)

def tangentVectorAngleError(pos0, pos1, pos2):
    delta1 = pos1 - pos0
    tanvec1 = delta1 - np.sum(delta1*pos0, axis=0)*pos0
    delta2 = pos2 - pos0
    tanvec2 = delta2 - np.sum(delta2*pos0, axis=0)*pos0
    temp = np.sum(tanvec2*tanvec1,axis=0) \
    / np.linalg.norm(tanvec1, axis=0) / np.linalg.norm(tanvec2, axis=0)
    temp[temp>1] = 1
    temp[temp<-1] = -1
    
    return np.arccos(temp)



############## hard code for S^3
def getData_S3(coord):
    # get R^4 from (th1, th2, th3)
    s1 = np.sin(coord[:,0])
    s2 = np.sin(coord[:,1])
    X1 = np.cos(coord[:,0])
    X2 = s1*np.cos(coord[:,1])
    X3 = s1*s2*np.cos(coord[:,2])
    X4 = s1*s2*np.sin(coord[:,2])
    return np.asarray([X1.T, X2.T, X3.T, X4.T])

def getJacobianPos_S3(coord):
    N = coord.shape[0]
    dpos_dcoord = np.zeros((N, 4, 3))
    s1 = np.sin(coord[:,0])
    c1 = np.cos(coord[:,0])
    s2 = np.sin(coord[:,1])
    c2 = np.cos(coord[:,1])
    s3 = np.sin(coord[:,2])
    c3 = np.cos(coord[:,2])
    dpos_dcoord[:,0,0] = -s1
    dpos_dcoord[:,1,0] = c1*c2
    dpos_dcoord[:,2,0] = c1*s2*c3
    dpos_dcoord[:,3,0] = c1*s2*s3
    dpos_dcoord[:,1,1] = -s1*s2
    dpos_dcoord[:,2,1] = s1*c2*c3
    dpos_dcoord[:,3,1] = s1*c2*s3
    dpos_dcoord[:,2,2] = -s1*s2*s3
    dpos_dcoord[:,3,2] = s1*s2*c3
    return dpos_dcoord

def getCoord_S3(data):
    # get (th1, th2, th3) from R^4
    th1 = np.arccos(data[0])
    th3 = np.arctan2(data[3], data[2])
    temp = data[1]/np.sin(th1)
    temp[temp > 1] = 1
    temp[temp < -1] = -1
    th2 = np.arccos(temp)
    #th2 = np.arctan2(data[2] / np.cos(th3), data[1])
    
    return np.asarray([th1, th2, th3]).T

def christoffelSum_S3(data):
    th1 = data[:,0]
    th2 = data[:,1]
    return np.array([2.0 * np.cos(th1)/np.sin(th1), \
                     np.cos(th2)/np.sin(th2), np.zeros(data.shape[0])]).T