import numpy as np
import torch

def gae_sph_n_amb_estimate_score(data, model, dx_dxth = None):
    if model.autoencoder[0].weight.is_cuda:
        data = data.cuda()
    return model.estimate_score(data, inCoord = True, dx_dxth = dx_dxth).data

def gae_sph_n_amb_estimate_score_error(data, est, model, metricInv_sqrt, 
                             christoffelSum, dx_dxth = None):
    
    # est: gscore est result of GDAE
    # ignore constant for term2
    N = data.shape[0]
    dim = data.shape[1] - 1
    if model.autoencoder[0].weight.is_cuda:
        drdx = model.get_derivative(data.cuda(), inCoord = True, dx_dxth = dx_dxth.cuda()).data.cpu()
    else:
        drdx = model.get_derivative(data, inCoord = True, dx_dxth = dx_dxth).data
    term2 = 2/model.noise_std**2 * torch.sum(drdx * torch.eye(dim))
    
    temp = est * metricInv_sqrt
    temp2 = 2.0 * metricInv_sqrt * christoffelSum
    
    term1 = torch.sum(temp**2)
    term3 = torch.sum(temp*temp2)
    return (term1.item() + term2.item() + term3.item())/N - 2.0 * dim / model.noise_std**2

def gae_sph_n_amb_estimate_score_error2(data, est, model, var, metricInv_sqrt, 
                             christoffelSum = None, metricDeriv = None, metricInv = None, diagonal_metric=False):
    if (metricDeriv is None or metricInv is None) and christoffelSum is None:
        print("At least one of either (metricDeriv and metricInv) or christoffelSum should not be None!!!")
    
    # est: gscore est result of GDAE
    # ignore constant for term2
    N = data.shape[0]
    dim = data.shape[1] - 1
    if model.autoencoder[0].weight.is_cuda:
        drdx = model.get_derivative(data.cuda(), inCoord = True).data.cpu()
    else:
        drdx = model.get_derivative(data, inCoord = True).data
    term2 = 2/var * torch.sum(drdx * torch.eye(dim))
    if diagonal_metric:
        temp = est * metricInv_sqrt
        if christoffelSum is not None:
            temp2 = 2.0 * metricInv_sqrt * christoffelSum
        else:
            temp2 = metricInv_sqrt * torch.sum(metricInv.view(N,dim,1)*metricDeriv, dim=1)
    else:
        temp = torch.bmm(est.view(N, 1, dim), metricInv_sqrt.permute(0,2,1)).view(N, dim)
        if christoffelSum is not None:
            temp2 = 2.0*torch.bmm(christoffelSum.view(N,1,dim), metricInv_sqrt.permute(0,2,1)).view(N,dim)
        else:
            temp2 = torch.zeros(N,dim)
            for i in range(dim):
                temp2[:,i] = torch.sum( torch.sum( torch.bmm(metricInv,
                                    metricDeriv[:,:,:,i]) * torch.eye(dim), dim=2), dim=1)
            temp2 = torch.bmm(temp2.view(N,1,dim), metricInv_sqrt.permute(0,2,1)).view(N, dim)
    term1 = torch.sum(temp**2)
    term3 = torch.sum(temp*temp2)
    return (term1.item() + term2.item() + term3.item())/N - 2.0 * dim / var