import numpy as np
import torch
import time
from util import *

### implement Least-Squares Log-Density Gradients method from
### H. Sasaki, A. Hyvarinen, M. Sugiyama, 'Clustering via Mode Seeking by Direct Estimation of the Gradient of a Log-Density'

def psi_func(centerPoints, data, sigma, returnpsiOnly = True, returnDiag = True):
    diff = centerPoints.unsqueeze(0) - data.unsqueeze(1)
    phi = torch.exp(- 0.5 * (diff*diff).sum(-1) / sigma**2)
    psi = diff * phi.unsqueeze(-1) / sigma**2
    
    if returnpsiOnly:
        return psi
    
    if returnDiag:
        dpsi_dx_diag = (diff**2 / sigma**4 - 1.0 / sigma**2) * phi.unsqueeze(-1)
        return psi, dpsi_dx_diag
    
    Eye = torch.eye(data.shape[1])
    if data.is_cuda:
        Eye = Eye.cuda()
    dpsi_dx = (diff.unsqueeze(-1)*diff.unsqueeze(-2) / sigma**4 - Eye / sigma**2) * phi.unsqueeze(-1).unsqueeze(-1)
    return psi, dpsi_dx


class LSLDG:
    def __init__(self, _sigma, theta, centerPoints):
        ### theta: N_basis x dim dimensional matrix
        ### _sigma: dim dimensional vector
        self.dim = centerPoints.shape[1]
        self._sigma = _sigma
        self.centerPoints = centerPoints
        self.theta = theta
        self.N_basis = centerPoints.shape[0]
        
    def estimateLDG(self, x, calculateObj = False):
        if (self.dim != x.shape[1]):
            raise Exception("check dimension of input")
        
        if not calculateObj:
            psi = psi_func(self.centerPoints, x, self._sigma, returnpsiOnly = True)
            ldg_est = (psi * self.theta.unsqueeze(0)).sum(1)
            return ldg_est
        
        N = x.shape[0]
        psi, dpsi_dx_diag = psi_func(self.centerPoints, x, self._sigma, returnpsiOnly = False, returnDiag = True)
        ldg_est = (psi * self.theta.unsqueeze(0)).sum(1)
        dldg_dx_diag_est = (dpsi_dx_diag * self.theta.unsqueeze(0)).sum(1)
            
        obj = (torch.sum(ldg_est*ldg_est) + 2*torch.sum(dldg_dx_diag_est)) / N
        return ldg_est, obj
        
    def estimateLDG_Deriv(self, x):
        if (self.dim != x.shape[1]):
            raise Exception("check dimension of input")
        
        psi, dpsi_dx = psi_func(self.centerPoints, x, self._sigma, returnpsiOnly = False, returnDiag = False)
        ldg_est = (psi * self.theta.unsqueeze(0)).sum(1)
        dldg_dx_est = (dpsi_dx * self.theta.unsqueeze(0).unsqueeze(-1)).sum(1)
        
        return dldg_dx_est, ldg_est

    def meanShiftUpdate(self, x, eps = 1e-7):
        if (self.dim != x.shape[1]):
            raise Exception("check dimension of input")
        diff = self.centerPoints.unsqueeze(0) - x.unsqueeze(1)
        phi = torch.exp(- 0.5 * (diff*diff).sum(-1) / self._sigma**2)
        den = (phi.unsqueeze(-1) * self.theta.unsqueeze(0)).sum(1)
        num = (phi.unsqueeze(-1) * (self.theta * self.centerPoints).unsqueeze(0)).sum(1)
        return num / (den + eps)
    
    
def LSLDG_trainer(data, centerPoints, _sigma, _lambda):
    # data: input data to estimate the derivative of log density
    # centerPoints: centering points for basis functions
    # _sigma: hyperparameter to control bandwidth of basis function
    # _lambda: hyperparameter to control weight norm
    N = data.shape[0]
    dim = data.shape[1]
    if dim != centerPoints.shape[1]:
        raise Exception("check dimension of input data and center points")
    N_basis = centerPoints.shape[0]
    Eye = torch.eye(N_basis)
    if data.is_cuda:
        Eye = Eye.cuda()
        
    psi, dpsi_dx_diag = psi_func(centerPoints, data, _sigma, returnpsiOnly = False, returnDiag = True)
    
    theta = []
    for i in range(dim):
        G_i = torch.mm(psi[:,:,i].permute(1,0), psi[:,:,i]) / N
        h_i = torch.sum(dpsi_dx_diag[:,:,i], dim = 0) / N
        temp, _ = torch.solve(h_i.view(-1,1), G_i + _lambda*Eye)
        theta.append(-temp.view(-1))
    theta = torch.stack(theta, dim=1)
    return LSLDG(_sigma, theta, centerPoints)
   

def LSLDG_model_selection_by_cross_validation(data, centerPoints, sigma0, Ncv = 5, 
                                              sigmaSet = None, lambdaSet = None, loggingFileName = None):
    if loggingFileName is None:
        logger = None
    else:
        logger = set_logger(loggingFileName)
        
    # set hyper parameter (sigma, lambda) from cross validation
    if sigmaSet is None:
        sigmaSet = torch.pow(torch.FloatTensor([10]), torch.FloatTensor([-2, -1.5, -1, -0.5, 0, 0.5, 1])) * sigma0
    if lambdaSet is None:
        lambdaSet = torch.pow(torch.FloatTensor([10]), torch.FloatTensor([-2, -1.5, -1, -0.5, 0, 0.5, 1]))
    holdoutErrors = torch.zeros(len(sigmaSet), len(lambdaSet))
    if data.is_cuda:
        holdoutErrors = holdoutErrors.cuda()
    N = data.shape[0]
    Nk = N // Ncv
    
    min_idx_sigma = 0
    min_idx_lambda = 0
    print_info("Start LSLDG model selection by {:d}-fold cross validation".format(Ncv), logger)
    for k in range(Ncv):
        start = time.time()
        if k < Ncv - 1:
            traindata_k = torch.cat((data[:k*Nk], data[(k+1)*Nk:]), 0)
            testdata_k = data[k*Nk:(k+1)*Nk]
        else:
            traindata_k = data[:k*Nk]
            testdata_k = data[k*Nk:]
            
        print_info("k: {:d}, # of train data: {:d}, # of val data: {:d}".format(k, len(traindata_k), len(testdata_k)), logger)
        
        for i, _sigma in enumerate(sigmaSet):
            for j, _lambda in enumerate(lambdaSet):
                # train model
                lsldg_temp = LSLDG_trainer(traindata_k, centerPoints, _sigma, _lambda)

                # calculate hold out error
                _, error = lsldg_temp.estimateLDG(testdata_k, calculateObj = True)
                holdoutErrors[i,j] += error.item() / Ncv
        print_info("elapsed time: {:.1f}".format(time.time() - start), logger)
        
    # determine hyperparameters with minimum validation error
    for i, _sigma in enumerate(sigmaSet):
        for j, _lambda in enumerate(lambdaSet):
            if i == 0 and j == 0:
                min_error = holdoutErrors[i,j]
            else:
                if holdoutErrors[i,j] < min_error:
                    min_error = holdoutErrors[i,j]
                    min_idx_sigma = i
                    min_idx_lambda = j
            
    # train with all data from selected parameter
    sigma_sel = sigmaSet[min_idx_sigma]
    lambda_sel = lambdaSet[min_idx_lambda]
    print_info("min estimated error: {:f}".format(min_error.item()), logger)
    print_info("selected sigma: {:f}".format(sigma_sel), logger)
    print_info("selected lambda: {:f}".format(lambda_sel), logger)
    
    return LSLDG_trainer(data, centerPoints, sigma_sel, lambda_sel), holdoutErrors, min_error
        

    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    