import torch

def calculate_uncertainties(
    sample_preds: torch.Tensor,
    reference_net_pred: torch.Tensor = None,
    sample_weights: torch.Tensor = None,
    gamma: float = 1e-10
):
    '''
    Calculating all uncertainties at once is much more efficient. 
    n_samples refers to number of different networks, which are samples from the posterior.
    :param reference_net_pred:
    [n_points, n_classes]
    :param sample_preds:
    [n_points, n_samples, n_classes]
    :param sample_weights:
    [n_points, n_samples]
    : param gamma:
    Small additive value for numerical stability
    :return: dict with uncertainties for all settings
    '''

    total_single, aleatoric_single, epistemic_single = None, None, None

    if sample_weights is None:
        if reference_net_pred is not None:
            total_single = - torch.mean(torch.sum((reference_net_pred.unsqueeze(1) * torch.log(sample_preds + gamma)), dim=-1), dim=1)
            aleatoric_single = - torch.sum( (reference_net_pred + gamma) * torch.log(reference_net_pred + gamma), dim=-1)

        avg_preds = torch.mean(sample_preds, dim=1)  # BMA
        # [n_points, n_classes]
        aleatoric_mi = - torch.mean(torch.sum(sample_preds * torch.log(sample_preds + gamma), dim=-1), dim=1)

        epistemic_rmi = (avg_preds.unsqueeze(1) * (torch.log(avg_preds.unsqueeze(1) + gamma) - torch.log(sample_preds + gamma))).sum(dim=-1).mean(dim=-1)
    
    else:
        if reference_net_pred is not None:
            total_single = - torch.sum(torch.sum((reference_net_pred.unsqueeze(1) * torch.log(sample_preds + gamma)), dim=-1) * sample_weights, dim=1)
            aleatoric_single = - torch.sum( (reference_net_pred + gamma) * torch.log(reference_net_pred + gamma), dim=-1)

        avg_preds = torch.sum(sample_preds * sample_weights.unsqueeze(2), dim=1)  # BMA
        # [n_points, n_classes]
        aleatoric_mi = - torch.sum(sample_preds * torch.log(sample_preds + gamma), dim=-1)
        aleatoric_mi = torch.sum(aleatoric_mi * sample_weights, dim=1)

        epistemic_rmi = torch.sum(avg_preds.unsqueeze(1) * (torch.log(avg_preds.unsqueeze(1) + gamma) - torch.log(sample_preds + gamma)), dim=-1)
        # [n_points, n_samples]
        epistemic_rmi = torch.sum(epistemic_rmi * sample_weights, dim=1)
        
    total_mi = - torch.sum(avg_preds * torch.log(avg_preds + gamma), dim=-1)  # H[BMA]

    if reference_net_pred is not None:
        epistemic_single = total_single - aleatoric_single

    return {
        "mi":
        {
            'total': total_mi,
            'aleatoric': aleatoric_mi,
            'epistemic': total_mi - aleatoric_mi,
        },
        "rmi":
        {
            'total': total_mi + epistemic_rmi,
            'aleatoric': total_mi,
            'epistemic': epistemic_rmi,
        },
        "epkl":
        {
            'total': total_mi + epistemic_rmi,
            'aleatoric': aleatoric_mi,
            'epistemic': (total_mi - aleatoric_mi) + epistemic_rmi,
        },
        "single":
        {
            'total': total_single,
            'aleatoric': aleatoric_single,
            'epistemic': epistemic_single,
        },
    }