from math import log, pi
import torch
from torch.distributions import MultivariateNormal
import torch.nn as nn

C = - 0.5 * log(2 * pi)
import torch.nn.functional as F


def nll_loss(input, target, reduction='mean'):
    """Heteroscedastic negative log likelihood Normal.

    Parameters
    ----------
    input : torch.Tensor (n, 2)
        two natural parameters per data point
    target : torch.Tensor (n, 1)
        targets
    """
    assert input.ndim == target.ndim == 2
    assert input.shape[0] == target.shape[0]
    n, _ = input.shape
    var = torch.clamp(input[..., [1]], min=1e-4, max=10)
    term1 = 0.5 * torch.log(var)
    term2 = torch.div((target - input[..., [0]]) ** 2, 2 * var)
    log_lik = n * C - term1.sum() - term2.sum()

    if reduction == 'mean':
        return - log_lik / n
    elif reduction == 'sum':
        return - log_lik
    else:
        raise ValueError('Invalid reduction', reduction)


# def loss_nll_all(prediction, target, correlation, covariates, arr_muter, reduction='mean'):
#     arr_muter = arr_muter.bool()
#     loss_pred = nll_loss(prediction, target, reduction='mean')
#     # loss_covariates = 0
#     # for ith_subnet in range(covariates.shape[-1]):
#     #     for ith_cov in range(covariates.shape[-1]):
#     #         current_muter = arr_muter[..., ith_cov, ith_subnet]
#     #         current_loss = heteroscedastic_nll_loss(correlation[..., ith_cov, ith_subnet][current_muter], covariates[..., [ith_cov]][current_muter], reduction='mean')
#     #         loss_covariates += current_loss
#     #         if torch.isnan(current_loss):
#     #             print(1)
#     pred_feat_mean = correlation['mean']
#     pred_feat_covariance = correlation['covariance']
#     loss_covariates = multivariate_gaussian_nll(pred_feat_mean, pred_feat_covariance, covariates, arr_muter)
#     loss_correlation = loss_covariates #/ covariates.shape[-1] / covariates.shape[-1]
#     if torch.isnan(loss_pred + loss_correlation):
#         print(1)
#     #print(loss_pred + loss_correlation)
#     return loss_pred + loss_correlation, loss_pred, loss_correlation
#


def features_loss_mean(fnn_out):
    b, d = fnn_out.shape
    return (fnn_out ** 2).sum() / (b * d)


def features_loss_var(fnn_out):
    b, d = fnn_out.shape
    return (fnn_out ** 2).sum() / (b * d)


def loss_nll_and_corr_sc(
        dict_pred: dict,
        target: torch.Tensor,
        reg_config: dict = {},
        reduction: str = 'mean'):
    prediction = dict_pred.get("overall_mu_and_var", None)  # dict_pred["overall_mu_and_var"]
    correlation = dict_pred.get("dict_correlation", None)  # dict_pred["dict_correlation"]
    covariates = dict_pred.get("arr_covariates_ori", None)  # dict_pred["arr_covariates_ori"]
    arr_muter = dict_pred.get("arr_muter", None)  # dict_pred["arr_muter"]

    list_mu = dict_pred.get("list_mu", None)
    list_var = dict_pred.get("list_var", None)

    if reg_config is None:
        reg_config = {}

    loss_pred = nll_loss(prediction, target, reduction=reduction)
    # loss correlation
    pred_feat_mean = correlation['mean']
    pred_feat_covariance = correlation['covariance']
    pred_feat_var = correlation['variance']
    loss_correlation = multivariate_gaussian_nll(pred_feat_mean, pred_feat_covariance, covariates,
                                                 arr_muter) #+ (pred_feat_var**2).mean() * 1e-2 #reg_config.get('features_loss', 0.0)

    loss_total = loss_pred + loss_correlation
    # features_loss
    lambda_feat = reg_config.get('features_loss', 0.0)
    if lambda_feat > 0 and list_mu is not None and list_var is not None:
        fnn_out_mu = torch.cat(list_mu, dim=-1)
        fnn_out_var = torch.cat(list_var, dim=-1)
        loss_feat = features_loss_mean(fnn_out_mu) + features_loss_var(fnn_out_var)
        loss_total += lambda_feat * loss_feat
    else:
        loss_feat = torch.tensor(0.0, device=prediction.device)
    return loss_total, loss_pred, loss_correlation


def multivariate_gaussian_nll(mean, covariance, gt, arr_muter):
    arr_muter_ = torch.sum(arr_muter, dim=-1) == mean.shape[-1]
    loss = torch.tensor(0., device=mean.device)

    try:
        mvn = MultivariateNormal(loc=mean[arr_muter_],
                                 covariance_matrix=covariance[arr_muter_])
        loss += -mvn.log_prob(gt[arr_muter_]).mean()
    except:
        print('missing value for corr loss')
    return loss


class MultivariateGaussianNLLLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, pred, target):
        mean, cov = pred[:, :pred.shape[1] // 2], pred[:, pred.shape[1] // 2:]

        # Ensure positive definiteness of covariance matrix
        cov = torch.diag_embed(torch.exp(cov) + self.eps)

        diff = target - mean
        batch_size, dim = diff.shape

        # Log-determinant of covariance
        logdet = torch.log(torch.diagonal(cov, dim1=-2, dim2=-1)).sum(-1)

        # Mahalanobis distance
        mahalanobis = torch.einsum('bi,bij,bj->b', diff, cov.inverse(), diff)

        # Constant term
        const = dim * torch.log(torch.tensor(2 * torch.pi))

        # Final NLL
        nll = 0.5 * (const + logdet + mahalanobis)

        return nll.mean()

