import numpy as np
import time
import copy
import torch
from tensor_data_util import christoffelSum, christoffelSum_N_n, metric_N_n, metricInv_sqrt_N_n

def dae_estimate_score(data, model):
    if model.autoencoder[0].weight.is_cuda:
        data = data.cuda()
    return model.estimate_score(data)

def dae_estimate_score_from_recon(data, recon, var):
    diff = recon - data
    score_est = diff / var
    return score_est

def dae_estimate_score_error(data, est, model, var, weight = None):
    # est: score est result of DAE
    N = data.shape[0]
    dim = data.shape[1]
    if model.autoencoder[0].weight.is_cuda:
        drdx = model.get_derivative(data.cuda()).data
        term2 = 2.0/var * (drdx * torch.eye(dim).cuda()).sum((1,2))
    else:
        drdx = model.get_derivative(data).data
        term2 = 2.0/var * (drdx * torch.eye(dim)).sum((1,2))
    term1 = torch.sum(est**2, dim=1)
    if weight is None:
        return torch.sum(term1.data + term2.data)/N - 2.0 * dim / model.noise_std**2
    return torch.sum((term1.data + term2.data)*weight)/N - 2.0 * dim / model.noise_std**2 * torch.sum(weight) / N

def dae_estimate_score_deriv(data, model, var):
    dim = data.shape[1]
    if model.autoencoder[0].weight.is_cuda:
        drdx = model.get_derivative(data.cuda()).data.cpu()
    else:
        drdx = model.get_derivative(data).data
    return (drdx - torch.eye(dim)) / var

def DTI_dae_estimate_score_error_truncated(posAndLogvec, idxSet, model, printError = False, weight = None):
    ### in this case, assume Euclidean metric for R^3 part of the input space (R^3 x P(3))
    score_error = 0
    for i in range(len(idxSet) - 1):
        N = idxSet[i+1] - idxSet[i]
        input = posAndLogvec[idxSet[i]:idxSet[i+1]]
        score_est = model.estimate_score(input.cuda())
        if weight is None:
            cur_weight = None
        else:
            cur_weight = weight[idxSet[i]:idxSet[i+1]].cuda()
        score_error_temp = dae_estimate_score_error(input.cuda(), score_est, model, model.noise_std**2, 
                                                    weight = cur_weight)
        if printError:
            print(score_error_temp)
        score_error += score_error_temp * N
    score_error /= idxSet[-1]
    if printError:
        print(score_error)
    return score_error

def gae_estimate_score(data, metric, model):
    if model.autoencoder[0].weight.is_cuda:
        data = data.cuda()
        metric = metric.cuda()
    return model.estimate_score(data, metric)

def gae_estimate_score_from_recon(data, recon, metric, var, diagonal_metric=False):
    diff = recon - data
    if diagonal_metric:
        score_est = (diff * metric) / var
    else:
        if type(data) is np.ndarray:
            N = data.shape[0]
            dim = data.shape[1]
            score_est = np.sum(diff.reshape((N,dim,1))*metric, axis=1) / var
        else:
            score_est = torch.bmm(diff.view(diff.size()[0], 1, diff.size()[1]), metric) \
        / var
    return score_est

def gae_reconstruction_numpy(data_numpy, model):
    dataVariable = torch.FloatTensor(data_numpy)
    use_gpu = False
    if model.autoencoder[0].weight.is_cuda:
        use_gpu = True
    if use_gpu:
        dataVariable = dataVariable.cuda()
        recon_numpy = model.clean_forward(dataVariable).data.cpu().numpy()
    else:
        recon_numpy = model.clean_forward(dataVariable).data.numpy()
        
    return recon_numpy

def gae_estimate_score_error(data, est, model, var, metricInv_sqrt, 
                             christoffelSum = None, metricDeriv = None, metricInv = None, diagonal_metric=False,
                            weight = None):
    if (metricDeriv is None or metricInv is None) and christoffelSum is None:
        print("At least one of either (metricDeriv and metricInv) or christoffelSum should not be None!!!")
    # est: gscore est result of GDAE
    # ignore constant for term2
    N = data.shape[0]
    dim = data.shape[1]
    useGPU = False
    if model.autoencoder[0].weight.is_cuda:
        useGPU = True
        drdx = model.get_derivative(data.cuda()).data
        term2 = 2.0/var * (drdx * torch.eye(dim).cuda()).sum((1,2))
    else:
        drdx = model.get_derivative(data).data
        term2 = 2.0/var * (drdx * torch.eye(dim)).sum((1,2))
    if diagonal_metric:
        temp = est * metricInv_sqrt
        if christoffelSum is not None:
            temp2 = 2.0 * metricInv_sqrt * christoffelSum
        else:
            temp2 = metricInv_sqrt * torch.sum(metricInv.view(N,dim,1)*metricDeriv, dim=1)
    else:
        temp = torch.bmm(est.view(N, 1, dim), metricInv_sqrt.permute(0,2,1)).view(N, dim)
        if christoffelSum is not None:
            temp2 = 2.0*torch.bmm(christoffelSum.view(N,1,dim), metricInv_sqrt.permute(0,2,1)).view(N,dim)
        else:
            temp2 = torch.zeros(N,dim)
            for i in range(dim):
                temp2[:,i] = torch.sum( torch.sum( torch.bmm(metricInv,
                                    metricDeriv[:,:,:,i]) * torch.eye(dim), dim=2), dim=1)
            temp2 = torch.bmm(temp2.view(N,1,dim), metricInv_sqrt.permute(0,2,1)).view(N, dim)
    term1 = torch.sum(temp**2, dim=1)
    term3 = torch.sum(temp*temp2, dim=1)
    if weight is None:
        return torch.sum(term1.data + term2.data + term3.data)/N - 2.0 * dim / model.noise_std**2
    return torch.sum((term1.data + term2.data + term3.data)*weight.view(-1)) / N \
                     - 2.0 * dim / model.noise_std**2 * torch.sum(weight) / N

def DTI_gae_estimate_score_error_truncated(posAndLogvec, posMetric, posMetricInv_sqrt, idxSet, model, 
                                           printError = False, pos_dim = 3, weight = None):
    cov_dim = int(pos_dim*(pos_dim+1)/2)
    dim = pos_dim + cov_dim
    
    gscore_error = 0
    for i in range(len(idxSet) - 1):
        N = idxSet[i+1] - idxSet[i]
        input = posAndLogvec[idxSet[i]:idxSet[i+1]]
        christoffel_sum = christoffelSum(input)
        gae_score_est = model.estimate_score(input.cuda(), 
                                        posMetric[idxSet[i]:idxSet[i+1]].cuda()).cpu()
        metricInv_sqrt = torch.zeros(N, dim, dim)
        metricInv_sqrt[:,:pos_dim,:pos_dim] = posMetricInv_sqrt[idxSet[i]:idxSet[i+1]]
        metricInv_sqrt[:,pos_dim:,pos_dim:] = (torch.eye(cov_dim) / \
                                               model.cov_metric_coeff_sqrt).view(1,cov_dim,cov_dim).expand(N,-1,-1)
        if weight is None:
            cur_weight = None
        else:
            cur_weight = weight[idxSet[i]:idxSet[i+1]].cuda()
        gscore_error_temp = gae_estimate_score_error(input.cuda(), gae_score_est.cuda(), model, model.noise_std**2, 
                            metricInv_sqrt.cuda(), christoffelSum = christoffel_sum.cuda(), 
                            metricDeriv = None, metricInv = None, diagonal_metric=False, 
                                                     weight = cur_weight)
        if printError:
            print(gscore_error_temp)
        gscore_error += gscore_error_temp * N
    gscore_error /= idxSet[-1]
    if printError:
        print(gscore_error)
    return gscore_error

def gae_N_n_estimate_score_error(data, est, model, var, metricInv_sqrt, 
                             christoffelSum = None, metricDeriv = None, metricInv = None, diagonal_metric=False,
                            weight = None):
    if (metricDeriv is None or metricInv is None) and christoffelSum is None:
        print("At least one of either (metricDeriv and metricInv) or christoffelSum should not be None!!!")
    # est: gscore est result of GDAE
    # consider symmetric metricInv_sqrt
    # ignore constant for term2
    N = data.shape[0]
    dim = data.shape[1]
    useGPU = False
    cov_sqrt = metricInv_sqrt[:,:model.pos_dim,:model.pos_dim]
    ### 22.05.19, fix to reflect model.cov_coeff in estimating score error
    if model.autoencoder[0].weight.is_cuda:
        useGPU = True
        drdx = model.get_reconstruction_derivative(data.cuda(), cov_sqrt.cuda()).data
        term2 = 2.0/var * (drdx * torch.eye(dim).cuda()).sum((1,2))
    else:
        drdx = model.get_reconstruction_derivative(data, cov_sqrt).data
        term2 = 2.0/var * (drdx * torch.eye(dim)).sum((1,2))
        
    if diagonal_metric:
        temp_metricInv_sqrt = metricInv_sqrt.clone()
        temp_metricInv_sqrt[:,:model.pos_dim] *= model.cov_coeff_sqrt
        temp = est * temp_metricInv_sqrt
        if christoffelSum is not None:
            temp2 = 2.0 * temp_metricInv_sqrt * christoffelSum / model.cov_coeff # to consider cov_coeff in christoffelSum term...
        else:
            temp2 = temp_metricInv_sqrt * torch.sum(metricInv.view(N,dim,1)*metricDeriv, dim=1) / model.cov_coeff
    else:
        temp_metricInv_sqrt = metricInv_sqrt.clone()
        temp_metricInv_sqrt[:,:model.pos_dim,:model.pos_dim] *= model.cov_coeff_sqrt
        temp = torch.bmm(est.view(N, 1, dim), temp_metricInv_sqrt.permute(0,2,1)).view(N, dim)
        if christoffelSum is not None:
            temp2 = 2.0*torch.bmm(christoffelSum.view(N,1,dim), temp_metricInv_sqrt.permute(0,2,1)).view(N,dim) / model.cov_coeff
        else:
            temp2 = torch.zeros(N,dim)
            for i in range(dim):
                temp2[:,i] = torch.sum( torch.sum( torch.bmm(metricInv,
                                    metricDeriv[:,:,:,i]) * torch.eye(dim), dim=2), dim=1) / model.cov_coeff
            temp2 = torch.bmm(temp2.view(N,1,dim), temp_metricInv_sqrt.permute(0,2,1)).view(N, dim)
    term1 = torch.sum(temp**2, dim=1)
    term3 = torch.sum(temp*temp2, dim=1)
    if weight is None:
        return torch.sum(term1.data + term2.data + term3.data)/N - 2.0 * dim / model.noise_std**2
    return torch.sum((term1.data + term2.data + term3.data)*weight.view(-1)) / N \
                     - 2.0 * dim / model.noise_std**2 * torch.sum(weight) / N

def gae_N_n_estimate_score_error_truncated(posAndCov, covInv, cov_sqrt, idxSet, model, 
                                           printError = False, pos_dim = 3, weight = None):
    cov_dim = int(pos_dim*(pos_dim+1)/2)
    dim = pos_dim + cov_dim
    
    gscore_error = 0
    for i in range(len(idxSet) - 1):
        N = idxSet[i+1] - idxSet[i]
        input = posAndCov[idxSet[i]:idxSet[i+1]]
        christoffel_sum = christoffelSum_N_n(input)
        metricInv_sqrt, metric = metricInv_sqrt_N_n(input, covInv[idxSet[i]:idxSet[i+1]], returnMetric = True)
        metricInv_sqrt[:,:pos_dim,:pos_dim] *= model.cov_coeff_sqrt
        metricInv_sqrt[:,pos_dim:,pos_dim:] *= model.cov_coeff
        
        gae_score_est = model.estimate_score(input.cuda(), metric.cuda(), cov_sqrt[idxSet[i]:idxSet[i+1]].cuda()).cpu()
        
        if weight is None:
            cur_weight = None
        else:
            cur_weight = weight[idxSet[i]:idxSet[i+1]].cuda()
        gscore_error_temp = gae_N_n_estimate_score_error(input.cuda(), gae_score_est.cuda(), model, model.noise_std**2, 
                            metricInv_sqrt.cuda(), christoffelSum = christoffel_sum.cuda(), 
                            metricDeriv = None, metricInv = None, diagonal_metric=False, 
                                                     weight = cur_weight)
        if printError:
            print(gscore_error_temp)
        gscore_error += gscore_error_temp * N
    gscore_error /= idxSet[-1]
    if printError:
        print(gscore_error)
    return gscore_error

def estimate_gscore_error(est, estDeriv, metricInv, metricInv_sqrt, metricInvDeriv, 
                          christoffelSum = None, metricDeriv = None, diagonal_metric=False, weight = None):
    if metricDeriv is None and christoffelSum is None:
        print("At least one of either metricDeriv or christoffelSum should not be None!!!")
    # est: gscore est result of DAE, LSLDG
    N = est.shape[0]
    dim = est.shape[1]
    
    if diagonal_metric:
        temp = est * metricInv_sqrt
        if christoffelSum is not None:
            temp2 = 2.0 * metricInv_sqrt * christoffelSum
        else:
            temp2 = metricInv_sqrt * torch.sum(metricInv.view(N,dim,1)*metricDeriv, dim=1)
        term2_1 = 2.0*(torch.bmm(metricInv.view(N,1,dim), estDeriv * torch.eye(dim))).sum((1,2))
        term2_2 = 2.0*(torch.bmm(est.view(N,1,dim), metricInvDeriv * torch.eye(dim))).sum((1,2))
    else:
        temp = torch.bmm(est.view(N, 1, dim), metricInv_sqrt.permute(0,2,1)).view(N, dim)
        if christoffelSum is not None:
            temp2 = 2.0*torch.bmm(christoffelSum.view(N,1,dim), metricInv_sqrt.permute(0,2,1)).view(N,dim)
        else:
            temp2 = torch.zeros(N,dim)
            for i in range(dim):
                temp2[:,i] = torch.sum( torch.sum( torch.bmm(metricInv,
                                    metricDeriv[:,:,:,i]) * torch.eye(dim), dim=2), dim=1)
            temp2 = torch.bmm(temp2.view(N,1,dim), metricInv_sqrt.permute(0,2,1)).view(N, dim)
        term2_1 = 2.0*(estDeriv * metricInv).sum((1,2))
        term2_2 = 2.0*torch.sum(est * (metricInvDeriv * \
                                       torch.eye(dim).cuda().view(1,1,dim,dim).expand(N,dim,-1,-1)).sum((2,3))
                               , dim = 1)
        # note that term2_2 is zero for DTI with log-Euclidean metric
    term1 = torch.sum(temp**2, dim=1)
    term3 = torch.sum(temp*temp2, dim=1)
    if weight is None:
        return torch.sum(term1.data + term2_1.data + term2_2.data + term3.data)/N
    return torch.sum((term1.data + term2_1.data + term2_2.data + term3.data)*weight.view(-1))/N

def DTI_estimate_gscore_error_truncated(posAndLogvec, posMetric, posMetricInv_sqrt, posMetricInv, posMetricInvDeriv, 
                                        idxSet, model, printError = False, pos_dim = 3, cov_metric_coeff_sqrt = 1.0,
                                       weight = None):
    cov_dim = int(pos_dim*(pos_dim+1)/2)
    dim = pos_dim + cov_dim
    
    gscore_error = 0
    for i in range(len(idxSet) - 1):
        N = idxSet[i+1] - idxSet[i]
        input = posAndLogvec[idxSet[i]:idxSet[i+1]]
        christoffel_sum = christoffelSum(input)
        gscore_est = model.estimate_score(input.cuda()) - christoffel_sum.cuda()
        if isinstance(model.noise_std, float):
            gscore_estDeriv = (model.get_derivative(input.cuda()) - torch.eye(dim).view(1,dim,dim).cuda()) \
            / model.noise_std**2
        else:
            gscore_estDeriv = (model.get_derivative(input.cuda()) - torch.eye(dim).view(1,dim,dim).cuda()) \
            / (model.noise_std**2).cuda().view(1,dim,1)
        # gscore_est: gscore est result of DAE, LSLDG
        
        #gae_score_est = model.estimate_score(input.cuda(), 
        #                                posMetric[idxSet[i]:idxSet[i+1]].cuda()).cpu()
        metricInv_sqrt = torch.zeros(N, dim, dim)
        metricInv_sqrt[:,:pos_dim,:pos_dim] = posMetricInv_sqrt[idxSet[i]:idxSet[i+1]]
        metricInv_sqrt[:,pos_dim:,pos_dim:] = (torch.eye(cov_dim) / \
                                               cov_metric_coeff_sqrt).view(1,cov_dim,cov_dim).expand(N,-1,-1)
        metricInv = torch.zeros(N, dim, dim)
        metricInv[:,:pos_dim,:pos_dim] = posMetricInv[idxSet[i]:idxSet[i+1]]
        metricInv[:,pos_dim:,pos_dim:] = (torch.eye(cov_dim) / \
                                               cov_metric_coeff_sqrt**2).view(1,cov_dim,cov_dim).expand(N,-1,-1)
        metricInvDeriv = torch.zeros(N, dim, dim, dim)
        metricInvDeriv[:,:pos_dim,:pos_dim] = posMetricInvDeriv[idxSet[i]:idxSet[i+1]]
        if weight is None:
            cur_weight = None
        else:
            cur_weight = weight[idxSet[i]:idxSet[i+1]].cuda()
        gscore_error_temp = estimate_gscore_error(gscore_est, gscore_estDeriv, metricInv.cuda(), 
                                                  metricInv_sqrt.cuda(), metricInvDeriv.cuda(), 
                          christoffelSum = christoffel_sum.cuda(), metricDeriv = None, diagonal_metric=False,
                                                 weight = cur_weight)
        
        #gae_estimate_score_error(input.cuda(), gae_score_est.cuda(), model, model.noise_std**2, 
        #                    metricInv_sqrt.cuda(), christoffelSum = christoffel_sum.cuda(), 
        #                    metricDeriv = None, metricInv = None, diagonal_metric=False)
        if printError:
            print(gscore_error_temp)
        gscore_error += gscore_error_temp * N
    gscore_error /= idxSet[-1]
    if printError:
        print(gscore_error)
    return gscore_error