"""
Utilities for processing a credal deep ensemble.
"""
import torch
from torch import nn
from torch.nn import functional as F
import torch.backends.cudnn as cudnn

from uncertainties.min_max_entropy import min_max_entropy_calculation
from uncertainties.gh_measure import compute_gh_measure

def reachable_probability_intervals(lower_probs, upper_probs):
    lower_probs_mod = torch.maximum(lower_probs, 1.0 - (torch.sum(upper_probs, dim=-1, keepdim=True) - upper_probs))
    upper_probs_mod = torch.minimum(upper_probs, 1.0 - (torch.sum(lower_probs, dim=-1, keepdim=True) - lower_probs))
    return lower_probs_mod, upper_probs_mod


def credal_ensemble_forward_pass(model_ensemble, data, uncertainty_type=None):
    
    lower_outputs = []
    upper_outputs = []
    
    for i, model in enumerate(model_ensemble):
        probs = model(data)
        # print('==============>>>>> SHAPE CHECK ===============>>>>>', probs.shape)
        # Extract upper and lower probability bounds
        lower_probs = probs[:, :probs.shape[-1]//2]
        upper_probs = probs[:, probs.shape[-1]//2:]

        # Make sure for reachability
        lower_probs_mod, upper_probs_mod = reachable_probability_intervals(lower_probs, upper_probs)

        lower_outputs.append(lower_probs_mod)
        upper_outputs.append(upper_probs_mod)
    
    lower_outputs = torch.stack(lower_outputs)
    upper_outputs = torch.stack(upper_outputs)
    
    # Average upper and lower probability over models
    lower_probs_avg = torch.mean(lower_outputs, dim=0)
    upper_probs_avg = torch.mean(upper_outputs, dim=0)

    probs_avg = torch.cat([lower_probs_avg, upper_probs_avg], dim=-1)
    if uncertainty_type == None:
        max_entropy = None
        diff_entropy = None
        gh = None
    elif uncertainty_type == 'GH':
        gh = compute_gh_measure(lower_probs_avg.cpu().numpy(), upper_probs_avg.cpu().numpy())
        max_entropy = None
        diff_entropy = None
    else:
        gh = None
        entropies = min_max_entropy_calculation(probs_avg.cpu().numpy())
        max_entropy = entropies['Hu']
        diff_entropy = entropies['EU']       
    
    return probs_avg, max_entropy, diff_entropy, gh 