import math

import gpytorch
import torch
from torch.nn.functional import binary_cross_entropy


def compute_mae(prop_post, prop_gt):
    if prop_post.covariance_matrix.ndim == 3:
        return torch.abs(
            prop_post.mean.mean(dim=0) - prop_gt.squeeze()
            if prop_gt.ndim == 2
            else prop_gt
        ).mean()

    else:
        return torch.abs(
            prop_post.mean - prop_gt.squeeze() if prop_gt.ndim == 2 else prop_gt
        ).mean()

def compute_mse(prop_post, prop_gt):
    if prop_post.covariance_matrix.ndim == 3:
        return torch.square(
            prop_post.mean.mean(dim=0) - prop_gt.squeeze()
            if prop_gt.ndim == 2
            else prop_gt
        ).mean()

    else:
        return torch.square(
            prop_post.mean - prop_gt.squeeze() if prop_gt.ndim == 2 else prop_gt
        ).mean()


def compute_nlpd(prop_post, prop_gt):
    # Each dim of property posterior (M) is seen as one test observation
    if prop_post.covariance_matrix.ndim == 3:
        # NLPD for Gaussian mixture
        log_prob = -gpytorch.metrics.negative_log_predictive_density(
            prop_post, prop_gt.squeeze(dim=1) if prop_gt.ndim == 2 else prop_gt
        )
        return -(
            torch.logsumexp(log_prob, dim=0) - math.log(log_prob.numel())
        )  # Apply log-sum exp trick

    else:
        # Vanilla NLPD
        return gpytorch.metrics.negative_log_predictive_density(
            prop_post, prop_gt.squeeze(dim=1) if prop_gt.ndim == 2 else prop_gt
        )


def compute_ce_loss(predicted_prob, prop_gt_binary):
    return binary_cross_entropy(
        input=predicted_prob,
        target=prop_gt_binary,
        reduction="none",
    )


def compute_accuracy(predicted_prob, prop_gt_binary):
    return ((predicted_prob > 0.5).double() == prop_gt_binary).double()
