import numpy as np
import torch

def list_to_cuda(data):
    newdata = []
    for d in data:
        newdata.append(d.cuda())
    return newdata

def list_to_cpu(data):
    newdata = []
    for d in data:
        newdata.append(d.cpu())
    return newdata

def gae_P_n_estimate_score(x, X_sqrt, metric, model):
    if model.autoencoder[0].weight.is_cuda:
        x = x.cuda()
        X_sqrt = X_sqrt.cuda()
    return model.estimate_score(x, X_sqrt, metric).data

def gae_P_n_estimate_score_error(x, X_sqrt, est, model, var, metricInv_sqrt, X_sqrt_dirderiv_set = None, 
                             christoffelSum = None, metricDeriv = None, metricInv = None, diagonal_metric=False,  
                                 other_quantities_at_x = None):
    if (metricDeriv is None or metricInv is None) and christoffelSum is None:
        raise Exception("At least one of either (metricDeriv and metricInv) or christoffelSum should not be None")
    
    # est: gscore est result of GDAE
    # ignore constant for term2
    N = x.shape[0]
    dim = x.shape[1]
    if model.autoencoder[0].weight.is_cuda:
        if other_quantities_at_x is not None:
            other_quantities_at_x = list_to_cuda(other_quantities_at_x)
        drdx = model.get_derivative(x.cuda(), X_sqrt.cuda(), X_sqrt_dirderiv_set = X_sqrt_dirderiv_set.cuda(), other_quantities_at_x = other_quantities_at_x).data.cpu()
    else:
        if other_quantities_at_x is not None:
            other_quantities_at_x = list_to_cpu(other_quantities_at_x)
        drdx = model.get_derivative(x, X_sqrt, X_sqrt_dirderiv_set = X_sqrt_dirderiv_set, other_quantities_at_x = other_quantities_at_x).data
    term2 = 2/var * torch.sum(drdx * torch.eye(dim))
    if diagonal_metric:
        temp = est * metricInv_sqrt
        if christoffelSum is not None:
            temp2 = 2.0 * metricInv_sqrt * christoffelSum
        else:
            temp2 = metricInv_sqrt * torch.sum(metricInv.view(N,dim,1)*metricDeriv, dim=1)
    else:
        temp = torch.bmm(est.view(N, 1, dim), metricInv_sqrt.permute(0,2,1)).view(N, dim)
        if christoffelSum is not None:
            temp2 = 2.0*torch.bmm(christoffelSum.view(N,1,dim), metricInv_sqrt.permute(0,2,1)).view(N,dim)
        else:
            temp2 = torch.zeros(N,dim).cuda()
            for i in range(dim):
                temp2[:,i] = torch.sum( torch.sum( torch.bmm(metricInv,
                                    metricDeriv[:,:,:,i]) * torch.eye(dim).cuda(), dim=2), dim=1)
            temp2 = torch.bmm(temp2.view(N,1,dim), metricInv_sqrt.permute(0,2,1)).view(N, dim)
    term1 = torch.sum(temp**2)
    term3 = torch.sum(temp*temp2)
    return (term1.item() + term2.item() + term3.item())/N - 2.0 * dim / var