"""
Description: Code for conducting OOD detection and distribution shift detection
"""
"""
[Table of Contents]
* [1] Uncertainty Measures
* [2] Codes for F-EDL
* [3] Codes for EDL Methods (EDL, I-EDL, R-EDL)
* [4] Codes for Softmax and Dropout 
* [5] Codes for DDU
* [6] Codes for DAEDL 
"""
import os
import math
import numpy as np
import numpy.random as npr
import pandas as pd
import scipy
import matplotlib.pyplot as plt
from tqdm import tqdm  

import sklearn 
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors

from scipy.stats import beta
from scipy.stats import dirichlet
from scipy.special import gammaln
from scipy.special import digamma
from scipy.stats import multivariate_normal as mvn

from torch.distributions.dirichlet import Dirichlet
from torch.distributions.kl import kl_divergence as kl_div

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.kl import kl_divergence as kl_div
from torch.nn.utils import spectral_norm
import torchvision
import torchvision.transforms as transforms

from utils.ensemble_utils import ensemble_forward_pass
from metrics.classification_metrics import get_logits_labels, get_logits_labels2
from metrics.uncertainty_confidence import entropy, logsumexp, confidence

import warnings
warnings.filterwarnings('ignore')

import utility, density_estimation
from utility import *
from density_estimation import *


############### [1] Uncertainty Measures

def logsumexp(logits):
    return torch.logsumexp(logits, dim=1, keepdim=False)

def identity(x):
    return x

def check(x):
    if isinstance(x, np.ndarray):
        x_tensor = torch.tensor(x)
    else:
        x_tensor = x

    nan = torch.sum(torch.isnan(x_tensor))
    inf = torch.sum(torch.isinf(x_tensor))
    
    if (inf + nan) != 0:
        x_tensor = torch.nan_to_num(x_tensor)

    return x_tensor

def entropy(logits, densities):
    p = F.softmax(logits, dim=1)
    logp = F.log_softmax(logits, dim=1)
    total_unc = -torch.sum(p * logp, dim=1)   
    return check(total_unc)

def entropy_density(logits, densities):
    
    weighted_logit = logits * densities.reshape(-1,1)
    p = F.softmax(weighted_logit, dim=1)
    logp = F.log_softmax(weighted_logit, dim=1)
    total_unc = -torch.sum(p * logp, dim=1)   
    return check(total_unc)

def alea_unc_density(logits, densities):

    weighted_logit = logits * densities.reshape(-1,1)
    p_star = F.softmax(weighted_logit, dim = 1)
    
    alpha_star = torch.exp(weighted_logit)
    alpha0_star = torch.sum(alpha_star, dim = 1)
    
    a = torch.digamma(alpha_star + 1) - torch.digamma(alpha0_star + 1).reshape(-1,1)
    alea_unc = -torch.sum(p_star * a, dim =1)   
    return check(alea_unc)

def maxP_density(logits, densities):
    weighted_logit = logits * densities.reshape(-1,1)        
    p = F.softmax(weighted_logit , dim = 1)
    max_p = p.max(1).values

    return check(max_p)

def identity_density(logits, densities):
    return densities

def max_alpha_density(logits, densities): 

    weighted_logit = logits * densities.reshape(-1,1)             
    alpha_star =  1e-6 + torch.exp(weighted_logit)                 
    max_alpha = alpha_star.max(1).values

    return check(max_alpha)

def alpha0_density(logits, densities):
    
    weighted_logit = logits * densities.reshape(-1,1)
    alpha_star = 1e-6 + torch.exp(weighted_logit)
    alpha0 = torch.sum(alpha_star, dim = 1) 
    
    return check(alpha0)


def dist_unc_density(logits, densities):
    return check(entropy_density(logits, densities) - alea_unc_density(logits, densities))

def max_logits_density(logits, densities):
    
    weighted_logits = logits * densities.reshape(-1,1)

    return check(weighted_logits.max(1).values)


def logsumexp(logits):
    return torch.logsumexp(logits, dim=1, keepdim=False)

def identity(x):
    return x


def TU(logits, densities):
    alpha = torch.exp(logits * densities.reshape(-1,1))
    alpha0 = alpha.sum(1)
    alpha_l2 = (alpha ** 2).sum(1)
    return 1 - (alpha_l2 / (alpha0 ** 2))

def EU(logits, densities):
    alpha = torch.exp(logits * densities.reshape(-1,1))
    alpha0 = alpha.sum(1)
    K = logits.shape[1]
    
    return K/alpha0

def DU(logits, densities):
    alpha = torch.exp(logits * densities.reshape(-1,1))
    alpha0 = alpha.sum(1)
    alpha_l2 = (alpha**2).sum(1)
    
    return (alpha0 ** 2 - alpha_l2) / (alpha0 ** 2 * (alpha0 ** 2 + 1))

def AU(logits, densities):
    alpha = torch.exp(logits * densities.reshape(-1,1))
    return TU(logits, densities) - DU(logits, densities)


def pred_entropy_density(logits, densities): 
    weighted_logits = logits * densities.reshape(-1,1)
    p = F.softmax(weighted_logits, dim=1)
    logp = F.log_softmax(weighted_logits, dim=1) 
    pred_unc = -torch.sum(p * logp, dim=1)  
    
    return check(pred_unc)

def maxP_density(logits, densities):
    weighted_logit = logits * densities.reshape(-1,1)        
    p = F.softmax(weighted_logit , dim = 1)
    max_p = p.max(1).values

    return check(max_p)

def alea_unc_density(logits, densities):

    weighted_logit = logits * densities.reshape(-1,1)
    p_star = F.softmax(weighted_logit, dim = 1)
    
    alpha_star = torch.exp(weighted_logit)
    alpha0_star = torch.sum(alpha_star, dim = 1)
    
    a = torch.digamma(alpha_star + 1) - torch.digamma(alpha0_star + 1).reshape(-1,1)
    alea_unc = -torch.sum(p_star * a, dim =1)   
    return check(alea_unc)


def logit_entropy_density(logits, densities):
    alpha = torch.exp(logits * densities.reshape(-1,1))
    logit_ent = alpha + torch.lgamma(alpha) +  (1-alpha) * torch.digamma(alpha)
    
    return -torch.sum(logit_ent, dim = 1)

def sample_variance_density(logits, densities):
    alpha = torch.exp(logits * densities.reshape(-1,1))
    var = torch.mean(alpha - alpha.mean(1).reshape(-1,1), dim = 1)
    return var

def max_alpha_density(logits, densities): 

    weighted_logit = logits * densities.reshape(-1,1)             
    alpha_star =  1e-6 + torch.exp(weighted_logit)                 
    max_alpha = alpha_star.max(1).values

    return check(max_alpha)


def EU(mu, var):
    return np.sum(var, axis = 1)


def TU(mu, var):
    return 1-np.sum(mu ** 2, axis = 1)


def AU(mu, var):
    return TU(mu, var) - EU(mu, var)

########################################################################################################################################################################
# [2] Code for F-EDL
# [2-1] Code for Getting the mean, variance, and the parameter of the FD distribution

def compute_mu(alpha, p, tau):
    alpha0 = alpha.sum(dim=1, keepdim=True)  
    mu = (alpha + tau * p) / (alpha0 + tau)

    return mu

def compute_var(alpha, p, tau):
    alpha0 = alpha.sum(dim=1, keepdim=True) 
    
    mu = compute_mu(alpha, p, tau)
    term1 = mu * (1 - mu) / (alpha0 + tau + 1)
    term2 = (tau**2) * p * (1 - p) / ((alpha0 + tau) * (alpha0 + tau + 1))
    var = term1 + term2
        
    return var

def get_mean_var(model, loader, fix_tau, fix_p, device):
    MU, VAR, ALPHA, P, TAU = [], [], [], [], []
    
    with torch.no_grad():
        for i, (x_t, y_t) in enumerate(loader):
            alpha_t, p_t, tau_t = model(x_t.to(device), fix_tau, fix_p)

            mu_t = compute_mu(alpha_t, p_t, tau_t)
            var_t = compute_var(alpha_t, p_t, tau_t)

            MU.append(mu_t)
            VAR.append(var_t)
            
            P.append(p_t)
            ALPHA.append(alpha_t)
            TAU.append(tau_t)
                
    MU, VAR = torch.cat(MU).cpu().numpy(), torch.cat(VAR).cpu().numpy()
    ALPHA, P, TAU = torch.cat(ALPHA).cpu().numpy(), torch.cat(P).cpu().numpy(), torch.cat(TAU).cpu().numpy()
    
    return MU, VAR, ALPHA, P, TAU

### [2-2] Code for obtaining AUROC, AUPR scores
def auroc_aupr_fedl(unc_id, unc_ood):
    unc_id = check(unc_id)
    unc_ood = check(unc_ood)
    bin_labels = np.concatenate([np.ones(unc_id.shape[0]), np.zeros(unc_ood.shape[0])])
    scores = -np.concatenate((unc_id, unc_ood))

    auroc = sklearn.metrics.roc_auc_score(bin_labels, scores)
    aupr = sklearn.metrics.average_precision_score(bin_labels, scores)
    
    return auroc, aupr

### [2-3] Code for conducting OOD detection
def ood_detection_fedl(model, testloader, ood_loader1, ood_loader2, fix_tau, fix_p, device):
    MU_id, VAR_id, ALPHA_id, P_id, TAU_id = get_mean_var(model, testloader, fix_tau, fix_p, device)
    MU_ood1, VAR_ood1, ALPHA_ood1, P_ood1, TAU_ood1 = get_mean_var(model, ood_loader1, fix_tau, fix_p, device)
    MU_ood2, VAR_ood2, ALPHA_ood2, P_ood2, TAU_ood2 = get_mean_var(model, ood_loader2, fix_tau, fix_p, device)

    alea_id = AU(MU_id, VAR_id)
    alea_ood1 = AU(MU_ood1, VAR_ood1)
    alea_ood2 = AU(MU_ood2, VAR_ood2)

    epis_id = EU(MU_id, VAR_id)
    epis_ood1 = EU(MU_ood1, VAR_ood1)
    epis_ood2 = EU(MU_ood2, VAR_ood2)

    auroc_alea1, aupr_alea1 = auroc_aupr_fedl(alea_id, alea_ood1)
    auroc_alea2, aupr_alea2 = auroc_aupr_fedl(alea_id, alea_ood2)

    auroc_epis1, aupr_epis1 = auroc_aupr_fedl(epis_id, epis_ood1)
    auroc_epis2, aupr_epis2 = auroc_aupr_fedl(epis_id, epis_ood2)

    OOD1 = {}
    OOD1["AU"] = auroc_alea1
    OOD1["EU"] = auroc_epis1

    OOD2 = {}
    OOD2["AU"] = auroc_alea2
    OOD2["EU"] = auroc_epis2

    OOD1_ = {}
    OOD1_["AU"] = aupr_alea1
    OOD1_["EU"] = aupr_epis1

    OOD2_ = {}
    OOD2_["AU"] = aupr_alea2
    OOD2_["EU"] = aupr_epis2

    AUROC = [OOD1, OOD2]
    AUPR = [OOD1_, OOD2_]

    return AUROC, AUPR

### [2-4] Codes for conducting distribution shift detection
def dist_shift_detection_fedl_mnist(model, testloader, fedl_type, device): 
    
    normalize = transforms.Normalize((0.5,), (0.5,))
    CORRUPTIONS = ['shot_noise','impulse_noise','glass_blur','motion_blur','shear','scale','rotate','brightness','translate',
                 'stripe','fog','spatter','dotted_line','zigzag','canny_edges',]

    AUROC1 = []
    AUPR1 = []
    AUROC2 = []
    AUPR2 = []

    MU_id, VAR_id, ALPHA_id, P_id, TAU_id = get_mean_var(model, testloader, fix_tau, fix_p, device)
    alea_id = AU(ALPHA_id, P_id, TAU_id)
    epis_id = EU(ALPHA_id, P_id, TAU_id)

    for corruption_type in CORRUPTIONS:
        base_path = "data/mnist_c/"

        data = torch.Tensor(np.load(os.path.join(base_path, corruption_type) + "/" + str("test_images.npy"))).reshape(-1,1,28,28)
        data = normalize(data / 255.0)

        label = torch.Tensor(np.load(os.path.join(base_path, corruption_type) + "/" + str("test_labels.npy"))) 

        corrupted_dataset = TensorDataset(data, label)
        corrupted_loader = torch.utils.data.DataLoader(corrupted_dataset, shuffle = False, batch_size = 64)

        MU_ood, VAR_ood, ALPHA_ood, P_ood, TAU_ood = get_mean_var(model, corrupted_loader, fix_tau, fix_p,device)
    
        alea_ood = AU(ALPHA_id, P_id, TAU_id)
        epis_ood = EU(ALPHA_id, P_id, TAU_id)

        in_scores1 = -alea_ood
        ood_scores1 = -alea_ood

        in_scores2 = -epis_id
        ood_scores2 = -epis_ood
        
        bin_labels = np.ones(in_scores1.shape[0])
        bin_labels = np.concatenate((bin_labels, np.zeros(ood_scores1.shape[0])))

        scores1 = np.concatenate((in_scores1, ood_scores1))
        scores2 = np.concatenate((in_scores2, ood_scores2))

        auroc1 = sklearn.metrics.roc_auc_score(bin_labels, scores1)
        aupr1 = sklearn.metrics.average_precision_score(bin_labels, scores1)

        auroc2 = sklearn.metrics.roc_auc_score(bin_labels, scores2)
        aupr2 = sklearn.metrics.average_precision_score(bin_labels, scores2)

        AUROC1.append(auroc1)
        AUPR1.append(aupr1)

        AUROC2.append(auroc2)
        AUPR2.append(aupr2)

    AUROC = {}
    AUPR = {}
    
    AUROC["AU"] = np.array(AUROC1).mean()
    AUROC["EU"] = np.array(AUROC2).mean()
    AUPR["AU"] = np.array(AUPR1).mean()
    AUPR["EU"] = np.array(AUPR2).mean()
    
    return AUROC, AUPR

from PIL import Image

class CIFARCorruption(torch.utils.data.Dataset):
    def __init__(self, data, transform):
        self.data = data  
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = Image.fromarray(self.data[idx])  
        return self.transform(img), 0   

class CIFARCorruption(torch.utils.data.Dataset):
    def __init__(self, data, transform):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        img = Image.fromarray(img)  
        return self.transform(img), 0  


def dist_shift_detection_fedl_cifar(ID_dataset, model, testloader, fix_tau, fix_p,  device):
    CORRUPTIONS = ["gaussian_noise", "shot_noise", "impulse_noise", "defocus_blur", "glass_blur", "motion_blur", "zoom_blur", "snow", "frost", "fog", "brightness", "contrast", "elastic_transform", "pixelate", "jpeg_compression", "speckle_noise", "gaussian_blur", "spatter", "saturate"]

    base_path = "data/CIFAR-10-C" if ID_dataset == "CIFAR-10" else "data/CIFAR-100-C"

    
    if ID_dataset == "CIFAR-10":
        MEAN = [0.4914, 0.4822, 0.4465]
        STD = [0.2023, 0.1994, 0.2010]
    else : 
        MEAN = [0.5071, 0.4867, 0.4408]
        STD = [0.2675, 0.2565, 0.2761]
            
    normalize = transforms.Normalize(mean = MEAN, std = STD)    
    transform = transforms.Compose([transforms.ToTensor(), normalize,])

    AUROC = {}
    AUPR = {}

    MU_id, VAR_id, [ALPHA_id, P_id, TAU_id] = get_mean_var(model, testloader, fix_tau, fix_p, device)
    alea_id = AU(MU_id, VAR_id)
    dist_id = DU(MU_id, VAR_id)

    for corruption_type in CORRUPTIONS:

        AUROC[corruption_type] = {}
        AUPR[corruption_type] = {}

        data = np.load(os.path.join(base_path, corruption_type) + ".npy")
        dataset_full = CIFARCorruption(data, transform)
            

        for severity in range(1,6):   

            AUROC[corruption_type][severity] = {}
            AUPR[corruption_type][severity] = {}

            dataset = data[10000*(severity-1):10000*severity]
            labels = torch.zeros(len(dataset))

            subset = torch.utils.data.Subset(dataset_full, list(range(10000*(severity-1), 10000*severity)))
            cifar10c_loader = torch.utils.data.DataLoader(subset, batch_size=128, shuffle=False)

            MU_ood, VAR_ood, [P_ood, ALPHA_ood, TAU_ood] = get_mean_var(model, cifar10c_loader, fix_tau, fix_p, device)

            alea_ood = AU(MU_ood, VAR_ood)
            dist_ood = DU(MU_ood, VAR_ood)

            in_scores1 = -alea_id
            ood_scores1 = -alea_ood

            in_scores2 = -dist_id
            ood_scores2 = -dist_ood

            bin_labels = np.ones(in_scores1.shape[0])
            bin_labels = np.concatenate((bin_labels, np.zeros(ood_scores1.shape[0])))

            scores1 = np.concatenate((in_scores1, ood_scores1))
            scores2 = np.concatenate((in_scores2, ood_scores2))

            auroc1 = sklearn.metrics.roc_auc_score(bin_labels, scores1)
            aupr1 = sklearn.metrics.average_precision_score(bin_labels, scores1)

            auroc2 = sklearn.metrics.roc_auc_score(bin_labels, scores2)
            aupr2 = sklearn.metrics.average_precision_score(bin_labels, scores2)

            AUPR[corruption_type][severity]["AU"] = aupr1
            AUPR[corruption_type][severity]["DU"] = aupr2 

            AUROC[corruption_type][severity]["AU"] = auroc1
            AUROC[corruption_type][severity]["DU"] = auroc2
            
    RESULT_AUPR = {}
    RESULT_AUROC = {}
    
    for severity_level in [1,2,3,4,5]:
        RESULT_AUPR[severity_level] = {}
        RESULT_AUROC[severity_level] = {}
    
        RESULT_AUPR[severity_level]["AU"] = np.array([AUPR[corruption_type][severity_level]["AU"] for corruption_type in CORRUPTIONS]).mean()
        RESULT_AUPR[severity_level]["DU"] = np.array([AUPR[corruption_type][severity_level]["DU"] for corruption_type in CORRUPTIONS]).mean()
        RESULT_AUROC[severity_level]["AU"] = np.array([AUROC[corruption_type][severity_level]["AU"] for corruption_type in CORRUPTIONS]).mean()
        RESULT_AUROC[severity_level]["DU"] = np.array([AUROC[corruption_type][severity_level]["DU"] for corruption_type in CORRUPTIONS]).mean()
    
    return RESULT_AUROC, RESULT_AUPR


def dist_shift_detection_fedl(ID_dataset, model, testloader, device):
    if ID_dataset =="MNIST":
        auroc, aupr = dist_shift_detection_fedl_mnist(model, testloader, device)
    else: 
        auroc, aupr = dist_shift_detection_fedl_cifar(ID_dataset, model, testloader, device)
         
    return auroc, aupr
        
##################################################################################################################
### [3] Code for EDL methods (EDL (NeurIPS 2018), I-EDL (ICML 2023), R-EDL (ICLR 2024))
### [3-1] Code for getting the parameters of the Dirichlet distribution
def get_alpha(model, edl_type, trainloader, loader, lamb, device):
    
    if edl_type == "DAEDL":
        logits_edl, _ = gmm_evaluate(model, gda, loader, device, num_classes, device)
    
        p_z = check(torch.logsumexp(logits_edl, dim = -1)) 
        p_z_train_min, p_z_train_max = p_z_train.min(), p_z_train.max()

        p_z[p_z < p_z_train_min] = p_z_train_min
        p_z[p_z > p_z_train_max] = p_z_train_max

        p_z = (p_z - p_z_train_min) / (p_z_train_max - p_z_train_min)
    
    ALPHA = []
    with torch.no_grad():
        for i, (x,y) in enumerate(loader):
            logits_t = model(x.to(device))
            
            if edl_type == "I-EDL":
                alpha_t = 1 + F.softplus(logits_t)
                
            if edl_type == "EDL":
                alpha_t = 1 + torch.relu(logits_t)
                    
            if edl_type == "DAEDL":
                alpha_t = 1e-6 + torch.exp(logits_t)
        
            if edl_type == "R-EDL":
                alpha_t = 1e-6 + F.softplus(logits_t) + lamb
            
            ALPHA.append(alpha_t)
            
    ALPHA = torch.cat(ALPHA)
    
    if edl_type == "DAEDL":
        ALPHA = torch.exp(torch.log(ALPHA) * p_z.reshape(-1,1)).cpu().numpy()
        
    else : 
        ALPHA = ALPHA.cpu().numpy()
        
    return ALPHA


### [3-2] Code for conducting OOD detection
def ood_detection_edl(model, edl_type, trainloader, testloader, ood_loader1, ood_loader2, lamb, device):

    ALPHA_id = get_alpha(model, edl_type, trainloader, testloader, lamb, device)
    ALPHA_ood1 = get_alpha(model, edl_type, trainloader, ood_loader1, lamb, device)
    ALPHA_ood2 = get_alpha(model, edl_type, trainloader, ood_loader2, lamb, device)

    alea_id = (ALPHA_id / ALPHA_id.sum(1).reshape(-1,1)).max(1)
    alea_ood1 = (ALPHA_ood1 / ALPHA_ood1.sum(1).reshape(-1,1)).max(1)
    alea_ood2 = (ALPHA_ood2 / ALPHA_ood2.sum(1).reshape(-1,1)).max(1)
    
    epis_id = ALPHA_id.sum(1)
    epis_ood1 = ALPHA_ood1.sum(1)
    epis_ood2 = ALPHA_ood2.sum(1)

    auroc_alea1, aupr_alea1 = auroc_aupr_fedl(-alea_id, -alea_ood1)
    auroc_alea2, aupr_alea2 = auroc_aupr_fedl(-alea_id, -alea_ood2)
    auroc_epis1, aupr_epis1 = auroc_aupr_fedl(-epis_id, -epis_ood1)
    auroc_epis2, aupr_epis2 = auroc_aupr_fedl(-epis_id, -epis_ood2)

    OOD1 = {}
    OOD1["AU"] = auroc_alea1
    OOD1["EU"] = auroc_epis1

    OOD2 = {}
    OOD2["AU"] = auroc_alea2
    OOD2["EU"] = auroc_epis2
    
    ## AUPR
    OOD1_ = {}
    OOD1_["AU"] = aupr_alea1
    OOD1_["EU"] = aupr_epis1

    OOD2_ = {}
    OOD2_["AU"] = aupr_alea2
    OOD2_["EU"] = aupr_epis2

    AUROC = [OOD1, OOD2]
    AUPR = [OOD1_, OOD2_]

    return AUROC, AUPR

### [3-3] Code for conducting distribution shift detection
def dist_shift_detection_edl_cifar10(model, edl_type, trainloader, testloader, lamb, device):
    CORRUPTIONS = ["gaussian_noise", "shot_noise", "impulse_noise", "defocus_blur", "glass_blur", "motion_blur", "zoom_blur", "snow", "frost", "fog", "brightness", "contrast", "elastic_transform", "pixelate", "jpeg_compression", "speckle_noise", "gaussian_blur", "spatter", "saturate"]
    
    MEAN = [0.49139968, 0.48215841, 0.44653091]
    STD  = [0.24703223, 0.24348513, 0.26158784]
    
    normalize = transforms.Normalize(mean = MEAN, std = STD)
    transform = transforms.Compose([transforms.ToTensor(), normalize,])

    AUROC = {}
    AUPR = {}

    ALPHA_id = get_alpha(model, edl_type, trainloader, testloader, lamb, device)
    alea_id = (ALPHA_id / ALPHA_id.sum(1).reshape(-1,1)).max(1)
    dist_id = ALPHA_id.sum(1)
    
    for corruption_type in CORRUPTIONS:
        base_path = "data/CIFAR-10-C"

        AUROC[corruption_type] = {}
        AUPR[corruption_type] = {}

        data = np.load(os.path.join(base_path, corruption_type) + ".npy")
        data = torch.from_numpy(np.array([transform(data[i]) for i in range(50000)])).float().to(device)

        for severity in range(1,6):   

            AUROC[corruption_type][severity] = {}
            AUPR[corruption_type][severity] = {}

            dataset = data[10000*(severity-1):10000*severity]
            labels = torch.zeros(len(dataset))

            cifar10c_dataset = torch.utils.data.TensorDataset(dataset, labels)      
            cifar10c_loader = torch.utils.data.DataLoader(cifar10c_dataset, shuffle = False, batch_size = 128)

            ALPHA_ood = get_alpha(model, edl_type, trainloader, cifar10c_loader, lamb, device)       
            alea_ood = (ALPHA_ood / ALPHA_ood.sum(1).reshape(-1,1)).max(1)
            dist_ood = ALPHA_ood.sum(1)

            auroc1, aupr1 = auroc_aupr_fedl(-alea_id, -alea_ood)
            auroc2, aupr2 = auroc_aupr_fedl(-dist_id, -dist_ood)

            AUPR[corruption_type][severity]["AU"] = aupr1
            AUPR[corruption_type][severity]["DU"] = aupr2 

            AUROC[corruption_type][severity]["AU"] = auroc1
            AUROC[corruption_type][severity]["DU"] = auroc2
            
    RESULT_AUPR = {}
    RESULT_AUROC = {}
    
    for severity_level in [1,2,3,4,5]:
        RESULT_AUPR[severity_level] = {}
        RESULT_AUROC[severity_level] = {}
    
        RESULT_AUPR[severity_level]["AU"] = np.array([AUPR[corruption_type][severity_level]["AU"] for corruption_type in CORRUPTIONS]).mean()
        RESULT_AUPR[severity_level]["DU"] = np.array([AUPR[corruption_type][severity_level]["DU"] for corruption_type in CORRUPTIONS]).mean()
        RESULT_AUROC[severity_level]["AU"] = np.array([AUROC[corruption_type][severity_level]["AU"] for corruption_type in CORRUPTIONS]).mean()
        RESULT_AUROC[severity_level]["DU"] = np.array([AUROC[corruption_type][severity_level]["DU"] for corruption_type in CORRUPTIONS]).mean()
    
    return RESULT_AUROC, RESULT_AUPR

def dist_shift_detection_edl_mnist(model, edl_type, trainloader, testloader, lamb, device): 
    normalize = transforms.Normalize((0.5,), (0.5,))
    CORRUPTIONS = ['shot_noise','impulse_noise','glass_blur','motion_blur','shear','scale','rotate','brightness','translate',
                 'stripe','fog','spatter','dotted_line','zigzag','canny_edges',]

    AUROC1 = []
    AUPR1 = []
    AUROC2 = []
    AUPR2 = []

    ALPHA_id = get_alpha(model, edl_type, trainloader, testloader, lamb, device)
    alea_id = (ALPHA_id / ALPHA_id.sum(1).reshape(-1,1)).max(1)
    dist_id = ALPHA_id.sum(1)

    for corruption_type in CORRUPTIONS:
        base_path = "data/mnist_c/"

        data = torch.Tensor(np.load(os.path.join(base_path, corruption_type) + "/" + str("test_images.npy"))).reshape(-1,1,28,28)
        data = normalize(data / 255.0)

        label = torch.Tensor(np.load(os.path.join(base_path, corruption_type) + "/" + str("test_labels.npy"))) 

        corrupted_dataset = TensorDataset(data, label)
        mnist_c_loader = torch.utils.data.DataLoader(corrupted_dataset, shuffle = False, batch_size = 64)
        
        ALPHA_ood = get_alpha(model, edl_type, trainloader, mnist_c_loader, lamb, device)
        alea_ood = (ALPHA_ood / ALPHA_ood.sum(1).reshape(-1,1)).max(1)
        dist_ood = ALPHA_ood.sum(1)

        in_scores1 = -alea_id
        ood_scores1 = -alea_ood

        in_scores2 = -dist_id
        ood_scores2 = -dist_ood
        
        bin_labels = np.ones(in_scores1.shape[0])
        bin_labels = np.concatenate((bin_labels, np.zeros(ood_scores1.shape[0])))

        scores1 = np.concatenate((in_scores1, ood_scores1))
        scores2 = np.concatenate((in_scores2, ood_scores2))

        auroc1 = sklearn.metrics.roc_auc_score(bin_labels, scores1)
        aupr1 = sklearn.metrics.average_precision_score(bin_labels, scores1)

        auroc2 = sklearn.metrics.roc_auc_score(bin_labels, scores2)
        aupr2 = sklearn.metrics.average_precision_score(bin_labels, scores2)

        AUROC1.append(auroc1)
        AUPR1.append(aupr1)

        AUROC2.append(auroc2)
        AUPR2.append(aupr2)

    AUROC = {}
    AUPR = {}
    
    AUROC["AU"] = np.array(AUROC1).mean()
    AUROC["DU"] = np.array(AUROC2).mean()
    AUPR["AU"] = np.array(AUPR1).mean()
    AUPR["DU"] = np.array(AUPR2).mean()
    
    return AUROC, AUPR


def dist_shift_detection_edl(ID_dataset, model, edl_type, trainloader, testloader, lamb, device):
    if ID_dataset =="CIFAR-10":
        auroc, aupr = dist_shift_detection_edl_cifar10(model, edl_type, trainloader, testloader, lamb, device)
    else :
        auroc, aupr = dist_shift_detection_edl_mnist(model, edl_type, trainloader, testloader, lamb, device)
        
    return auroc, aupr

#################################################################################################################
### [4] Code for Softmax (ICLR 2017) and Dropout (ICML 2016)
### [4-1] Code for conducting OOD detection
def ood_detection_softmax(model, sm_type, trainloader, testloader, ood_loader1, ood_loader2, num_passes, device):
    def get_softmax_probs(model, loader, sm_type, num_passes, device):
        all_probs = []
        all_variances = []  

        model.eval() if sm_type != "Dropout" else model.train()

        with torch.no_grad():
            for x, _ in loader:
                x = x.to(device)
                if sm_type == "Dropout":
                    batch_probs = []
                    for _ in range(num_passes):
                        logits = model(x)
                        probs = F.softmax(logits, dim=1)
                        batch_probs.append(probs)
                    batch_probs = torch.stack(batch_probs)
                    probs = batch_probs.mean(dim=0)
                    variance = batch_probs.var(dim=0)  # Variance across stochastic passes
                    all_variances.append(variance)
                else:
                    logits = model(x)
                    probs = F.softmax(logits, dim=1)
                all_probs.append(probs)

        all_probs = torch.cat(all_probs, dim=0)
        all_variances = torch.cat(all_variances, dim=0) if all_variances else None

        if all_probs.dim() < 2:
            raise ValueError(f"Expected probabilities to have at least 2 dimensions, got {all_probs.shape}")
            
        return all_probs, all_variances


    probs_id, var_id = get_softmax_probs(model, testloader, sm_type, num_passes, device)
    probs_ood1, var_ood1 = get_softmax_probs(model, ood_loader1, sm_type, num_passes, device)
    probs_ood2, var_ood2 = get_softmax_probs(model, ood_loader2, sm_type, num_passes, device)

    alea_id = probs_id.max(1)[0].cpu().numpy()
    alea_ood1 = probs_ood1.max(1)[0].cpu().numpy()
    alea_ood2 = probs_ood2.max(1)[0].cpu().numpy()


    epis_id = None
    epis_ood1 = None
    epis_ood2 = None

    if sm_type == "Dropout":
        epis_id = var_id.double().sum(dim=1).cpu().numpy()
        epis_ood1 = var_ood1.double().sum(dim=1).cpu().numpy()
        epis_ood2 = var_ood2.double().sum(dim=1).cpu().numpy()
        

    auroc_alea1, aupr_alea1 = auroc_aupr_fedl(-alea_id, -alea_ood1)
    auroc_alea2, aupr_alea2 = auroc_aupr_fedl(-alea_id, -alea_ood2)

    auroc_epis1, aupr_epis1 = (None, None)
    auroc_epis2, aupr_epis2 = (None, None)

    if sm_type == "Dropout":
        auroc_epis1, aupr_epis1 = auroc_aupr_fedl(epis_id, epis_ood1)
        auroc_epis2, aupr_epis2 = auroc_aupr_fedl(epis_id, epis_ood2)

    OOD1 = {"AU": auroc_alea1, "EU": auroc_epis1}
    OOD2 = {"AU": auroc_alea2, "EU": auroc_epis2}

    OOD1_ = {"AU": aupr_alea1, "EU": aupr_epis1}
    OOD2_ = {"AU": aupr_alea2, "EU": aupr_epis2}

    AUROC = [OOD1, OOD2]
    AUPR = [OOD1_, OOD2_]

    return AUROC, AUPR

### [4-2] Code for conducting distribution shift detection
def dist_shift_detection_softmax(model, sm_type, testloader, num_passes, device):
    CORRUPTIONS = ["gaussian_noise", "shot_noise", "impulse_noise", "defocus_blur", "glass_blur", "motion_blur", "zoom_blur", "snow", "frost", "fog", "brightness", "contrast", "elastic_transform", "pixelate", "jpeg_compression", "speckle_noise", "gaussian_blur", "spatter", "saturate"]
    
    MEAN = [0.49139968, 0.48215841, 0.44653091]
    STD  = [0.24703223, 0.24348513, 0.26158784]
    
    normalize = transforms.Normalize(mean = MEAN, std = STD)
    transform = transforms.Compose([transforms.ToTensor(), normalize,])
    
    def get_uncertainties(loader):
        all_probs = []
        all_variances = []

        with torch.no_grad():
            for x, _ in loader:
                x = x.to(device)
                if sm_type == "Dropout":
                    batch_probs = []
                    for _ in range(num_passes):
                        logits = model(x)
                        probs = torch.softmax(logits, dim=1)
                        batch_probs.append(probs)
                        
                    batch_probs = torch.stack(batch_probs)
                    mean_probs = batch_probs.mean(dim=0)
                    variance = batch_probs.var(dim=0)
                    all_variances.append(variance)
                    all_probs.append(mean_probs)
                    
                else:
                    logits = model(x)
                    probs = torch.softmax(logits, dim=1)
                    all_probs.append(probs)

        all_probs = torch.cat(all_probs, dim=0)
        all_variances = torch.cat(all_variances, dim=0) if all_variances else None
        return all_probs, all_variances

    model.train() if sm_type == "Dropout" else model.eval()
    probs_id, var_id = get_uncertainties(testloader)
    alea_id = probs_id.max(1)[0].cpu().numpy()
    epis_id = var_id.double().sum(dim=1).cpu().numpy() if var_id is not None else None

    AUROC = {}
    AUPR = {}

    for corruption_type in CORRUPTIONS:
        AUROC[corruption_type] = {"AU": {}, "EU": {}}
        AUPR[corruption_type] = {"AU": {}, "EU": {}}

        base_path = "data/CIFAR-10-C"
        corruption_data = np.load(os.path.join(base_path, corruption_type) + ".npy")
        corruption_data = torch.from_numpy(np.array([transform(data[i]) for i in range(50000)])).float().to(device)
       
        for severity in range(1, 6):
            corrupted_dataset = corruption_data[(severity - 1) * 10000 : severity * 10000]
            corrupted_loader = torch.utils.data.DataLoader(corrupted_dataset, batch_size=64, shuffle=False)
            probs_ood, var_ood = get_uncertainties(corrupted_loader)
            alea_ood = probs_ood.max(1)[0].cpu().numpy()
            epis_ood = var_ood.double().sum(dim=1).cpu().numpy() if var_ood is not None else None

            bin_labels = np.concatenate((np.ones_like(alea_id), np.zeros_like(alea_ood)))
            scores_alea = np.concatenate((-alea_id, -alea_ood))
            auroc_alea = sklearn.metrics.roc_auc_score(bin_labels, scores_alea)
            aupr_alea = sklearn.metrics.average_precision_score(bin_labels, scores_alea)

            AUROC[corruption_type]["AU"][severity] = auroc_alea
            AUPR[corruption_type]["AU"][severity] = aupr_alea
  
            if epis_id is not None and epis_ood is not None:
                scores_epis = np.concatenate((epis_id, epis_ood))
                auroc_epis = sklearn.metrics.roc_auc_score(bin_labels, scores_epis)
                aupr_epis = sklearn.metrics.average_precision_score(bin_labels, scores_epis)

                AUROC[corruption_type]["EU"][severity] = auroc_epis
                AUPR[corruption_type]["EU"][severity] = aupr_epis

    RESULT_AUROC = {"AU": [], "EU": []}
    RESULT_AUPR = {"AU": [], "EU": []}

    for severity in range(1, 6):
        RESULT_AUROC["AU"].append(np.mean([AUROC[corruption_type]["AU"][severity] for corruption_type in corruptions]))
        RESULT_AUPR["AU"].append(np.mean([AUPR[corruption_type]["AU"][severity] for corruption_type in corruptions]))
        
        if epis_id is not None:
            RESULT_AUROC["EU"].append(np.mean([AUROC[corruption_type]["EU"][severity]for corruption_type in corruptions]))
            RESULT_AUPR["EU"].append(np.mean([AUPR[corruption_type]["EU"][severity] for corruption_type in corruptions]))

    return RESULT_AUROC, RESULT_AUPR



### [5] Code for DDU (CVPR 2023)
### [5-1] Code for conducting OOD detection
def auroc_aupr_ddu(net, test_loader, ood_test_loader, id_density, ood_density, uncertainty, device):
    
    logits, _ = get_logits_labels(net, test_loader, device)
    ood_logits, _ = get_logits_labels(net, ood_test_loader, device)
   
    uncertainties = uncertainty(logits, id_density)
    ood_uncertainties = uncertainty(ood_logits, ood_density)
    
    print("ID unc:", uncertainties)
    print("OOD unc:", ood_uncertainties)

    bin_labels = torch.ones(uncertainties.shape[0]).to(device)
    bin_labels = torch.cat((bin_labels, torch.zeros(ood_uncertainties.shape[0]).to(device)))
                      
    in_scores = uncertainties
    ood_scores = ood_uncertainties

    scores = torch.cat((in_scores, ood_scores))
    auroc = sklearn.metrics.roc_auc_score(bin_labels.cpu().numpy(), scores.cpu().numpy())
    aupr = sklearn.metrics.average_precision_score(bin_labels.cpu().numpy(), scores.cpu().numpy())

    return auroc, aupr


def ood_detection_ddu(model, gda, p_z_train, testloader, ood_loader1, ood_loader2, num_classes, device):     
    logits_edl, _ = gmm_evaluate(model, gda, testloader, device, num_classes, device)
    ood_logits_edl1, _ = gmm_evaluate(model, gda, ood_loader1, device, num_classes,device)
    ood_logits_edl2, _ = gmm_evaluate(model, gda, ood_loader2, device, num_classes,device)
    
    p_z_test = check(torch.logsumexp(logits_edl, dim = -1))
    p_z_ood1 = check(torch.logsumexp(ood_logits_edl1, dim = -1))  
    p_z_ood2 = check(torch.logsumexp(ood_logits_edl2, dim = -1))
          
    auroc_alea1, aupr_alea1 = auroc_aupr_ddu(model, testloader, ood_loader1, p_z_test, p_z_ood1, -entropy, device)   
    auroc_alea2, aupr_alea2 = auroc_aupr_ddu(model, testloader, ood_loader2, p_z_test, p_z_ood2, -entropy, device)  
    auroc_epis1, aupr_epis1 = auroc_aupr_ddu(model, testloader, ood_loader1, p_z_test, p_z_ood1, identity_density, device)
    auroc_epis2, aupr_epis2 = auroc_aupr_ddu(model, testloader, ood_loader2, p_z_test, p_z_ood2, identity_density, device)

    OOD1 = {}
    OOD1["AU"] = auroc_alea1
    OOD1["EU"] = auroc_epis1
     
    OOD2 = {}
    OOD2["AU"] = auroc_alea2
    OOD2["EU"] = auroc_epis2
    
    ## AUPR
    OOD1_ = {}
    OOD1_["AU"] = aupr_alea1
    OOD1_["EU"] = aupr_epis1

    OOD2_ = {}
    OOD2_["AU"] = aupr_alea2
    OOD2_["EU"] = aupr_epis2
       
    AUROC = [OOD1, OOD2]
    AUPR = [OOD1_, OOD2_]   

    return AUROC, AUPR


##################################################################################################################
#### [6] Code for DAEDL (ICML 2024)
### [6-1] Code for computing AUROC, AUPR
def auroc_aupr_daedl(net, test_loader, ood_test_loader, train_density, id_density, ood_density, uncertainty, device):
    
    logits, _ = get_logits_labels(net, test_loader, device)
    ood_logits, _ = get_logits_labels(net, ood_test_loader, device)
    
    min = train_density.min().to(device)
    max = train_density.max().to(device)
    
    ood_density = ood_density.to(device)
    id_density = id_density.to(device)

    ood_density.clamp_(min=min, max=max)
    id_density.clamp_(min=min, max=max)

    id_norm_density = (id_density - min) / (max - min)
    ood_norm_density = (ood_density - min) / (max - min) 
    train_density_norm = (train_density - min) / (max - min)
    
    with torch.no_grad():
        plt.title("Density")
        plt.hist(id_norm_density.cpu().numpy(), bins = 50, label = "ID")
        plt.hist(ood_norm_density.cpu().numpy(), bins = 50, label = "OOD")
        plt.legend()

    uncertainties = uncertainty(logits, id_norm_density)
    ood_uncertainties = uncertainty(ood_logits, ood_norm_density)

    bin_labels = torch.ones(uncertainties.shape[0]).to(device)
    bin_labels = torch.cat((bin_labels, torch.zeros(ood_uncertainties.shape[0]).to(device)))
                      
    in_scores = uncertainties
    ood_scores = ood_uncertainties

    scores = torch.cat((in_scores, ood_scores))
    auroc = sklearn.metrics.roc_auc_score(bin_labels.cpu().numpy(), scores.cpu().numpy())
    aupr = sklearn.metrics.average_precision_score(bin_labels.cpu().numpy(), scores.cpu().numpy())

    return auroc, aupr

### [6-2] Code for conducting OOD detection
def ood_detection_daedl(model, gda, p_z_train, testloader, ood_loader1, ood_loader2, num_classes, device):     
    logits_edl, _ = gmm_evaluate(model, gda, testloader, device, num_classes, device)
    ood_logits_edl1, _ = gmm_evaluate(model, gda, ood_loader1, device, num_classes,device)
    ood_logits_edl2, _ = gmm_evaluate(model, gda, ood_loader2, device, num_classes,device)
    
    p_z_test = check(torch.logsumexp(logits_edl, dim = -1))
    p_z_ood1 = check(torch.logsumexp(ood_logits_edl1, dim = -1))  
    p_z_ood2 = check(torch.logsumexp(ood_logits_edl2, dim = -1))
          
    auroc_alea1, aupr_alea1 = auroc_aupr_daedl(model, testloader, ood_loader1, p_z_train, p_z_test, p_z_ood1, maxP_density, device)   
    auroc_alea2, aupr_alea2 = auroc_aupr_daedl(model, testloader, ood_loader2, p_z_train, p_z_test, p_z_ood2, maxP_density, device)  
    auroc_epis1, aupr_epis1 = auroc_aupr_daedl(model, testloader, ood_loader1, p_z_train, p_z_test, p_z_ood1, alpha0_density, device)
    auroc_epis2, aupr_epis2 = auroc_aupr_daedl(model, testloader, ood_loader2, p_z_train, p_z_test, p_z_ood2, alpha0_density, device)

    OOD1 = {}
    OOD1["AU"] = auroc_alea1
    OOD1["EU"] = auroc_epis1
     
    OOD2 = {}
    OOD2["AU"] = auroc_alea2
    OOD2["EU"] = auroc_epis2
    
    ## AUPR
    OOD1_ = {}
    OOD1_["AU"] = aupr_alea1
    OOD1_["EU"] = aupr_epis1

    OOD2_ = {}
    OOD2_["AU"] = aupr_alea2
    OOD2_["EU"] = aupr_epis2
       
    AUROC = [OOD1, OOD2]
    AUPR = [OOD1_, OOD2_]   

    return AUROC, AUPR

### [6-3] Codes for conducting distribution shift detection
def dist_shift_detection_daedl_cifar10(model, gda, p_z_train, testloader, num_classes, device):

    CORRUPTIONS = ["gaussian_noise", "shot_noise", "impulse_noise", "defocus_blur", "glass_blur", "motion_blur", "zoom_blur", "snow", "frost", "fog", "brightness", "contrast", "elastic_transform", "pixelate", "jpeg_compression", "speckle_noise", "gaussian_blur", "spatter", "saturate"]
    
    MEAN = [0.49139968, 0.48215841, 0.44653091]
    STD  = [0.24703223, 0.24348513, 0.26158784]
    normalize = transforms.Normalize(mean = MEAN, std = STD)  
    transform = transforms.Compose([transforms.ToTensor(), normalize,])

    AUROC = {}
    AUPR = {}
    
    logits_daedl, _ = gmm_evaluate(model, gda, testloader, device, num_classes, device)
    p_z_test = check(torch.logsumexp(logits_daedl, dim = -1))
       
    for corruption_type in CORRUPTIONS:
        base_path = "data/CIFAR-10-C"

        AUROC[corruption_type] = {}
        AUPR[corruption_type] = {}

        data = np.load(os.path.join(base_path, corruption_type) + ".npy")
        data = torch.from_numpy(np.array([transform(data[i]) for i in range(50000)])).float().to(device)

        for severity in range(1,6):   

            AUROC[corruption_type][severity] = {}
            AUPR[corruption_type][severity] = {}

            dataset = data[10000*(severity-1):10000*severity]
            labels = torch.zeros(len(dataset))

            cifar10c_dataset = torch.utils.data.TensorDataset(dataset, labels)      
            cifar10c_loader = torch.utils.data.DataLoader(cifar10c_dataset, shuffle = False, batch_size = 128)
            
            ood_logits_daedl, _ = gmm_evaluate(model, gda, cifar10c_loader, device, num_classes, device)
            p_z_ood = check(torch.logsumexp(ood_logits_daedl, dim = -1)) 
            
            auroc1, aupr1 = auroc_aupr_daedl(model, testloader, cifar10c_loader, p_z_train, p_z_test, p_z_ood, maxP_density, device)
            auroc2, aupr2 = auroc_aupr_daedl(model, testloader, cifar10c_loader, p_z_train, p_z_test, p_z_ood, alpha0_density, device)

            AUPR[corruption_type][severity]["AU"] = aupr1
            AUPR[corruption_type][severity]["DU"] = aupr2 

            AUROC[corruption_type][severity]["AU"] = auroc1
            AUROC[corruption_type][severity]["DU"] = auroc2
            
    RESULT_AUPR = {}
    RESULT_AUROC = {}
    
    for severity_level in [1,2,3,4,5]:
        RESULT_AUPR[severity_level] = {}
        RESULT_AUROC[severity_level] = {}
    
        RESULT_AUPR[severity_level]["AU"] = np.array([AUPR[corruption_type][severity_level]["AU"] for corruption_type in CORRUPTIONS]).mean()
        RESULT_AUPR[severity_level]["DU"] = np.array([AUPR[corruption_type][severity_level]["DU"] for corruption_type in CORRUPTIONS]).mean()
        RESULT_AUROC[severity_level]["AU"] = np.array([AUROC[corruption_type][severity_level]["AU"] for corruption_type in CORRUPTIONS]).mean()
        RESULT_AUROC[severity_level]["DU"] = np.array([AUROC[corruption_type][severity_level]["DU"] for corruption_type in CORRUPTIONS]).mean()
    
    return RESULT_AUROC, RESULT_AUPR


def dist_shift_detection_daedl_mnist(model, gda, p_z_train, testloader, num_classes, device):    
    normalize = transforms.Normalize((0.1307,), (0.3081,))
    CORRUPTIONS = ['shot_noise','impulse_noise','glass_blur','motion_blur','shear','scale','rotate','brightness','translate',
                 'stripe','fog','spatter','dotted_line','zigzag','canny_edges',]

    AUROC1 = []
    AUPR1 = []
    AUROC2 = []
    AUPR2 = []

    logits_daedl, _ = gmm_evaluate(model, gda, testloader, device, num_classes, device)
    p_z_test = check(torch.logsumexp(logits_daedl, dim = -1))

    for corruption_type in CORRUPTIONS:
        base_path = "data/mnist_c/"

        data = torch.Tensor(np.load(os.path.join(base_path, corruption_type) + "/" + str("test_images.npy"))).reshape(-1,1,28,28)
        data = normalize(data / 255.0)

        label = torch.Tensor(np.load(os.path.join(base_path, corruption_type) + "/" + str("test_labels.npy"))) 

        corrupted_dataset = TensorDataset(data, label)
        mnist_c_loader = torch.utils.data.DataLoader(corrupted_dataset, shuffle = False, batch_size = 64)
        
        auroc1, aupr1 = auroc_aupr_daedl(model, testloader, mnist_c_loader, p_z_train, p_z_test, p_z_ood, maxP_density, device)
        auroc2, aupr2 = auroc_aupr_daedl(model, testloader, mnist_c_loader, p_z_train, p_z_test, p_z_ood, alpha0_density, device)
        
        AUROC1.append(auroc1)
        AUPR1.append(aupr1)

        AUROC2.append(auroc2)
        AUPR2.append(aupr2)

    AUROC = {}
    AUPR = {}
    
    AUROC["AU"] = np.array(AUROC1).mean()
    AUROC["DU"] = np.array(AUROC2).mean()
    AUPR["AU"] = np.array(AUPR1).mean()
    AUPR["DU"] = np.array(AUPR2).mean()
    
    return AUROC, AUPR

def dist_shift_detection_daedl(ID_dataset, model, gda, p_z_train, testloader, num_classes, device):
    if ID_dataset == "CIFAR-10":
        dist_auroc, dist_aupr = dist_shift_detection_daedl_cifar10(model, gda, p_z_train, testloader, num_classes, device)
    else:
        dist_auroc, dist_aupr = dist_shift_detection_daedl_cifar10(model, gda, p_z_train, testloader, num_classes, device)
        
    return dist_auroc, dist_aupr
        
