import numpy as np
import torch
from torch.autograd import Variable

def dae_estimate_score(data, model):
    dataVariable = Variable(data)
    if model.autoencoder[0].weight.is_cuda:
        dataVariable = dataVariable.cuda()
    return model.estimate_score(dataVariable).data

def dae_estimate_score_from_recon(data, recon, var):
    diff = recon - data
    score_est = diff / var
    return score_est

def dae_estimate_score_error(data, est, model, var):
    # est: score est result of DAE
    N = data.shape[0]
    dim = data.shape[1]
    dataVariable = Variable(data, requires_grad = True)
    if model.autoencoder[0].weight.is_cuda:
        dataVariable = dataVariable.cuda()
        drdx = model.get_derivative(dataVariable).data.cpu()
    else:
        drdx = model.get_derivative(dataVariable).data
    term2 = 2/var * torch.sum(drdx * torch.eye(dim))
    term1 = torch.sum(est**2)
    return (term1.item() + term2.item())/N

def dae_estimate_score_deriv(data, model, var, force_cpu = True):
    dim = data.shape[1]
    dataVariable = Variable(data, requires_grad = True)
    if model.autoencoder[0].weight.is_cuda:
        dataVariable = dataVariable.cuda()
        drdx = model.get_derivative(dataVariable).data
        if force_cpu:
            drdx = drdx.cpu()
        else:
            return (drdx - torch.eye(dim).cuda()) / var
    else:
        drdx = model.get_derivative(dataVariable).data
    return (drdx - torch.eye(dim)) / var
    
def gae_estimate_score(data, metric, model):
    dataVariable = Variable(data)
    metricVariable = Variable(metric)
    if model.autoencoder[0].weight.is_cuda:
        dataVariable = dataVariable.cuda()
        metricVariable = metricVariable.cuda()
    return model.estimate_score(dataVariable, metricVariable).data

def gae_estimate_score_from_recon(data, recon, metric, var, diagonal_metric=False):
    diff = recon - data
    if diagonal_metric:
        score_est = (diff * metric) / var
    else:
        if type(data) is np.ndarray:
            N = data.shape[0]
            dim = data.shape[1]
            score_est = np.sum(diff.reshape((N,dim,1))*metric, axis=1) / var
        else:
            score_est = torch.bmm(diff.view(diff.size()[0], 1, diff.size()[1]), metric) \
        / var
    return score_est

def gae_reconstruction_numpy(data_numpy, model):
    dataVariable = Variable(torch.FloatTensor(data_numpy))
    use_gpu = False
    if model.autoencoder[0].weight.is_cuda:
        use_gpu = True
    if use_gpu:
        dataVariable = dataVariable.cuda()
        recon_numpy = model.clean_forward(dataVariable).data.cpu().numpy()
    else:
        recon_numpy = model.clean_forward(dataVariable).data.numpy()
        
    return recon_numpy

def gae_estimate_score_error(data, est, model, var, metricInv_sqrt, 
                             christoffelSum = None, metricDeriv = None, metricInv = None, diagonal_metric=False):
    if (metricDeriv is None or metricInv is None) and christoffelSum is None:
        print("At least one of either (metricDeriv and metricInv) or christoffelSum should not be None!!!")
    # est: gscore est result of GDAE
    # ignore constant for term2
    N = data.shape[0]
    dim = data.shape[1]
    
    if model.autoencoder[0].weight.is_cuda:
        dataVariable = Variable(data.cuda(), requires_grad = True)
        drdx = model.get_derivative(dataVariable).data.cpu()
    else:
        dataVariable = Variable(data, requires_grad = True)
        drdx = model.get_derivative(dataVariable).data
    term2 = 2/var * torch.sum(drdx * torch.eye(dim))
    if diagonal_metric:
        temp = est * metricInv_sqrt
        if christoffelSum is not None:
            temp2 = 2.0 * metricInv_sqrt * christoffelSum
        else:
            temp2 = metricInv_sqrt * torch.sum(metricInv.view(N,dim,1)*metricDeriv, dim=1)
    else:
        temp = torch.bmm(est.view(N, 1, dim), metricInv_sqrt.permute(0,2,1)).view(N, dim)
        if christoffelSum is not None:
            temp2 = 2.0*torch.bmm(christoffelSum.view(N,1,dim), metricInv_sqrt.permute(0,2,1)).view(N,dim)
        else:
            temp2 = torch.zeros(N,dim)
            for i in range(dim):
                temp2[:,i] = torch.sum( torch.sum( torch.bmm(metricInv,
                                    metricDeriv[:,:,:,i]) * torch.eye(dim), dim=2), dim=1)
            temp2 = torch.bmm(temp2.view(N,1,dim), metricInv_sqrt.permute(0,2,1)).view(N, dim)
    term1 = torch.sum(temp**2)
    term3 = torch.sum(temp*temp2)
    return (term1.item() + term2.item() + term3.item())/N - 2.0 * dim / var

def estimate_gscore_error(est, estDeriv, metricInv, metricInv_sqrt, metricInvDeriv, 
                          christoffelSum = None, metricDeriv = None, diagonal_metric=False):
    if metricDeriv is None and christoffelSum is None:
        print("At least one of either metricDeriv or christoffelSum should not be None!!!")
    # est: gscore est result of DAE, LSLDG
    N = est.shape[0]
    dim = est.shape[1]
    Eye = torch.eye(dim)
    if est.is_cuda:
        Eye = Eye.cuda()
    if diagonal_metric:
        temp = est * metricInv_sqrt
        if christoffelSum is not None:
            temp2 = 2.0 * metricInv_sqrt * christoffelSum
        else:
            temp2 = metricInv_sqrt * torch.sum(metricInv.view(N,dim,1)*metricDeriv, dim=1)
        term2_1 = 2*torch.sum(torch.bmm(metricInv.view(N,1,dim), estDeriv * Eye))
        term2_2 = 2*torch.sum(torch.bmm(est.view(N,1,dim), metricInvDeriv * Eye))
    else:
        temp = torch.bmm(est.view(N, 1, dim), metricInv_sqrt.permute(0,2,1)).view(N, dim)
        if christoffelSum is not None:
            temp2 = 2.0*torch.bmm(christoffelSum.view(N,1,dim), metricInv_sqrt.permute(0,2,1)).view(N,dim)
        else:
            temp2 = torch.zeros(N,dim)
            for i in range(dim):
                temp2[:,i] = torch.sum( torch.sum( torch.bmm(metricInv,
                                    metricDeriv[:,:,:,i]) * Eye, dim=2), dim=1)
            temp2 = torch.bmm(temp2.view(N,1,dim), metricInv_sqrt.permute(0,2,1)).view(N, dim)
        term2_1 = 2.0*torch.sum(estDeriv * metricInv)
        term2_2 = 2.0*torch.sum(est * (metricInvDeriv * \
                                       Eye.view(1,1,dim,dim).expand(N,dim,-1,-1)).sum((2,3)))
    term1 = torch.sum(temp**2)
    term3 = torch.sum(temp*temp2)
    return (term1.item() + term2_1.item() + term2_2.item() + term3.item())/N

def compareScores(true_score, est_score, metricInv_sqrt, pointwise=False):
    if len(metricInv_sqrt.shape) == 2:
        diff = (true_score - est_score) * metricInv_sqrt
        diffangle = torch.sum(true_score*est_score*metricInv_sqrt*metricInv_sqrt, dim = 1) \
        / torch.sqrt(torch.sum((true_score*metricInv_sqrt)**2, dim = 1)) \
        / torch.sqrt(torch.sum((est_score*metricInv_sqrt)**2, dim = 1))
    else:
        diff = torch.bmm((true_score - est_score).view(-1,1,true_score.shape[1]), metricInv_sqrt).view(-1,true_score.shape[1])
        temp1 = torch.bmm(true_score.view(-1,1,true_score.shape[1]), metricInv_sqrt)
        temp2 = torch.bmm(est_score.view(-1,1,est_score.shape[1]), metricInv_sqrt)
        diffangle = (temp1*temp2).sum((1,2)) \
        / torch.sqrt(((temp1)**2).sum((1,2))) \
        / torch.sqrt(((temp2)**2).sum((1,2)))
    
    diffnorm = torch.sum(diff*diff, dim = 1)
    diffangle[diffangle>1] = 1
    diffangle[diffangle<-1] = -1
    diffangle = torch.acos(diffangle)
    
    if pointwise:
        return diffnorm, diffangle
    return torch.mean(diffnorm), torch.mean(diffangle[~torch.isnan(diffangle)])