import numpy as np
import torch
import time
from Pn_util import *
from util import *

### implement Riemannian Least-Squares Log-Density Gradients method for P(n) data from
### M. Ashizawa, H. Sasaki, T. Sakai, M. Sugiyama, 'Least-squares log-density gradient clustering for riemannian manifolds'

def getSqDist_torch(data1, data2, eps1 = 1e-14, eps2 = 1e-7):
    assert(data1.shape[1] == data2.shape[1])
    N1 = data1.shape[0]
    N2 = data2.shape[0]
    dim = data1.shape[1]
    sqdistmat = torch.zeros(N1,N2).cuda()
    S, U = batch_eigsym(data1)
    S[S<eps1] = eps1
    X_sqrt, X_invsqrt = get_sqrt_sym(data1, eps = eps1, returnInvAlso = True, S = S, U = U)
    
    T = torch.matmul(torch.matmul(X_invsqrt.unsqueeze(1), data2.unsqueeze(0)), X_invsqrt.unsqueeze(1)).view(-1,dim,dim)
    S_T, U_T = batch_eigsym(T)
    S_T[S_T<eps2] = eps2
    logS = torch.log(S_T)
    return torch.sum(logS*logS, dim=-1).view(N1,N2)

def calculate_required_quantities(centerPoints, data, eps1 = 1e-14, eps2 = 1e-7, returnxiOnly = False, returnDataSqrtAlso = False, 
                                  S = None, U = None, X_sqrt = None, X_invsqrt = None):
    assert(centerPoints.shape[1] == data.shape[1])
    N = data.shape[0]
    Nc = centerPoints.shape[0]
    dim = data.shape[1]
    vec_dim = int(dim * (dim + 1) / 2)
    if S is None or U is None:
        S, U = batch_eigsym(data)
        S[S<eps1] = eps1
    if X_sqrt is None or X_invsqrt is None:
        X_sqrt, X_invsqrt = get_sqrt_sym(data, eps = eps1, returnInvAlso = True, S = S, U = U)
    X_inv = torch.inverse(data)
    X_inv = 0.5*(X_inv + X_inv.permute(0,2,1))
    Metric = metric_P_n(data, X_inv)
    
    ### variables to return
    sqdistSet = torch.cuda.FloatTensor(Nc, N).zero_()
    Log_X_C_vec_set = torch.cuda.FloatTensor(Nc, N, vec_dim).zero_()
    Log_X_C_covec_set = torch.cuda.FloatTensor(Nc, N, vec_dim).zero_()
    Log_X_C_vec_derivTrace_set = torch.cuda.FloatTensor(Nc, N).zero_()
    christoffelSum = christoffelSum_P_n(data, X_inv)
    
    T = torch.matmul(torch.matmul(X_invsqrt.unsqueeze(0), centerPoints.unsqueeze(1)), X_invsqrt.unsqueeze(0)).view(-1,dim,dim)
    S_T, U_T = batch_eigsym(T)
    S_T[S_T<eps2] = eps2
    logS = torch.log(S_T)
    sqdistSet = torch.sum(logS*logS, dim=-1).view(Nc, N)
    
    LogT = Log_mat(T, eps = eps2, S = S_T, U = U_T)
    Log_X_C_vec_set = mat2vec(torch.matmul(torch.matmul(X_sqrt.unsqueeze(0), LogT.view(Nc, N, dim, dim)), 
                                              X_sqrt.unsqueeze(0)).view(-1,dim,dim)).view(Nc, N, vec_dim)
    if not returnxiOnly:
        Log_X_C_covec_set = torch.matmul(
            Metric.unsqueeze(0), Log_X_C_vec_set.unsqueeze(-1)).squeeze(-1)
        Log_X_C_vec_derivTrace_set = torch.sum(
            logarithmMapDeriv_P_n(data.unsqueeze(0).repeat(Nc,1,1,1).view(-1,dim,dim), 
                                  centerPoints.unsqueeze(1).repeat(1,N,1,1).view(-1,dim,dim), 
                                  eps1 = eps1, eps2 = eps2, 
                                  S = S.unsqueeze(0).repeat(Nc,1,1).view(-1,dim), 
                                  U = U.unsqueeze(0).repeat(Nc,1,1,1).view(-1,dim,dim), 
                                  X_sqrt = X_sqrt.unsqueeze(0).repeat(Nc,1,1,1).view(-1,dim,dim), 
                                  X_invsqrt = X_invsqrt.unsqueeze(0).repeat(Nc,1,1,1).view(-1,dim,dim), 
                                  T = T, S_T = S_T, U_T = U_T, LogT = LogT)  
            * torch.eye(vec_dim).cuda().view(1,vec_dim,vec_dim), axis=(1,2)).view(Nc, N)
    if returnDataSqrtAlso:
        return (sqdistSet, Log_X_C_vec_set, Log_X_C_covec_set, Log_X_C_vec_derivTrace_set, christoffelSum), X_sqrt, X_invsqrt
    return (sqdistSet, Log_X_C_vec_set, Log_X_C_covec_set, Log_X_C_vec_derivTrace_set, christoffelSum)

def xi_func(quantities, sigma, returnxiOnly = True, eps1 = 1e-14, eps2 = 1e-7):
    # xi: N_basis x N x vec_dim dimensional tensor --- vector form of psi in the paper (derivative of phi w.r.t. data * metricInv)
    # phi: N_basis x N dimensional vector
    # sigma: number (TO DO: should include N_basis dimensional vector later)
    # G: N_basis x N_basis dimensional matrix
    # h: N_basis dimensional vector
    
    (sqdistSet, Log_X_C_vec_set, Log_X_C_covec_set, Log_X_C_vec_derivTrace_set, christoffelSum) = quantities
    
    N_basis = sqdistSet.shape[0]
    N = sqdistSet.shape[1]
    vec_dim = Log_X_C_vec_set.shape[-1]
    
    ### calculate phi, xi
    phi = torch.exp(- 0.5 * sqdistSet / sigma**2)
    xi = phi.unsqueeze(-1)*Log_X_C_vec_set/sigma**2
    if returnxiOnly:
        return xi
    
    G = torch.mm((phi.unsqueeze(-1)*Log_X_C_vec_set).view(N_basis, -1), 
                  (phi.unsqueeze(-1)*Log_X_C_covec_set).view(N_basis, -1).permute(1,0)) / N / sigma**4
    h = torch.sum(xi*christoffelSum.unsqueeze(0), axis=(1,2)) / N \
    + torch.sum(phi*Log_X_C_vec_derivTrace_set, axis=-1) / N / sigma**2 \
    + torch.sum(phi*torch.sum(Log_X_C_vec_set*Log_X_C_covec_set, axis=-1), axis=-1) / N / sigma**4
    
    return xi, G, h


class RLSLDG_Pn:
    def __init__(self, _sigma, theta, centerPoints):
        ### _sigma: scalar
        assert(len(theta) == centerPoints.shape[0])
        self.dim = centerPoints.shape[1]
        self._sigma = _sigma
        self.centerPoints = centerPoints
        self.theta = theta
        self.N_basis = centerPoints.shape[0]
        
    def estimate_gscore(self, x, calculateObj = False):
        if (self.dim != x.shape[1]):
            print("check dimension of input")
        N = x.shape[0]
        
        metric = metric_P_n(x)
        
        if calculateObj is False:
            quantities = calculate_required_quantities(self.centerPoints, x, returnxiOnly = True)
            xi = xi_func(quantities, self._sigma, returnxiOnly = True)
            gscore_est = torch.bmm(metric, torch.sum(self.theta.view(self.N_basis,1,1) * xi, dim = 0).unsqueeze(-1)).squeeze(-1)
            return gscore_est
        
        quantities = calculate_required_quantities(self.centerPoints, x, returnxiOnly = False)
        xi, G, h = xi_func(quantities, self._sigma, returnxiOnly = False)
        
        gscore_est = torch.bmm(metric, torch.sum(self.theta.view(self.N_basis,1,1) * xi, dim = 0).unsqueeze(-1)).squeeze(-1)
        obj = torch.mm(torch.mm(self.theta.view(1,-1), G), self.theta.view(-1,1)).view(-1) \
        + 2.0*torch.sum(self.theta.view(-1)*h)
        
        return gscore_est, obj
    
    def meanShiftUpdate(self, x, eps = 1e-7, returnDir=False, quantities_for_x = None):
        # set input x and output as matrices
        if quantities_for_x is None:
            S, U, X_sqrt, X_invsqrt = None, None, None, None
        else:
            S, U, X_sqrt, X_invsqrt = quantities_for_x
        quantities, X_sqrt, X_invsqrt = calculate_required_quantities(self.centerPoints, x, returnxiOnly = True, returnDataSqrtAlso = True, 
                                                                     S = S, U = U, X_sqrt = X_sqrt, X_invsqrt = X_invsqrt)
        (sqdistSet, Log_X_C_vec_set, _, _, _) = quantities
    
        N_basis = sqdistSet.shape[0]
        N = sqdistSet.shape[1]
        vec_dim = Log_X_C_vec_set.shape[-1]

        ### calculate phi and update direction
        phi = torch.exp(- 0.5 * sqdistSet / self._sigma**2)
        num = (phi.unsqueeze(-1) * self.theta.view(N_basis, 1, 1) * Log_X_C_vec_set).sum(0)
        den = (phi * self.theta.unsqueeze(1)).sum(0).unsqueeze(-1)
        log_x_r = num / (den + eps)
        
        if returnDir:
            return log_x_r
        V = torch.bmm(torch.bmm(X_invsqrt, vec2mat(log_x_r)), X_invsqrt)
        return torch.bmm(torch.bmm(X_sqrt, Exp_mat(V)), X_sqrt)
        
    
def RLSLDG_Pn_trainer(data, centerPoints, _sigma, _lambda, quantities = None):
    # data: input data to estimate the derivative of log density
    # centerPoints: centering points for basis functions
    # _sigma: hyperparameter to control bandwidth of basis function
    # _lambda: hyperparameter to control weight norm
    
    if quantities is None:
        ####### remove centerPoints from traindata and testdata
        sqdistmat = getSqDist_torch(data, centerPoints)
        eps = 1e-6
        checkmat = sqdistmat < eps
        data = data[torch.sum(checkmat, dim=1) == 0]
        ##########################################
        quantities = calculate_required_quantities(centerPoints, data, returnxiOnly = False)
    
    N = data.shape[0]
    dim = data.shape[1]
    if dim != centerPoints.shape[1]:
        print("check dimension of input data and center points")
    N_basis = centerPoints.shape[0]
    _, G, h = xi_func(quantities, _sigma, returnxiOnly = False)
    theta, _ = torch.solve(h.view(-1,1), G + _lambda*torch.eye(N_basis).cuda())
    theta = - theta.view(-1)

    return RLSLDG_Pn(_sigma, theta, centerPoints)

def RLSLDG_Pn_model_selection_by_cross_validation(data, centerPoints, sigma0, Ncv = 5, 
                                              sigmaSet = None, lambdaSet = None, loggingFileName = None):
    if loggingFileName is None:
        logger = None
    else:
        logger = set_logger(loggingFileName)
        
    # set hyper parameter (sigma, lambda) from cross validation
    if sigmaSet is None:
        sigmaSet = torch.pow(torch.FloatTensor([10]), torch.FloatTensor([-2, -1.5, -1, -0.5, 0, 0.5, 1])) * sigma0
    if lambdaSet is None:
        lambdaSet = torch.pow(torch.FloatTensor([10]), torch.FloatTensor([-2, -1.5, -1, -0.5, 0, 0.5, 1]))
    holdoutErrors = torch.zeros(len(sigmaSet), len(lambdaSet)).cuda()
    N = data.shape[0]
    Nk = N // Ncv
    
    min_idx_sigma = 0
    min_idx_lambda = 0
    print_info("Start R-LSLDG model selection by {:d}-fold cross validation".format(Ncv), logger)
    for k in range(Ncv):
        start = time.time()
        if k < Ncv - 1:
            traindata_k = torch.cat((data[:k*Nk], data[(k+1)*Nk:]), 0)
            testdata_k = data[k*Nk:(k+1)*Nk]
        else:
            traindata_k = data[:k*Nk]
            testdata_k = data[k*Nk:]
        # train model
        ####### remove centerPoints from traindata and testdata
        sqdistmat = getSqDist_torch(traindata_k, centerPoints)
        eps = 1e-6
        checkmat = sqdistmat < eps
        traindata_k = traindata_k[torch.sum(checkmat, dim=1) == 0]
        sqdistmat2 = getSqDist_torch(testdata_k, centerPoints)
        checkmat2 = sqdistmat2 < eps
        testdata_k = testdata_k[torch.sum(checkmat2, dim=1) == 0]
        ##########################################

        print_info("k: {:d}, # of train data: {:d}, # of val data: {:d}".format(k, len(traindata_k), len(testdata_k)), logger)
        
        ##### calculate required quantities
        quantities_temp = calculate_required_quantities(centerPoints, traindata_k, returnxiOnly = False)
        
        for i, _sigma in enumerate(sigmaSet):
            for j, _lambda in enumerate(lambdaSet):
                # calculate hold out error for cross validation
                holdout_error = 0
                rlsldg_temp = RLSLDG_Pn_trainer(traindata_k, centerPoints, _sigma, _lambda, quantities = quantities_temp)
                #print lsldg_temp.theta.T
                # calculate hold out error
                _, error = rlsldg_temp.estimate_gscore(testdata_k, calculateObj = True)
                #print error
                holdoutErrors[i,j] += error[0] / Ncv
        print_info("elapsed time: {:.1f}".format(time.time() - start), logger)
        
    # determine hyperparameters with minimum validation error
    for i, _sigma in enumerate(sigmaSet):
        for j, _lambda in enumerate(lambdaSet):
            if i == 0 and j == 0:
                min_error = holdoutErrors[i,j]
            else:
                if holdoutErrors[i,j] < min_error:
                    min_error = holdoutErrors[i,j]
                    min_idx_sigma = i
                    min_idx_lambda = j
        
    # train with all data from selected parameter
    sigma_sel = sigmaSet[min_idx_sigma]
    lambda_sel = lambdaSet[min_idx_lambda]
    print_info("min estimated error: {:f}".format(min_error.item()), logger)
    print_info("selected sigma: {:f}".format(sigma_sel), logger)
    print_info("selected lambda: {:f}".format(lambda_sel), logger)
    
    return RLSLDG_Pn_trainer(data, centerPoints, sigma_sel, lambda_sel), holdoutErrors, min_error