import numpy as np
import torch
import time
from sph_n_DataUtil import *
from sph_n_ambient import *
from util import *

### implement Riemannian Least-Squares Log-Density Gradients method for hyperspherical data from
### M. Ashizawa, H. Sasaki, T. Sakai, M. Sugiyama, 'Least-squares log-density gradient clustering for riemannian manifolds'

def phi_xi_xiDerivDiag_func(centerPoints, data, sigma):
    # xi: covector form of psi in the paper (derivative of phi w.r.t. data)
    # phi: N_basis x N dimensional vector
    # xi: N_basis x N x dim dimensional tensor
    # xiDerivDiag: N_basis x N x dim dimensional tensor
    # sigma: N_basis dimensional vector or float
    
    N_basis = centerPoints.shape[0]
    N = data.shape[0]
    dim = data.shape[1]
    if not isinstance(sigma, float):
        sigma = sigma.view(N_basis, 1)
    
    distmat, distjac, posjac = getDist_torch(centerPoints, data, returnjac = True, returnposjac = True)
    
    ### calculate phi, xi, xiDerivDiag
    phi = torch.exp(- 0.5 * distmat**2 / sigma**2)
    
    xi = - (phi * distmat / sigma**2).view(N_basis,N,1).expand(-1,-1,dim) * distjac
    
    pos_hessdiag = getPosHessianDiagonal_torch(data)
    pos_center = getPos_torch(centerPoints)
    
    xiDerivDiag = - distjac * xi * distmat.view(N_basis,N,1).expand(-1,-1,dim) \
                  + distjac**2 * (phi * (safe_div_tan(distmat) - 1)).view(N_basis,N,1).expand(-1,-1,dim) \
                  + torch.matmul(
        pos_center.view(N_basis,1,1,dim+1).expand(-1,N,-1,-1),
        pos_hessdiag.view(1,N,dim+1,dim).expand(N_basis,-1,-1,-1)
    ).view(N_basis,N,dim) * (phi * safe_div_sin(distmat)).view(N_basis,N,1).expand(-1,-1,dim)
    xiDerivDiag /= sigma**2
    
    return phi, xi, xiDerivDiag
    


def xi_func(centerPoints, data, sigma, returnxiOnly = True, returnOthersAlso = False, memory_efficient = False):
    # xi: covector form of psi in the paper (derivative of phi w.r.t. data)
    # phi: N_basis x N dimensional vector
    # xi: N_basis x N x dim dimensional tensor
    # xiDerivDiag: N_basis x N x dim dimensional tensor
    # sigma: number (TO DO: should include N_basis dimensional vector later)
    # G: N_basis x N_basis dimensional matrix
    # h: N_basis dimensional vector
    
    N_basis = centerPoints.shape[0]
    N = data.shape[0]
    dim = data.shape[1]
    distmat, distjac, posjac = getDist_torch(centerPoints, data, returnjac = True, returnposjac = True)
    
    ### calculate phi, xi
    phi = torch.exp(- 0.5 * distmat**2 / sigma**2)
    
    xi = - (phi * distmat / sigma**2).view(N_basis,N,1).expand(-1,-1,dim) * distjac
    
    if returnxiOnly:
        if returnOthersAlso:
            return xi, phi, posjac
        return xi
    
    ### calculate xiDerivDiag
    pos_hessdiag = getPosHessianDiagonal_torch(data)
    pos_center = getPos_torch(centerPoints)
    if memory_efficient:
        xiDerivDiag = []
        for i in range(N_basis):
            xiDerivDiag_i = - distjac[i] * xi[i] * distmat[i].view(N,1).expand(-1,dim) \
                          + distjac[i]**2 * (phi[i] * (safe_div_tan(distmat[i]) - 1)).view(N,1).expand(-1,dim) \
                          + torch.matmul(
                pos_center[i].view(1,1,dim+1).expand(N,-1,-1),
                pos_hessdiag.view(N,dim+1,dim)
            ).view(N,dim) * (phi[i] * safe_div_sin(distmat[i])).view(N,1).expand(-1,dim)
            xiDerivDiag.append(xiDerivDiag_i)
        xiDerivDiag = torch.stack(xiDerivDiag, dim=0) / sigma**2
    else:
        xiDerivDiag = - distjac * xi * distmat.view(N_basis,N,1).expand(-1,-1,dim) \
                      + distjac**2 * (phi * (safe_div_tan(distmat) - 1)).view(N_basis,N,1).expand(-1,-1,dim) \
                      + torch.matmul(
            pos_center.view(N_basis,1,1,dim+1).expand(-1,N,-1,-1),
            pos_hessdiag.view(1,N,dim+1,dim).expand(N_basis,-1,-1,-1)
        ).view(N_basis,N,dim) * (phi * safe_div_sin(distmat)).view(N_basis,N,1).expand(-1,-1,dim)
        xiDerivDiag /= sigma**2
        
    ### calculate G and h
    metricInv_sqrt = metricInvSqrt_torch(data)
    metricInv = metricInv_sqrt**2
    temp0 = xi*metricInv_sqrt.view(1,N,dim)
    temp2 = (christoffelSum_torch(data)*metricInv_sqrt).view(1,N,dim).expand(N_basis,-1,-1)
    temp1 = temp0.view(N_basis,-1)
    G = torch.mm(temp1, temp1.permute(1,0)) / N
    h = ((xiDerivDiag*metricInv.view(1,N,dim)).sum((1,2)) + (temp0*temp2).sum((1,2))) / N

    return xi, G, h

class RLSLDG_sph_n:
    def __init__(self, _sigma, theta, centerPoints):
        ### _sigma: scalar
        self.dim = centerPoints.shape[1]
        self._sigma = _sigma
        self.centerPoints = centerPoints
        self.theta = theta
        self.centerPointsPos = getPos_torch(centerPoints)
        
        self.N_basis = centerPoints.shape[0]
        
    def estimate_gscore(self, x, calculateObj = False, memory_efficient = False):
        if (self.dim != x.shape[1]):
            print("check dimension of input")
        N = x.shape[0]
        ldg_est = torch.zeros(N, self.dim)
        if x.is_cuda:
            ldg_est = ldg_est.cuda()
        if calculateObj:
            dldg_dx_est = torch.zeros(N, self.dim)
            if x.is_cuda:
                dldg_dx_est = dldg_dx_est.cuda()
        
        if calculateObj is False:
            xi = xi_func(self.centerPoints, x, self._sigma, returnxiOnly = True, memory_efficient = memory_efficient)
            return torch.sum(self.theta.view(self.N_basis,1,1) * xi, dim = 0)
        
        xi, G, h = xi_func(self.centerPoints, x, self._sigma, returnxiOnly = False, memory_efficient = memory_efficient)
        gscore_est = torch.sum(self.theta.view(self.N_basis,1,1) * xi, dim = 0)
        obj = torch.mm(torch.mm(self.theta.view(1,-1), G), self.theta.view(-1,1)).view(-1) \
        + 2.0*torch.sum(self.theta.view(-1)*h)
        
        return gscore_est, obj
    
    def meanShiftUpdate(self, x, eps = 1e-7, returnLog=False, memory_efficient = False):
        # inputs are represented in the ambient space
        N_basis = self.centerPoints.shape[0]
        N = x.shape[0]
        dim = x.shape[1]
        if memory_efficient:
            distmat_set = []
            log_x_c_set = []
            for i in range(N_basis):
                cur_log_x_c, cur_distmat = logarithm_map(x, 
                                    self.centerPointsPos[i].unsqueeze(0).repeat(N,1), eps=1e-10, returnDistAlso = True)
                distmat_set.append(cur_distmat)
                log_x_c_set.append(cur_log_x_c)
            distmat = torch.stack(distmat_set, dim=0)
            log_x_c = torch.stack(log_x_c_set, dim=0)
        else:
            log_x_c, distmat = logarithm_map(x.unsqueeze(0).repeat(N_basis,1,1).view(-1,dim), 
                                    self.centerPointsPos.unsqueeze(1).repeat(1,N,1).view(-1,dim), eps=1e-10, returnDistAlso = True)
            log_x_c = log_x_c.view(N_basis, N, dim)
            distmat = distmat.view(N_basis, N)
        
        ### calculate phi
        phi = torch.exp(- 0.5 * distmat**2 / self._sigma**2)
        den = torch.sum(phi * self.theta.unsqueeze(-1), dim=0)
        num = (log_x_c * phi.unsqueeze(-1) * self.theta.view(N_basis,1,1)).sum(0)
        
        log_x_r = num / (den.view(-1,1) + eps)
        if returnLog:
            return log_x_r
        return exponential_map(x, log_x_r)
        
    
    def meanShiftUpdateInCoord(self, x, eps = 1e-7, returnCurPosAndLog=False, memory_efficient = False):
        # inputs are represented in spherical coordinates
        xi, phi, posjac =  xi_func(self.centerPoints, x, self._sigma, returnxiOnly = True, returnOthersAlso = True, 
                                   memory_efficient = memory_efficient)
        metricInv = metricInv_torch(x)
        den = torch.sum(phi * self.theta.unsqueeze(-1), dim=0)
        if memory_efficient:
            xi_amb = []
            for i in range(self.N_basis):
                xi_amb.append(torch.matmul((xi[i]*metricInv).unsqueeze(-2), posjac.permute(0,2,1)).squeeze(1))
            xi_amb = torch.stack(xi_amb, dim=0)
        else:
            xi_amb = torch.matmul((xi*metricInv.unsqueeze(0)).unsqueeze(-2), posjac.permute(0,2,1).unsqueeze(0)).squeeze(2)
        log_x_r = torch.sum(self.theta.view(self.N_basis,1,1) * xi_amb, dim=0) / (den.unsqueeze(-1) + eps) * self._sigma**2
        pos_x = getPos_torch(x)
        if returnCurPosAndLog:
            return pos_x, log_x_r
        pos_r = exponential_map(pos_x, log_x_r)
        return getCoord_torch(pos_r)
        
    
def RLSLDG_sph_n_trainer(data, centerPoints, _sigma, _lambda, memory_efficient = False):
    # data: input data to estimate the derivative of log density
    # centerPoints: centering points for basis functions
    # _sigma: hyperparameter to control bandwidth of basis function
    # _lambda: hyperparameter to control weight norm
    
    ####### remove centerPoints from traindata and testdata
    distmat = getDist_torch(data, centerPoints)
    eps = 1e-6
    checkmat = distmat < eps
    data = data[torch.sum(checkmat, dim=1) == 0]
    ##########################################
    
    N = data.shape[0]
    dim = data.shape[1]
    if dim != centerPoints.shape[1]:
        print("check dimension of input data and center points")
    N_basis = centerPoints.shape[0]
    
    Eye = torch.eye(N_basis)
    if data.is_cuda:
        Eye = Eye.cuda()
    #print("enter xi_func")
    _, G, h = xi_func(centerPoints, data, _sigma, returnxiOnly = False, memory_efficient = memory_efficient)
    #print("out from xi_func")
    theta, _ = torch.solve(h.view(-1,1), G + _lambda*Eye)
    theta = - theta.view(-1)

    return RLSLDG_sph_n(_sigma, theta, centerPoints)

def RLSLDG_sph_n_model_selection_by_cross_validation(data, centerPoints, sigma0, Ncv = 5, 
                                              sigmaSet = None, lambdaSet = None, loggingFileName = None, memory_efficient = False):
    if loggingFileName is None:
        logger = None
    else:
        logger = set_logger(loggingFileName)
        
    # set hyper parameter (sigma, lambda) from cross validation
    if sigmaSet is None:
        sigmaSet = torch.pow(torch.FloatTensor([10]), torch.FloatTensor([-2, -1.5, -1, -0.5, 0, 0.5, 1])) * sigma0
    if lambdaSet is None:
        lambdaSet = torch.pow(torch.FloatTensor([10]), torch.FloatTensor([-2, -1.5, -1, -0.5, 0, 0.5, 1]))
    holdoutErrors = torch.zeros(len(sigmaSet), len(lambdaSet))
    if data.is_cuda:
        holdoutErrors = holdoutErrors.cuda()
    N = data.shape[0]
    Nk = N // Ncv
    
    min_idx_sigma = 0
    min_idx_lambda = 0
    print_info("Start R-LSLDG model selection by {:d}-fold cross validation".format(Ncv), logger)
    for k in range(Ncv):
        start = time.time()
        if k < Ncv - 1:
            traindata_k = torch.cat((data[:k*Nk], data[(k+1)*Nk:]), 0)
            testdata_k = data[k*Nk:(k+1)*Nk]
        else:
            traindata_k = data[:k*Nk]
            testdata_k = data[k*Nk:]
        # train model
        ####### remove centerPoints from traindata and testdata
        distmat = getDist_torch(traindata_k, centerPoints)
        eps = 1e-6
        checkmat = distmat < eps
        traindata_k = traindata_k[torch.sum(checkmat, dim=1) == 0]
        distmat2 = getDist_torch(testdata_k, centerPoints)
        checkmat2 = distmat2 < eps
        testdata_k = testdata_k[torch.sum(checkmat2, dim=1) == 0]
        ##########################################
    
        print_info("k: {:d}, # of train data: {:d}, # of val data: {:d}".format(k, len(traindata_k), len(testdata_k)), logger)
        
        for i, _sigma in enumerate(sigmaSet):
            for j, _lambda in enumerate(lambdaSet):
                # calculate hold out error for cross validation
                holdout_error = 0

                rlsldg_temp = RLSLDG_sph_n_trainer(traindata_k, centerPoints, _sigma, _lambda, memory_efficient = memory_efficient)
                #print lsldg_temp.theta.T
                # calculate hold out error
                _, error = rlsldg_temp.estimate_gscore(testdata_k, calculateObj = True, memory_efficient = memory_efficient)
                #print error
                holdoutErrors[i,j] += error[0] / Ncv
        print_info("elapsed time: {:.1f}".format(time.time() - start), logger)
        
    # determine hyperparameters with minimum validation error
    for i, _sigma in enumerate(sigmaSet):
        for j, _lambda in enumerate(lambdaSet):
            if i == 0 and j == 0:
                min_error = holdoutErrors[i,j]
            else:
                if holdoutErrors[i,j] < min_error:
                    min_error = holdoutErrors[i,j]
                    min_idx_sigma = i
                    min_idx_lambda = j
                    
    # train with all data from selected parameter
    sigma_sel = sigmaSet[min_idx_sigma]
    lambda_sel = lambdaSet[min_idx_lambda]
    print_info("min estimated error: {:f}".format(min_error.item()), logger)
    print_info("selected sigma: {:f}".format(sigma_sel), logger)
    print_info("selected lambda: {:f}".format(lambda_sel), logger)
    
    return RLSLDG_sph_n_trainer(data, centerPoints, sigma_sel, lambda_sel, memory_efficient = memory_efficient), \
holdoutErrors, min_error