import torch
from project_utils.model_utils import get_spn_mpe_output

l2_loss = torch.nn.MSELoss()


def supervised_loss(
    output_for_spn,
    output,
    query_bucket,
    library_spn,
    supervised_loss_lambda,
    device,
):
    outputs_for_spn_np = output_for_spn.detach().cpu().numpy()
    query_bucket_np = query_bucket.detach().cpu().numpy()
    mpe_outputs = get_spn_mpe_output(library_spn, outputs_for_spn_np, query_bucket_np)
    mpe_outputs = torch.tensor(mpe_outputs, device=device)
    query_spn_outputs = mpe_outputs[query_bucket]
    query_nn_outputs = output[query_bucket]
    supervised_loss_value = l2_loss(query_nn_outputs, query_spn_outputs)
    return supervised_loss_lambda * supervised_loss_value


def distance_loss(spn, output_for_spn, initial_data, buckets, evid_lambda):
    output_for_spn_distance = output_for_spn.clone()
    evidence_true = initial_data.detach().clone()

    nan_value = float("nan")
    for indices in [buckets["query"], buckets["unobs"]]:
        output_for_spn_distance[indices] = nan_value
        evidence_true[indices] = nan_value

    nn_evid_ll = spn.evaluate(output_for_spn_distance)
    with torch.no_grad():
        spn_evid_ll = spn.evaluate(evidence_true)

    evid_distance_loss = l2_loss(nn_evid_ll, spn_evid_ll)
    return evid_lambda * evid_distance_loss


def evid_loss(output, evid_bucket, initial_data, evid_lambda):
    evidence_output = output[evid_bucket]
    evidence_true = initial_data[evid_bucket]
    evid_loss_value = l2_loss(evidence_output, evidence_true)
    return evid_lambda * evid_loss_value


import torch


def entropy_loss_function(predictions, lambda_param=0.01, epsilon=1e-5):
    """
    The entropy_loss_function calculates the entropy loss of a set of predictions with an optional
    regularization term.

    """
    regularizer = -(
        predictions * torch.log(predictions + epsilon)
        + (1 - predictions) * torch.log(1 - predictions + epsilon)
    )
    regularizer = lambda_param * regularizer.mean()
    return regularizer


def entropy_loss(output, entropy_lambda):
    entropy_loss_value = entropy_loss_function(output, entropy_lambda)
    return entropy_loss_value


def calculate_kl_divergence(mu, log_var):
    return -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
