"""
Descprition: Code for conducting confidence calibration
"""
"""
[Table of Contents]
* [1] Uncertainty Measures
* [2] Codes for F-EDL
* [3] Codes for EDL Methods (EDL, I-EDL, R-EDL)
* [4] Codes for Softmax and Dropout 
* [5] Codes for DDU
* [6] Codes for DAEDL 
"""

import math
import numpy as np
import numpy.random as npr
import pandas as pd
import scipy
import matplotlib.pyplot as plt
from tqdm import tqdm  
import sklearn 
import torch

from utils.ensemble_utils import ensemble_forward_pass
from metrics.classification_metrics import get_logits_labels, get_logits_labels2
from metrics.uncertainty_confidence import entropy, logsumexp, confidence
import warnings
warnings.filterwarnings('ignore')

import density_estimation
from density_estimation import *

################################################################################################
### [1] Utility functions & uncertainty measures
def compute_mu(alpha, p, tau):
    alpha0 = alpha.sum(dim=1, keepdim=True)  
    mu = (alpha + tau * p) / (alpha0 + tau)
  
    return mu
    
def compute_var(alpha, p, tau):
    alpha0 = alpha.sum(dim=1, keepdim=True) 
    
    mu = (alpha + tau * p) / (alpha0 + tau)
    term1 = mu * (1 - mu) / (alpha0 + tau + 1)
    term2 = (tau**2) * p * (1 - p) / ((alpha0 + tau) * (alpha0 + tau + 1))
    var = term1 + term2
        
        
    return var
    
def get_mean_var(model, loader, fix_tau, fix_p, device):
    MU, VAR, ALPHA, P, TAU = [], [], [], [], []
    
    with torch.no_grad():
        for i, (x_t, y_t) in enumerate(loader):
            alpha_t, p_t, tau_t = model(x_t.to(device), fix_tau, fix_p)

            mu_t = compute_mu(alpha_t, p_t, tau_t)
            var_t = compute_var(alpha_t, p_t, tau_t)

            MU.append(mu_t)
            VAR.append(var_t)
            
            P.append(p_t)
            ALPHA.append(alpha_t)
            TAU.append(tau_t)
                
    MU, VAR = torch.cat(MU).cpu().numpy(), torch.cat(VAR).cpu().numpy()
    ALPHA, P, TAU = torch.cat(ALPHA).cpu().numpy(), torch.cat(P).cpu().numpy(), torch.cat(TAU).cpu().numpy()
    
    return MU, VAR, ALPHA, P, TAU

def TU(mu,var):
    return 1 - (mu**2).sum(1)

def AU(mu,var):
    return TU(mu, var) - DU(mu, var)

def DU(mu, var):
    return var.sum(1)

def EU(mu, var):
    return 1-mu.max(1)


### [2] Confidence calibration for F-EDL
def conf_calibration_fedl(model, testloader, fix_tau, fix_p, device): 
    cnt = 0
    
    Y = []
    MU = []
    CORRECT = []
    brier_scores = []

    with torch.no_grad():
        for i, (x,y) in enumerate(testloader):              
            x,y = x.to(device), y.to(device) 
            alpha, p, tau = model(x)

            mu = compute_mu(alpha, p, tau)

            y_pred = mu.argmax(1).detach().cpu().numpy()
            correct = (y_pred == y.detach().cpu().numpy()).astype(int)
            
            y_oh = F.one_hot(y, num_classes=p.shape[1]).to(device)

            brier_score_batch = torch.mean((y_oh - mu) ** 2, dim=1)  # Mean over each sample
            brier_scores.extend(brier_score_batch.cpu().numpy())  
            
            CORRECT.append(correct)
            
    BRIER = np.mean(brier_scores)
    CORRECT = np.concatenate(CORRECT)

    MU_id, VAR_id, ALPHA_id, P_id, TAU_id = get_mean_var(model, testloader, fix_tau, fix_p, device)
    
    alea_id = -AU(MU_id, VAR_id)
    epis_id = -EU(MU_id, VAR_id)

    alea_id = check(alea_id)
    epis_id = check(epis_id)

    aupr_alea =  sklearn.metrics.average_precision_score(CORRECT, alea_id)
    aupr_epis =  sklearn.metrics.average_precision_score(CORRECT, epis_id)

    auroc_alea = sklearn.metrics.roc_auc_score(CORRECT, alea_id)
    auroc_epis = sklearn.metrics.roc_auc_score(CORRECT, epis_id)

    AUROC = {}
    AUPR = {}

    AUROC["AU"] = auroc_alea
    AUROC["EU"] = auroc_epis

    AUPR["AU"] = aupr_alea
    AUPR["EU"] = aupr_epis

    return AUROC, AUPR, BRIER


### [3] Confidence calibration for EDL Methods
### [3-1] Code for getting the parameters
def get_alpha(model, edl_type, trainloader, loader, lamb, device):
   
    if edl_type == "DAEDL":
        num_classes = 10
        embedding_dim = 512
    
        gda, p_z_train = fit_gda(model, trainloader, num_classes, embedding_dim, device)
        logits_edl, _ = gmm_evaluate(model, gda, loader, device, num_classes, device)
    
        p_z = check(torch.logsumexp(logits_edl, dim = -1)) 
        p_z_train_min, p_z_train_max = p_z_train.min(), p_z_train.max()

        p_z[p_z < p_z_train_min] = p_z_train_min
        p_z[p_z > p_z_train_max] = p_z_train_max

        p_z = (p_z - p_z_train_min) / (p_z_train_max - p_z_train_min)
    
    ALPHA = []
    with torch.no_grad():
        for i, (x,y) in enumerate(loader):
            logits_t = model(x.to(device))
            
            if edl_type == "I-EDL":
                alpha_t = 1 + F.softplus(logits_t)
                
            if edl_type == "EDL":
                alpha_t = 1 + torch.relu(logits_t)
                    
            if edl_type == "DAEDL":
                alpha_t = 1e-6 + torch.exp(logits_t)
        
            if edl_type == "R-EDL":
                alpha_t = 1e-6 + F.softplus(logits_t) + lamb
            
            ALPHA.append(alpha_t)
            
    ALPHA = torch.cat(ALPHA)
    
    if edl_type == "DAEDL":
        ALPHA = torch.exp(torch.log(ALPHA) * p_z.reshape(-1,1)).cpu().numpy()
        
    else : 
        ALPHA = ALPHA.cpu().numpy()
        
    return ALPHA

### [3-2] Code for conducting confidence calibration
def conf_calibration_edl(model, edl_type, trainloader, testloader, lamb, device): 
    
    Y = []
    CORRECT = []

    brier_scores = []
    
    with torch.no_grad():
        for i, (x,y) in enumerate(testloader):    
            
            logits_t = model(x.to(device))
            if edl_type == "I-EDL":
                alpha_t = 1 + F.softplus(logits_t)
                
            if edl_type == "EDL":
                alpha_t = 1 + torch.relu(logits_t)
                    
            if edl_type == "R-EDL":
                alpha_t = 1e-6 + F.softplus(logits_t) + lamb
                
            PI = alpha_t / alpha_t.sum(1, keepdim=True)
            
            y_oh = F.one_hot(y, num_classes=alpha_t.shape[1]).to(device)
            
            brier_score_batch = torch.mean((y_oh - PI) ** 2, dim=1)  
            brier_scores.extend(brier_score_batch.cpu().numpy()) 
                
            correct = (logits_t.argmax(1).detach().cpu().numpy() == y.detach().cpu().numpy()).astype(int)
            CORRECT.append(correct)
                      
    BRIER = np.mean(brier_scores)
    
    CORRECT = np.concatenate(CORRECT) 
    ALPHA = get_alpha(model, edl_type, trainloader, testloader, lamb, device)
   
    alea_id = (ALPHA / ALPHA.sum(1).reshape(-1,1)).max(1)
    epis_id = ALPHA.sum(1)
  
    aupr_alea =  sklearn.metrics.average_precision_score(CORRECT, alea_id)
    aupr_epis =  sklearn.metrics.average_precision_score(CORRECT, epis_id)
    auroc_alea = sklearn.metrics.roc_auc_score(CORRECT, alea_id)
    auroc_epis = sklearn.metrics.roc_auc_score(CORRECT, epis_id)

    
    AUROC = {}
    AUPR = {}
    AUROC["AU"] = auroc_alea
    AUROC["EU"] = auroc_epis
    AUPR["AU"] = aupr_alea
    AUPR["EU"] = aupr_epis

    
    return AUROC, AUPR, BRIER


#############################################################################################
### [4] Confidence calibration for Softmax models (Softmax, Dropout, DDU)
def conf_calibration_softmax(model, trainloader, testloader, sm_type, num_passes, device):
    Y = []
    CORRECT = []
    brier_scores = []
    mean_probs_list = []  
    var_probs_list = []   

    if sm_type == "Dropout":
        model.train()  
    else:
        model.eval()

    with torch.no_grad():
        for x, y in testloader:
            x = x.to(device)
            y = y.to(device)

            if sm_type == "Dropout":
                all_probs = []
                for _ in range(num_passes):
                    logits_t = model(x)
                    probs = F.softmax(logits_t, dim=1)
                    all_probs.append(probs)

                all_probs = torch.stack(all_probs)
                mean_probs = all_probs.mean(dim=0)
                var_probs = all_probs.var(dim=0).sum(dim=1)
            else:
                logits_t = model(x)
                mean_probs = F.softmax(logits_t, dim=1)
                var_probs = torch.zeros(mean_probs.size(0)).to(device)

            y_oh = F.one_hot(y, num_classes=mean_probs.shape[1]).to(device)
            brier_score_batch = torch.mean((y_oh - mean_probs) ** 2, dim=1)
            brier_scores.extend(brier_score_batch.cpu().numpy())

            correct = (mean_probs.argmax(1).detach().cpu().numpy() == y.detach().cpu().numpy()).astype(int)
            CORRECT.append(correct)


            mean_probs_list.append(mean_probs)
            var_probs_list.append(var_probs)

    BRIER = np.mean(brier_scores)
    CORRECT = np.concatenate(CORRECT)
    mean_probs = torch.cat(mean_probs_list, dim=0).cpu().numpy()
    var_probs = torch.cat(var_probs_list, dim=0).cpu().numpy()

    alea_id = mean_probs.max(1)
    epis_id = -var_probs.astype(np.float64)

    aupr_alea = sklearn.metrics.average_precision_score(CORRECT, alea_id)
    auroc_alea = sklearn.metrics.roc_auc_score(CORRECT, alea_id)
    aupr_epis = sklearn.metrics.average_precision_score(CORRECT, epis_id) if sm_type == "Dropout" else None
    auroc_epis = sklearn.metrics.roc_auc_score(CORRECT, epis_id) if sm_type == "Dropout" else None

    AUROC = {"AU": auroc_alea, "EU": auroc_epis}
    AUPR = {"AU": aupr_alea, "EU": aupr_epis}

    return AUROC, AUPR, BRIER


################################################################################################################################################################[5] Confidence calibration for DAEDL
def conf_calibration_daedl(model, gda, p_z_train, testloader, num_classes, device):
    
    brier_scores = []

    Y = []
    PI = []
    ALPHA = []
    
    d_min, d_max = p_z_train.min(), p_z_train.max()

    with torch.no_grad():
        for i,(x,y) in enumerate(tqdm(testloader)):
            x,y = x.to(device), y.to(device)
            
            z = model(x)

            p_z = torch.logsumexp(gmm_forward(model, gda, x), dim = -1)

            p_z[p_z < d_min.to(p_z.dtype)] = d_min.to(p_z.dtype)
            p_z_norm = (p_z - d_min.to(p_z.dtype)) / (d_max.to(p_z.dtype)-d_min.to(p_z.dtype))
                                            
            alpha = torch.exp(z * p_z_norm.reshape(-1,1))
            pi = alpha / alpha.sum(1).reshape(-1,1)
   
            y_oh = F.one_hot(y, num_classes=alpha.shape[1]).to(device)
    
            brier_score_batch = torch.mean((y_oh - pi) ** 2, dim=1)  # Mean over each sample
            brier_scores.extend(brier_score_batch.cpu().numpy())  
    
            Y.append(y)
            ALPHA.append(alpha)     
            PI.append(pi)
       
    BRIER = np.mean(brier_scores)

    labels = torch.cat(Y)
    prob = torch.cat(PI)
    alpha = torch.cat(ALPHA)
    
    correct = torch.tensor(prob.argmax(1) == labels).cpu().numpy()
    scores_alea = prob.max(1).values.cpu().numpy()   
    scores_epis = alpha.max(1).values.cpu().numpy()
    
    aupr_alea =  sklearn.metrics.average_precision_score(correct, scores_alea)
    aupr_epis =  sklearn.metrics.average_precision_score(correct, scores_epis)
    
    auroc_alea = sklearn.metrics.roc_auc_score(correct, scores_alea)
    auroc_epis = sklearn.metrics.roc_auc_score(correct, scores_epis)
    
    AUROC = {}
    AUPR = {}
    
    AUROC["AU"] = auroc_alea
    AUROC["EU"] = auroc_epis
    AUPR["AU"] = aupr_alea
    AUPR["EU"] = aupr_epis

    return AUROC, AUPR, BRIER