import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
#%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.lines as lines
from mpl_toolkits.mplot3d import Axes3D


def metricSqrt_torch(data):
    # return only diagonal components
    dim = data.shape[1]
    if data.is_cuda:
        metricSqrt = Variable(torch.ones(data.shape)).cuda()
    else:
        metricSqrt = torch.ones(data.shape)
    s = torch.sin(data)
    for i in range(1,dim):
        metricSqrt[:,i] = (metricSqrt[:,i-1]).clone() * s[:,i-1]
    return metricSqrt

def metric_torch(data):
    # return only diagonal components
    metricSqrt = metricSqrt_torch(data)
    return metricSqrt**2

def metricDeriv_torch(data):
    N = data.shape[0]
    dim = data.shape[1]
    if data.is_cuda:
        metricDeriv = Variable(torch.zeros(N,dim,dim)).cuda()
    else:
        metricDeriv = torch.zeros(N,dim,dim)
    s = torch.sin(data)
    c = torch.cos(data)
    metric = metric_torch(data)
    for i in range(dim-1):
        metricDeriv[:,i+1:,i] = metric[:,i+1:] * 2.0*c[:,i:i+1]/s[:,i:i+1]
    return metricDeriv

def metricInvSqrt_torch(data):
    # return only diagonal components
    metricSqrt = metricSqrt_torch(data)
    return 1.0/metricSqrt

def metricInv_torch(data):
    # return only diagonal components
    metricSqrt = metricSqrt_torch(data)
    return 1.0/metricSqrt**2

def metricInvDeriv_torch(data):
    metricDeriv = metricDeriv_torch(data)
    metricInv = metricInvSqrt_torch(data)**2
    metricInvDeriv = torch.zeros(metricDeriv.shape)
    if data.is_cuda:
        metricInvDeriv = metricInvDeriv.cuda()
    dim = data.shape[1]
    for i in range(dim):
        metricInvDeriv[:,:,i] = - metricInv*metricDeriv[:,:,i]*metricInv
    return metricInvDeriv

def christoffelSum_torch(data):
    christoffel = torch.zeros(data.shape)
    if data.is_cuda:
        christoffel = christoffel.cuda()
    dim = data.shape[1]
    for i in range(dim):
        th = data[:,i]
        christoffel[:,i] = (dim - i - 1) * torch.cos(th) / torch.sin(th)
    
    return christoffel

def christoffelSumDeriv_torch(data):
    N = data.shape[0]
    dim = data.shape[1]
    christoffelDeriv = torch.zeros(N,dim,dim)
    if data.is_cuda:
        christoffelDeriv = christoffelDeriv.cuda()
    for i in range(dim):
        th = data[:,i]
        christoffelDeriv[:,i, i] = - (dim - i - 1) / torch.sin(th)**2
    
    return christoffelDeriv

def Exp_torch(x, eps=1e-10):
    # Exponential map is defined as Exp: T_0 S^n -> (th_1, ..., th_n) coordinate
    coord = torch.zeros(x.shape)
    if x.is_cuda:
        coord = coord.cuda()
    dim = x.shape[1]
    coord[:,0] = torch.sqrt(torch.sum(x**2, dim = 1))

    sin_multi = coord[:,0].clone()
    coord[:,-1] = torch.atan2(x[:,-1], x[:,-2])
    for i in range(1, dim-1):
        temp = x[:,i-1]/sin_multi
        temp[temp > 1] = 1
        temp[temp < -1] = -1
        coord[:,i] = torch.acos(temp)
        coord[torch.abs(sin_multi) < eps, i] = 0.
        sin_multi *= torch.sin(coord[:,i])
    return coord

def getCoord_torch(data, eps=1e-10):
    # get th from R^(n+1)
    dim = data.shape[1] - 1
    N = data.shape[0]
    th = torch.zeros(N, dim)
    log_sin_multi = torch.zeros(N)
    if data.is_cuda:
        th = th.cuda()
        log_sin_multi = log_sin_multi.cuda()
    th[:,0] = torch.acos(data[:,0])
    th[:,dim-1] = torch.atan2(data[:,dim], data[:,dim-1])
    
    for i in range(1, dim-1):
        log_sin_multi = torch.log(torch.sin(th[:,:i].clone())).sum(1)
        #sin_multi *= torch.sin(th[:,i-1])
        sin_multi = torch.exp(log_sin_multi)
        temp = data[:,i] / sin_multi
        temp[temp > 1] = 1
        temp[temp < -1] = -1
        th[:,i] = torch.acos(temp)
        th[torch.abs(sin_multi) < eps, i] = 0.
    return th

def getCoord_torch2(data, eps=1e-10):
    # get th from R^(n+1)
    dim = data.shape[1] - 1
    N = data.shape[0]
    th = torch.zeros(N, dim)
    sin_multi = torch.ones(N)
    if data.is_cuda:
        th = th.cuda()
        sin_multi = sin_multi.cuda()
    th[:,0] = torch.acos(data[:,0])
    th[:,dim-1] = torch.atan2(data[:,dim], data[:,dim-1])
    
    for i in range(1, dim-1):
        sin_multi = torch.prod(torch.sin(th[:,:i].clone()),1)
        #sin_multi *= torch.sin(th[:,i-1])
        temp = data[:,i] / sin_multi
        temp[temp > 1] = 1
        temp[temp < -1] = -1
        th[:,i] = torch.acos(temp)
        th[torch.abs(sin_multi) < eps, i] = 0.
    return th

def getPos_torch(th):
    # get R^(n+1) from th
    dim = th.shape[1]
    N = th.shape[0]
    pos = torch.zeros(N,dim+1)
    sin_multi = torch.ones(N)
    if th.is_cuda:
        pos = pos.cuda()
        sin_multi = sin_multi.cuda()
    pos[:,0] = torch.cos(th[:,0])
    for i in range(1, dim):
        sin_multi *= torch.sin(th[:,i-1])
        pos[:,i] = sin_multi * torch.cos(th[:,i])
    pos[:,-1] = sin_multi * torch.sin(th[:,-1])
    
    return pos

"""
def getPosJacobian_torch(th):
    # get the jacobian of R^(n+1) w.r.t. th
    dim = th.shape[1]
    N = th.shape[0]
    jac = torch.zeros(N,dim+1,dim)
    if th.is_cuda:
        jac = jac.cuda()
        
    jac[:,0,0] = -torch.sin(th[:,0])
    for i in range(1, dim+1):
        sin_multi = -jac[:,i-1,i-1]
        for j in range(i):
            if i < dim:
                jac[:,i,j] = sin_multi * torch.cos(th[:,j]) / torch.sin(th[:,j]) * torch.cos(th[:,i])
        if i < dim:
            jac[:,i,i] = -sin_multi * torch.sin(th[:,i])
    for j in range(dim):
        jac[:,dim,j] = sin_multi * torch.cos(th[:,j]) / torch.sin(th[:,j])
        
    return jac
"""

def getPosJacobian_torch(th, eps=1e-10):
    # get the jacobian of R^(n+1) w.r.t. th
    dim = th.shape[1]
    N = th.shape[0]
    data = getPos_torch(th)
    tan_th = torch.tan(th)
    jac = torch.zeros(N,dim+1,dim)
    if data.is_cuda:
        jac = jac.cuda()
    for i in range(dim):
        jac[:,i,i] = -data[:,i] * tan_th[:,i]
    for i in range(dim-1):
        jac[:,i+1:,i] = data[:,i+1:] / tan_th[:,i].view(-1,1)
        jac[torch.abs(tan_th[:,i]) < eps,i+1:,i] = 0.
        if i == 0:
            jac[torch.abs(tan_th[:,i]) < eps,i+1,i] = 1.
        else:
            jac[torch.abs(tan_th[:,i]) < eps,i+1,i] = -jac[torch.abs(tan_th[:,i]) < eps,i-1,i-1]
    jac[:,-2,-1] = -data[:,-1]
    jac[:,-1,-1] = data[:,-2]
        
    return jac

def getPosJacobianFromPos_torch(data, eps=1e-10):
    # get the jacobian of R^(n+1) w.r.t. th using data in R^(n+1)
    dim = data.shape[1] - 1
    N = data.shape[0]
    th = getCoord_torch(data)
    tan_th = torch.tan(th)
    jac = torch.zeros(N,dim+1,dim)
    if data.is_cuda:
        jac = jac.cuda()
    for i in range(dim):
        jac[:,i,i] = -data[:,i] * tan_th[:,i]
    for i in range(dim-1):
        jac[:,i+1:,i] = data[:,i+1:] / tan_th[:,i].view(-1,1)
        jac[torch.abs(tan_th[:,i]) < eps,i+1:,i] = 0.
        if i == 0:
            jac[torch.abs(tan_th[:,i]) < eps,i+1,i] = 1.
        else:
            jac[torch.abs(tan_th[:,i]) < eps,i+1,i] = -jac[torch.abs(tan_th[:,i]) < eps,i-1,i-1]
    jac[:,-2,-1] = -data[:,-1]
    jac[:,-1,-1] = data[:,-2]
        
    return jac

def getNormalizedJacobianFromPos_torch(data):
    # get the normalized jacobian of R^(n+1) w.r.t. th using data in R^(n+1)
    dim = data.shape[1] - 1
    N = data.shape[0]
    th = getCoord_torch(data)
    jac = torch.zeros(N,dim+1,dim)
    if th.is_cuda:
        jac = jac.cuda()
    sin_th = torch.sin(th)
    cos_th = torch.cos(th)
    tan_th = torch.tan(th)
    for i in range(dim):
        jac[:,i,i] = -sin_th[:,i]
        if i < dim-1:
            jac[:,i+1,i] = cos_th[:,i] * cos_th[:,i+1]
        else:
            jac[:,i+1,i] = cos_th[:,i]
    for i in range(2,dim+1):
        if i < dim:
            jac[:,i,:i-1] = jac[:,i-1,:i-1].clone() * tan_th[:,i-1].view(-1,1) * cos_th[:,i].view(-1,1)
        else:
            jac[:,i,:i-1] = jac[:,i-1,:i-1].clone() * tan_th[:,i-1].view(-1,1)
    return jac

def getPosHessianDiagonal_torch(th):
    # get the jacobian of R^(n+1) w.r.t. th
    dim = th.shape[1]
    N = th.shape[0]
    hessdiag = torch.zeros(N,dim+1,dim)
    if th.is_cuda:
        hessdiag = hessdiag.cuda()
        
    pos = getPos_torch(th)
    for j in range(dim):
        hessdiag[:,j:,j] = - pos[:,j:]
        
    return hessdiag

def safe_div_sin(input, eps = 1e-6):
    output = torch.zeros(input.shape)
    if input.is_cuda:
        output = output.cuda()
    output[input <= eps] = 1
    output[input > eps] = input[input > eps] / torch.sin(input[input > eps])
    return output

def safe_div_tan(input, eps = 1e-6):
    output = torch.zeros(input.shape)
    if input.is_cuda:
        output = output.cuda()
    output[input <= eps] = 1
    output[input > eps] = input[input > eps] / torch.tan(input[input > eps])
    return output

def getDist_torch(th1, th2, eps = 1e-6, returnjac = False, returnposjac = False):
    # if returnjac, also return the derivative of dist bw th1 and th2 w.r.t. th2
    # if returnposjac, also return the derivative of pos of th2 w.r.t. th2
    assert(th1.shape[1] == th2.shape[1])
    N1 = th1.shape[0]
    N2 = th2.shape[0]
    dim = th2.shape[1]
    pos1 = getPos_torch(th1)
    pos2 = getPos_torch(th2)
    distmat = torch.zeros(N1,N2)
    if th1.is_cuda:
        distmat = distmat.cuda()
    for i in range(N1):
        temp = torch.mm(pos2, pos1[i].view(-1,1))
        temp[temp > 1] = 1
        temp[temp < -1] = -1
        distmat[i] = torch.acos(temp).view(-1)
        
    if returnjac:
        posjac = getPosJacobian_torch(th2)
        #### numerical stability should be considered when distmat < eps
        distjac = torch.zeros(N1,N2,dim)
        if th1.is_cuda:
            distjac = distjac.cuda()
        
        tempmat = []
        for i in range(N1):
            tempmat.append(torch.matmul(pos1[i].view(1,1,dim+1), posjac).view(N2,dim))
        tempmat = torch.stack(tempmat, dim=0)
        """
        tempmat = - torch.matmul(
            pos1.view(N1,1,1,dim+1), 
            posjac.view(1,N2,dim+1,dim)
        ).view(N1,N2,dim)
        """
        distjac1 = []
        distjac2 = []
        for i in range(dim):
            tempmat_i = tempmat[:,:,i]
            distjac_i = distjac[:,:,i]
            distjac_i[distmat > eps] = - tempmat_i[distmat > eps]/ torch.sin(distmat[distmat > eps])
            posjac_norm_i = torch.sqrt(torch.sum(posjac[:,:,i]**2, dim=1)).view(1,N2).expand(N1,-1)
            distjac_i[distmat <= eps] = - torch.sign(tempmat_i[distmat <= eps])*posjac_norm_i[distmat <= eps]
            distjac[:, :, i] = distjac_i
        
        #############################################
        if returnposjac:
            return distmat, distjac, posjac
        
        return distmat, distjac
    
    return distmat

def getPairwiseDist_torch(th):
    x = getPos_torch(th)
    temp = (x.unsqueeze(0)*x.unsqueeze(1)).sum(-1)
    temp[temp>1] = 1
    temp[temp<-1] = -1
    return torch.acos(temp)
    
    
class SndataTangentGaussian(Dataset):
    def __init__(self, N, Cov_sqrt):
        self.dim = Cov_sqrt.shape[0]
        # assume x ~ N(0, Cov) in tangent space of th = 0
        x = torch.mm(torch.FloatTensor(N,self.dim).normal_(0.0, 1.0), Cov_sqrt)
        th = Exp_torch(x)
        self.train_data = th
        self.N = N
        self.metric_sqrt = metricSqrt_torch(th)
        self.metric = metric_torch(th)
        self.metricInv_sqrt = metricInvSqrt_torch(th)
    def __len__(self):
        return self.N
    def __getitem__(self, idx):
        return self.train_data[idx], self.metric_sqrt[idx], self.metricInv_sqrt[idx]

class SndataTangentGaussianMixture(Dataset):
    def __init__(self, N, means, Cov_sqrts):
        # assume mixture of x ~ N(mean, Cov) in tangent space of th = 0
        # assume equal weights for each mixture
        self.dim = means.size()[1]
        self.Nmix = means.size()[0]
        self.means = means
        self.Cov_sqrts = Cov_sqrts
        Ni = N // self.Nmix
        Nend = N - (self.Nmix-1) * Ni
        x = torch.zeros(N, self.dim)
        print(Ni)
        print(Nend)
        for i in range(self.Nmix-1):
            x[(i)*Ni:(i+1)*Ni] = torch.mm(torch.FloatTensor(Ni,self.dim).normal_(0.0, 1.0), 
                                          Cov_sqrts[i]) + means[i]
        x[(self.Nmix-1)*Ni:] = torch.mm(torch.FloatTensor(Nend,self.dim).normal_(0.0, 1.0), 
                                        Cov_sqrts[-1]) + means[-1]
        th = Exp_torch(x)
        self.train_data = th
        self.N = N
        self.metric_sqrt = metricSqrt_torch(th)
        self.metric = metric_torch(th)
        self.metricInv_sqrt = metricInvSqrt_torch(th)
    def __len__(self):
        return self.N
    def __getitem__(self, idx):
        return self.train_data[idx], self.metric_sqrt[idx], self.metricInv_sqrt[idx]
    
class SndataFromPos(Dataset):
    def __init__(self, pos):
        th = getCoord_torch(pos)
        self.train_data = th
        N = pos.size()[0]
        self.N = N
        self.dim = th.size()[1]
        self.metric_sqrt = metricSqrt_torch(th)
        self.metric = metric_torch(th)
        self.metricInv_sqrt = metricInvSqrt_torch(th)
    def __len__(self):
        return self.N
    def __getitem__(self, idx):
        return self.train_data[idx], self.metric_sqrt[idx], self.metricInv_sqrt[idx]

class SnPosdataFromTh(Dataset):
    def __init__(self, th):
        pos = getPos_torch(th)
        self.train_data = pos
        N = pos.size()[0]
        self.N = N
        self.dim = pos.size()[1]
    def __len__(self):
        return self.N
    def __getitem__(self, idx):
        return self.train_data[idx]
    
class SnPosdataFromPos(Dataset):
    def __init__(self, pos):
        self.train_data = pos
        N = pos.size()[0]
        self.N = N
        self.dim = pos.size()[1]
    def __len__(self):
        return self.N
    def __getitem__(self, idx):
        return self.train_data[idx]
    
class SnPosdataTangentGaussianMixture(Dataset):
    def __init__(self, N, means, Cov_sqrts):
        # assume mixture of x ~ N(mean, Cov) in tangent space of th = 0
        dim = means.size()[1]
        self.Nmix = means.size()[0]
        self.means = means
        self.Cov_sqrts = Cov_sqrts
        Ni = N // self.Nmix
        Nend = N - (self.Nmix-1) * Ni
        x = torch.zeros(N, dim)
        print(Ni)
        print(Nend)
        for i in range(self.Nmix-1):
            x[(i)*Ni:(i+1)*Ni] = torch.mm(torch.FloatTensor(Ni, dim).normal_(0.0, 1.0), 
                                          Cov_sqrts[i]) + means[i]
        x[(self.Nmix-1)*Ni:] = torch.mm(torch.FloatTensor(Nend, dim).normal_(0.0, 1.0), 
                                        Cov_sqrts[-1]) + means[-1]
        th = Exp_torch(x)
        self.train_data = getPos_torch(th)
        self.N = N
        self.dim = self.train_data.size()[1]
    def __len__(self):
        return self.N
    def __getitem__(self, idx):
        return self.train_data[idx]