#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Deep Ensemble (DE) Training for Bayesian inference on whole CIFAR-10 (classes 0–9)
using ResNet-50 embeddings and an ensemble of SimpleMLP models.

This script loads (or computes and caches) ResNet-50 embeddings for CIFAR-10 images 
with labels 0–9 (in-domain) and three extra OODs. For DE training, we select fixed 
subsets of the in-domain training and validation embeddings and train R replicates, each 
with an ensemble of models. For each replicate, the script computes:
  - In-domain metrics: Negative Log Likelihood (NLL), Total Entropy, Epistemic Uncertainty (Mutual Information), 
    F1 Score, AUC-PR, and Accuracy. It also provides a breakdown for correct vs. incorrect predictions.
  - Out-of-domain metrics: Total Entropy and Epistemic Uncertainty (F1 and AUC-PR are set to zero).
All replicate metrics are then aggregated and saved to a MAT file.
"""

############################################
#           Imports
############################################
import os
import sys
import time
import random
import numpy as np
from scipy.io import savemat

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
from torchvision import models

from sklearn.metrics import f1_score, average_precision_score, precision_score, recall_score
from sklearn.preprocessing import label_binarize

import pyro
from PIL import Image
import tarfile, urllib.request

############################################
#      Device and Prior Settings
############################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sigma_w = np.sqrt(0.2)   # standard deviation for weights
sigma_b = np.sqrt(0.2)   # standard deviation for biases

############################################
#      whole CIFAR-10 Dataset Class
############################################
class FilteredCIFAR10(Dataset):
    def __init__(self, root, train, transform, download, allowed_labels):
        self.dataset = datasets.CIFAR10(root=root, train=train, transform=transform, download=download)
        self.allowed_labels = allowed_labels
        self.data = [(img, label) for img, label in self.dataset if label in allowed_labels]
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)

############################################
#   ResNet-50 Embedding Extraction for whole CIFAR-10 (In-Domain: Labels 0–9)
############################################
def create_resnet50_embedded_cifar10_dataset(
    train_cache_path="cifar10_train_embeddings.pt",
    test_cache_path="cifar10_test_embeddings.pt",
    allowed_labels=list(range(10))
):
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached ResNet-50 embeddings for whole CIFAR-10...")
        X_train, y_train = torch.load(train_cache_path)
        X_test, y_test = torch.load(test_cache_path)
        return X_train, y_train, X_test, y_test

    print("Cached embeddings not found. Computing ResNet-50 embeddings for whole CIFAR-10...")
    transform = transforms.Compose([
         transforms.Resize(224),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
    ])
    train_dataset = FilteredCIFAR10(root='./data', train=True, transform=transform, download=True, allowed_labels=allowed_labels)
    test_dataset  = FilteredCIFAR10(root='./data', train=False, transform=transform, download=True, allowed_labels=allowed_labels)
    N_tr = len(train_dataset)
    N_val = len(test_dataset)
    train_dataset = Subset(train_dataset, list(range(N_tr)))
    test_dataset = Subset(test_dataset, list(range(N_val)))
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=2)
    test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
    resnet50 = models.resnet50(pretrained=True)
    # Remove the final fc layer to obtain 2048-dim features.
    feature_extractor = nn.Sequential(*list(resnet50.children())[:-1])
    feature_extractor.eval()
    feature_extractor.to(device)
    X_train_list, y_train_list = [], []
    with torch.no_grad():
        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)
            features = features.view(features.size(0), -1)
            X_train_list.append(features.cpu())
            y_train_list.append(targets)
    X_train = torch.cat(X_train_list, dim=0)
    y_train = torch.cat(y_train_list, dim=0)
    X_test_list, y_test_list = [], []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)
            features = features.view(features.size(0), -1)
            X_test_list.append(features.cpu())
            y_test_list.append(targets)
    X_test = torch.cat(X_test_list, dim=0)
    y_test = torch.cat(y_test_list, dim=0)
    torch.save((X_train, y_train), train_cache_path)
    torch.save((X_test, y_test), test_cache_path)
    print("ResNet-50 embeddings computed and saved.")
    return X_train, y_train, X_test, y_test

############################################
#   ResNet-50 Embedding Extraction for OOD
############################################
# --- 1) CIFAR-100 “close” OOD: classes not in CIFAR-10 ---
def sample_cifar100_not_in_cifar10(n1, root='./data'):
    # Download both datasets
    cifar10 = datasets.CIFAR10(root=root, train=False, download=True)
    cifar100 = datasets.CIFAR100(root=root, train=False, download=True)
    
    # Which CIFAR-100 fine-class names are NOT in CIFAR-10?
    cif10_set = set(cifar10.classes)
    cif100_names = cifar100.classes  # list of 100 names
    ood_names = [name for name in cif100_names if name not in cif10_set]
    
    # Filter indices whose label name is in our OOD set
    idxs = [
        idx for idx, label in enumerate(cifar100.targets)
        if cif100_names[label] in ood_names
    ]
    random.seed(42)
    chosen = random.sample(idxs, n1)
    
    # Return as PIL Images
    return [Image.fromarray(cifar100.data[i]) for i in chosen]

# --- 2) CIFAR-10-C “corrupt” OOD: all corruption types/severities ---
def sample_cifar10c(n2, root='./data'):
    url = 'https://zenodo.org/record/2535967/files/CIFAR-10-C.tar'
    tar_path = os.path.join(root, 'CIFAR-10-C.tar')
    extract_dir = os.path.join(root, 'CIFAR-10-C')
    
    # Download if needed
    if not os.path.exists(tar_path):
        os.makedirs(root, exist_ok=True)
        print("Downloading CIFAR-10-C (≈180 MB)...")
        urllib.request.urlretrieve(url, tar_path)
    
    # Extract if needed
    if not os.path.isdir(extract_dir):
        print("Extracting CIFAR-10-C...")
        with tarfile.open(tar_path) as tar:
            tar.extractall(path=root)
    
    # Load every corruption .npy (skip the labels file)
    all_imgs = []
    for fname in os.listdir(extract_dir):
        if fname.endswith('.npy') and 'labels' not in fname:
            arr = np.load(os.path.join(extract_dir, fname))
            # arr shape: [10000, 32, 32, 3]
            all_imgs.append(arr)
    # Concatenate: shape [#types × 10000, 32,32,3]
    all_imgs = np.vstack(all_imgs)
    
    # Sample and convert
    random.seed(42)
    chosen = random.sample(range(len(all_imgs)), n2)
    return [Image.fromarray(all_imgs[i].astype(np.uint8)) for i in chosen]

# --- 3) SVHN “far” OOD: street-view digits ---
def sample_svhn(n3, root='./data'):
    # transform to PIL so dataset[i][0] is an Image
    #transform = transforms.ToPILImage()
    #svhn = SVHN(root=root, split='test', download=True, transform=transform)
    svhn = datasets.SVHN(root=root, split='test', download=True, transform=None)
    
    random.seed(42)
    chosen = random.sample(range(len(svhn)), n3)
    return [svhn[i][0] for i in chosen]

def create_resnet50_embedded_ood_datasets(
        n_close, n_corrupt, n_far,
        cache_path='ood_embeddings.pt'):

    if os.path.exists(cache_path):
        print("Loading cached ood embeddings...")
        return torch.load(cache_path)

    print("Cached embeddings not found. Computing ood embeddings...")
    
    # 1) Define the image‐preprocessing transform for OOD PIL images:
    transform = transforms.Compose([
        transforms.Resize(224),                  # match ResNet50’s input size
        transforms.ToTensor(),                   # convert PIL→[0,1] tensor
        transforms.Normalize(                    # ImageNet mean/std
            mean=[0.485, 0.456, 0.406],
            std =[0.229, 0.224, 0.225]
        )
    ])

    # 2) Build the ResNet‐50 feature extractor:
    resnet50 = models.resnet50(pretrained=True)     # download pretrained weights
    # drop the final fully‐connected layer:
    feature_extractor = nn.Sequential(*list(resnet50.children())[:-1])
    feature_extractor.eval()                        # set to eval mode
    feature_extractor.to(device)                    # move to GPU if available

    # 3) Sample PIL images
    close_imgs   = sample_cifar100_not_in_cifar10(n_close)
    corrupt_imgs = sample_cifar10c(n_corrupt)
    far_imgs     = sample_svhn(n_far)

    # 4) Convert → tensors & DataLoaders
    def imgs_to_embeddings(img_list):
        tensors = torch.stack([transform(img) for img in img_list])
        loader = DataLoader(TensorDataset(tensors), batch_size=128, shuffle=False)
        feats = []
        with torch.no_grad():
            for (batch,) in loader:
                batch = batch.to(device)
                out = feature_extractor(batch).view(batch.size(0), -1)
                feats.append(out.cpu())
        return torch.cat(feats, dim=0)

    X_close   = imgs_to_embeddings(close_imgs)
    X_corrupt = imgs_to_embeddings(corrupt_imgs)
    X_far     = imgs_to_embeddings(far_imgs)

    torch.save((X_close, X_corrupt, X_far), cache_path)
    print("ood embeddings computed and saved.")
    return X_close, X_corrupt, X_far

############################################
#      Metrics computation: Brier and ECE
############################################
def compute_brier(probs, labels):
    """
    Multi-class Brier score: 
      (1/N) ∑_i ∑_c (p_{i,c} – 1{y_i=c})^2.
    """
    # probs: shape (N, C), rows sum to 1
    # labels: shape (N,), ints in [0..C-1]
    N, C = probs.shape
    # one-hot encode
    one_hot = np.zeros_like(probs)
    one_hot[np.arange(N), labels] = 1
    return np.mean(np.sum((probs - one_hot)**2, axis=1))


# def compute_ece(probs, labels, n_bins=10):
#     """Expected Calibration Error over `n_bins` equally‐spaced bins."""
#     bins = np.linspace(0.0, 1.0, n_bins + 1)
#     ece = 0.0
#     for i in range(n_bins):
#         mask = (probs >= bins[i]) & (probs < bins[i+1])
#         if not np.any(mask):
#             continue
#         preds_bin = (probs[mask] >= 0.5).astype(int)
#         acc_bin   = (labels[mask] == preds_bin).mean()
#         conf_bin  = probs[mask].mean()
#         ece      += np.abs(acc_bin - conf_bin) * (mask.sum() / len(probs))
#     return ece
def compute_ece(probs, labels, n_bins=10):
    """
    Expected Calibration Error for multiclass outputs.
    
    Args:
      probs: array of shape (N, C) of predicted class probabilities (each row sums to 1)
      labels: array of shape (N,) of integer true labels in [0..C-1]
      n_bins: number of equal-width confidence bins in [0,1]
    
    Returns:
      scalar ECE = sum_b (|acc(b) – conf(b)| · |bin_b|/N)
    """
    # 1) For each example, get predicted class and its confidence
    preds       = np.argmax(probs, axis=1)                             # shape (N,)
    confidences = probs[np.arange(len(probs)), preds]                 # shape (N,)
    correct     = (preds == labels).astype(float)                     # shape (N,)
    
    # 2) Build bins over [0,1]
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    
    # 3) Loop through bins
    for i in range(n_bins):
        low, high = bin_edges[i], bin_edges[i+1]
        if i == n_bins - 1:
            in_bin = (confidences >= low) & (confidences <= high)
        else:
            in_bin = (confidences >= low) & (confidences < high)
        prop_in_bin = in_bin.mean()
        if prop_in_bin == 0:
            continue
        
        # average accuracy and average confidence in this bin
        acc_bin  = correct[in_bin].mean()
        conf_bin = confidences[in_bin].mean()
        
        # accumulate weighted gap
        ece += prop_in_bin * abs(acc_bin - conf_bin)
    
    return ece


############################################
#         SimpleMLP Model Definition (Logistic Regression)
############################################
class SimpleMLP(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=0, num_classes=10):
        super(SimpleMLP, self).__init__()
        if hidden_dim is None or hidden_dim == 0:
            # Logistic regression: no hidden layer.
            self.fc = nn.Linear(input_dim, num_classes)
            nn.init.normal_(self.fc.weight, mean=0, std=sigma_w)
            if self.fc.bias is not None:
                nn.init.normal_(self.fc.bias, mean=0, std=sigma_b)
        else:
            # Single-hidden-layer MLP.
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, num_classes)
            nn.init.normal_(self.fc1.weight, mean=0, std=sigma_w)
            nn.init.normal_(self.fc2.weight, mean=0, std=sigma_w)
            if self.fc1.bias is not None:
                nn.init.normal_(self.fc1.bias, mean=0, std=sigma_b)
            if self.fc2.bias is not None:
                nn.init.normal_(self.fc2.bias, mean=0, std=sigma_b)
    def forward(self, x):
        if hasattr(self, 'fc'):
            return self.fc(x)
        else:
            x = self.fc1(x)
            x = F.relu(x)
            return self.fc2(x)

############################################
#    Deep Ensemble Prediction Function
############################################
def deep_ensemble_predict(x, models):
    preds = []
    for model in models:
        model.eval()
        with torch.no_grad():
            logits = model(x.to(device))
            probs = F.softmax(logits, dim=1)
            preds.append(probs)
    # Average the softmax outputs of all ensemble members.
    ensemble_probs = torch.stack(preds, dim=0).mean(dim=0)
    return ensemble_probs

############################################
#    Main Execution: DE Training on CIFAR-10 Embeddings
############################################
if __name__ == '__main__':
    ############################################
    #       Load ResNet-50 Embeddings for CIFAR-10 (Classes 0–9)
    ############################################
    allowed_labels = list(range(10))
    train_cache = "cifar10_train_embeddings.pt"
    test_cache  = "cifar10_test_embeddings.pt"
    X_train, y_train, X_test, y_test = create_resnet50_embedded_cifar10_dataset(
         train_cache_path=train_cache,
         test_cache_path=test_cache,
         allowed_labels=allowed_labels
    )
    print(f"Embedding dimension: {X_train.shape[1]}")

    # split and split into 40 000 train / 10 000 early‐stop validation
    n_total = X_train.size(0)        # should be 50 000 for CIFAR-10
    print(f"Total CIFAR-10 train embeddings: {n_total}")

    # 2) Split
    n_train_es = 50000
    n_val_es   = n_total - n_train_es    # = 25 000

    X_train_subset = X_train[:n_train_es]
    y_train_subset = y_train[:n_train_es]
    X_val_subset   = X_train[n_train_es:]
    y_val_subset   = y_train[n_train_es:]

    # Create TensorDatasets and DataLoaders
    train_dataset = TensorDataset(X_train_subset, y_train_subset)
    val_dataset   = TensorDataset(X_val_subset, y_val_subset)
    batch_size = 128
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # # Full loaders for training & early-stop validation
    # train_loader_full = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=True)
    # val_loader_full   = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False)
    # x_train_full, y_train_full   = next(iter(train_loader_full))
    # x_val_es_full, y_val_es_full = next(iter(val_loader_full))
    # x_train_full, y_train_full   = x_train_full.to(device),   y_train_full.to(device)
    # x_val_es_full, y_val_es_full = x_val_es_full.to(device), y_val_es_full.to(device)

    # Now load the entire CIFAR-10 TEST (labels 0–9) for ID evaluation
    test_dataset = TensorDataset(X_test, y_test)
    test_loader_full = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)
    x_val_full, y_val_full = next(iter(test_loader_full))
    x_val_full, y_val_full = x_val_full.to(device), y_val_full.to(device)

    
    ############################################
    #      DE Replicates and Ensemble Training Settings
    ############################################
    R = 5
    ensemble_size = 10

    # Lists to store metrics per replicate.
    replicate_epochs       = []  # average number of epochs per replicate (over ensemble members)
    replicate_accuracy     = []
    replicate_nll_in       = []
    replicate_total_entropy_in = []
    replicate_epistemic_in = []
    replicate_f1_in        = []
    replicate_aucpr_in     = []

    replicate_brier = []
    replicate_ece   = []
    replicate_precision = []
    replicate_recall    = []

    # replicate_total_entropy_od = []
    # replicate_epistemic_od = []
    replicate_total_entropy_od = {
        "close":   [],
        "corrupt": [],
        "far":     []
    }
    replicate_epistemic_od = {
        "close":   [],
        "corrupt": [],
        "far":     []
    }
    # Breakdown for in-domain predictions.
    replicate_total_entropy_in_correct = []
    replicate_epistemic_in_correct = []
    replicate_total_entropy_in_incorrect = []
    replicate_epistemic_in_incorrect = []

    ############################################
    # Start DE replicates.
    ############################################
    for r in range(1, R+1):
        print(f"\n===== Starting replicate {r} =====\n")
        # Set replicate seed.
        torch.manual_seed(r)
        np.random.seed(r)
        random.seed(r)
        pyro.set_rng_seed(r)
        
        # (The training and validation subsets remain fixed.)
        
        # Deep Ensembles: train ensemble_size models for replicate r.
        ensemble_models = []   # to store ensemble members
        ensemble_epochs = []   # to store number of epochs trained per member
        replicate_de_time = []
        
        max_epochs = 200
        moving_avg_window = 10
        patience = 5
        #best_val_loss = float('inf')
        
        for m in range(ensemble_size):
            member_seed = r * 1000 + m  
            torch.manual_seed(member_seed)
            np.random.seed(member_seed)
            random.seed(member_seed)
            pyro.set_rng_seed(member_seed)
            
            model_de = SimpleMLP(input_dim=X_train.shape[1], hidden_dim=0, num_classes=10).to(device)
            optimizer_de = optim.Adam(model_de.parameters(), lr=0.001)
            criterion_de = nn.CrossEntropyLoss()
            
            train_losses_de = []
            val_losses_de = []
            best_moving_avg_de = float('inf')
            no_improve_count_de = 0
            
            start_time_de = time.time()
            for epoch in range(max_epochs):
                model_de.train()
                running_loss = 0.0
                for features, labels in train_loader:
                    features = features.to(device)
                    labels = labels.to(device)
                    optimizer_de.zero_grad()
                    outputs = model_de(features)
                    ce_loss = criterion_de(outputs, labels)
                    # Regularization using the Gaussian prior.
                    reg_loss = (torch.sum(model_de.fc.weight**2) / (2 * sigma_w**2) +
                                torch.sum(model_de.fc.bias**2)   / (2 * sigma_b**2))
                    loss = ce_loss + reg_loss / len(train_loader.dataset)
                    loss.backward()
                    optimizer_de.step()
                    running_loss += loss.item()
                train_loss = running_loss / len(train_loader)
                train_losses_de.append(train_loss)

                print(f"Replicate {r}, DE Model (seed={member_seed}) Epoch {epoch+1}: Train Loss = {train_loss:.4f}")
                
                # Evaluate on validation set.
                # model_de.eval()
                # val_running_loss = 0.0
                # with torch.no_grad():
                #     for features, labels in val_loader:
                #         features = features.to(device)
                #         labels = labels.to(device)
                #         outputs = model_de(features)
                #         ce_loss = criterion_de(outputs, labels)
                #         reg_loss = (torch.sum(model_de.fc.weight**2) / (2 * sigma_w**2) +
                #                     torch.sum(model_de.fc.bias**2)   / (2 * sigma_b**2))
                #         loss = ce_loss + reg_loss / len(train_loader.dataset)
                #         val_running_loss += loss.item()
                # val_loss = val_running_loss / len(val_loader)
                # val_losses_de.append(val_loss)
                
                # print(f"Replicate {r}, DE Model (seed={member_seed}) Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
                
                # if epoch >= moving_avg_window - 1:
                #     moving_avg = np.mean(val_losses_de[-moving_avg_window:])
                #     if moving_avg < best_moving_avg_de:
                #         best_moving_avg_de = moving_avg
                #         no_improve_count_de = 0
                #     else:
                #         no_improve_count_de += 1
                #     if no_improve_count_de >= patience:
                #         print(f"Replicate {r}, DE Model (seed={member_seed}) Early stopping at epoch {epoch+1}")
                #         break
            
            total_time_de = time.time() - start_time_de
            epochs_trained = epoch + 1
            ensemble_epochs.append(epochs_trained)
            print(f"Replicate {r}, DE Model (seed={member_seed}) Total training time: {total_time_de:.2f} seconds")
            print(f"Trained DE model with seed {member_seed} in {epochs_trained} epochs\n")
            ensemble_models.append(model_de)
            replicate_de_time.append(total_time_de)
        
        avg_epochs = np.mean(ensemble_epochs)
        print(f"Replicate {r}, Average training epochs over ensembles: {avg_epochs:.2f}")
        
        ############################################
        # Evaluate Deep Ensemble on In-Domain Data (Classes 0–9)
        ############################################
        ensemble_probs = deep_ensemble_predict(x_val_full, ensemble_models)  # shape: (N_val, 10)
        ensemble_probs_np = ensemble_probs.cpu().numpy()
        labels_val_np = y_val_full.cpu().numpy().flatten()
        
        # Compute Negative Log Likelihood (NLL).
        eps = 1e-10
        nlls = -np.log(ensemble_probs_np[np.arange(len(labels_val_np)), labels_val_np] + eps)
        avg_nll_in = np.mean(nlls)
        
        # Compute Total Entropy.
        total_entropy = -np.sum(ensemble_probs_np * np.log(ensemble_probs_np + eps), axis=1)
        avg_total_entropy_in = np.mean(total_entropy)
        
        # Compute per-model predictions to estimate epistemic uncertainty.
        all_model_probs = []
        for model in ensemble_models:
            model.eval()
            with torch.no_grad():
                logits = model(x_val_full)
                probs = F.softmax(logits, dim=1)
                all_model_probs.append(probs.cpu().numpy())
        all_model_probs = np.stack(all_model_probs, axis=0)  # shape: (ensemble_size, N_val, 10)
        model_entropies = -np.sum(all_model_probs * np.log(all_model_probs + eps), axis=2)  # (ensemble_size, N_val)
        avg_model_entropy = np.mean(model_entropies, axis=0)  # (N_val,)
        # Epistemic Uncertainty (Mutual Information)
        mi = total_entropy - avg_model_entropy
        avg_epistemic_in = np.mean(mi)
        
        # Breakdown for correct and incorrect predictions.
        ensemble_preds = np.argmax(ensemble_probs_np, axis=1)
        correct_mask = (ensemble_preds == labels_val_np)
        incorrect_mask = (ensemble_preds != labels_val_np)
        if np.sum(correct_mask) > 0:
            total_entropy_in_correct = np.mean(total_entropy[correct_mask])
            epistemic_in_correct = np.mean(mi[correct_mask])
        else:
            total_entropy_in_correct = np.nan
            epistemic_in_correct = np.nan
        if np.sum(incorrect_mask) > 0:
            total_entropy_in_incorrect = np.mean(total_entropy[incorrect_mask])
            epistemic_in_incorrect = np.mean(mi[incorrect_mask])
        else:
            total_entropy_in_incorrect = np.nan
            epistemic_in_incorrect = np.nan
        
        # Compute F1 Score and AUC-PR (macro) for in-domain.
        f1_in = f1_score(labels_val_np, ensemble_preds, average='macro', zero_division=0)
        labels_val_bin = label_binarize(labels_val_np, classes=list(range(10)))
        aucpr_in = average_precision_score(labels_val_bin, ensemble_probs_np, average='macro')
        accuracy_in = np.mean(ensemble_preds == labels_val_np)

        brier = compute_brier(ensemble_probs_np, labels_val_np)
        # after computing ensemble_preds = argmax
        #confidences = probs[np.arange(len(y_val_full)), ensemble_preds]
        #correctness  = (ensemble_preds == y_val_full).astype(int)
        ece = compute_ece(ensemble_probs_np, labels_val_np, n_bins=10)

        y_val_bin = label_binarize(y_val_full, classes=np.arange(10))
        f1_val = f1_score(y_val_full, ensemble_preds, average='macro', zero_division=0)
        aucpr_val = average_precision_score(y_val_bin, probs, average='macro')
        precision_val = precision_score(y_val_full, ensemble_preds,average='macro', zero_division=0)
        recall_val = recall_score(y_val_full, ensemble_preds,average='macro', zero_division=0)
        
        print("\nReplicate {} In-Domain Analysis (Classes 0–9):".format(r))
        print("Average NLL: {:.6f}".format(avg_nll_in))
        print("Average Total Entropy: {:.6f}".format(avg_total_entropy_in))
        print("Average Epistemic Uncertainty: {:.6f}".format(avg_epistemic_in))
        print("F1 Score (macro): {:.6f}".format(f1_in))
        print("AUC-PR (macro): {:.6f}".format(aucpr_in))
        print("Validation Accuracy: {:.6f}%".format(accuracy_in * 100))
        print("\nBreakdown for In-Domain Predictions:")
        print("Correct Predictions - Total Entropy: {:.6f}".format(total_entropy_in_correct))
        print("Correct Predictions - Epistemic Uncertainty: {:.6f}".format(epistemic_in_correct))
        print("Incorrect Predictions - Total Entropy: {:.6f}".format(total_entropy_in_incorrect))
        print("Incorrect Predictions - Epistemic Uncertainty: {:.6f}".format(epistemic_in_incorrect))
        
        ############################################
        # Out-Of-Domain Analysis 
        ############################################
        print("\nComputing OOD embeddings...")
        # sample sizes (you choose these; e.g. 1 000 each)
        n_close, n_corrupt, n_far = 1000, 1000, 1000

        X_close, X_corrupt, X_far = create_resnet50_embedded_ood_datasets(
            n_close, n_corrupt, n_far,
            cache_path="ood_all_embeddings.pt"
            )
        model_ensemble = ensemble_models 
        for name, X_ood in [("close", X_close), ("corrupt", X_corrupt), ("far", X_far)]:
            ensemble_probs_od = deep_ensemble_predict(X_ood, model_ensemble)
            ensemble_probs_od_np = ensemble_probs_od.cpu().numpy()
            total_entropy_od = -np.sum(ensemble_probs_od_np * np.log(ensemble_probs_od_np + eps), axis=1)
            avg_total_entropy_od = np.mean(total_entropy_od)

            replicate_total_entropy_od[name].append(avg_total_entropy_od)
            
            all_model_probs_od = []
            for model in ensemble_models:
                model.eval()
                with torch.no_grad():
                    logits = model(X_ood)
                    probs = F.softmax(logits, dim=1)
                    all_model_probs_od.append(probs.cpu().numpy())
            all_model_probs_od = np.stack(all_model_probs_od, axis=0)
            model_entropies_od = -np.sum(all_model_probs_od * np.log(all_model_probs_od + eps), axis=2)
            avg_model_entropy_od = np.mean(model_entropies_od, axis=0)
            mi_od = total_entropy_od - avg_model_entropy_od
            avg_epistemic_od = np.mean(mi_od)

            replicate_epistemic_od[name].append(avg_epistemic_od)
            
            print(f"Replicate {r}, OOD ({name})")
            print("Average Total Entropy: {:.6f}".format(avg_total_entropy_od))
            print("Average Epistemic Uncertainty: {:.6f}".format(avg_epistemic_od))
        
        ############################################
        # Store replicate metrics.
        ############################################
        replicate_epochs.append(avg_epochs)
        replicate_nll_in.append(avg_nll_in)
        replicate_total_entropy_in.append(avg_total_entropy_in)
        replicate_epistemic_in.append(avg_epistemic_in)
        replicate_f1_in.append(f1_in)
        replicate_aucpr_in.append(aucpr_in)
        replicate_accuracy.append(accuracy_in * 100)  # percentage

        replicate_precision.append(precision_val)
        replicate_recall.append(recall_val)
        replicate_brier.append(brier)
        replicate_ece.append(ece)
        
        replicate_total_entropy_in_correct.append(total_entropy_in_correct)
        replicate_epistemic_in_correct.append(epistemic_in_correct)
        replicate_total_entropy_in_incorrect.append(total_entropy_in_incorrect)
        replicate_epistemic_in_incorrect.append(epistemic_in_incorrect)
    
    ############################################
    # After all replicates: Compute overall averages and standard errors.
    ############################################
    def compute_stats(metric_list):
        metric_array = np.array(metric_list)
        mean_val = np.mean(metric_array)
        stderr_val = np.std(metric_array, ddof=1) / np.sqrt(R)
        return mean_val, stderr_val

    avg_epochs_mean, avg_epochs_stderr         = compute_stats(replicate_epochs)
    nll_in_mean, nll_in_stderr                 = compute_stats(replicate_nll_in)
    tot_ent_in_mean, tot_ent_in_stderr         = compute_stats(replicate_total_entropy_in)
    epistemic_in_mean, epistemic_in_stderr     = compute_stats(replicate_epistemic_in)
    f1_in_mean, f1_in_stderr                   = compute_stats(replicate_f1_in)
    aucpr_in_mean, aucpr_in_stderr             = compute_stats(replicate_aucpr_in)
    accuracy_mean, accuracy_stderr             = compute_stats(replicate_accuracy)
    tot_ent_od_close_mean, tot_ent_od_close_stderr         = compute_stats(replicate_total_entropy_od["close"])
    tot_ent_od_corrupt_mean, tot_ent_od_corrupt_stderr         = compute_stats(replicate_total_entropy_od["corrupt"])
    tot_ent_od_far_mean, tot_ent_od_far_stderr         = compute_stats(replicate_total_entropy_od["far"])
    epistemic_od_close_mean, epistemic_od_close_stderr     = compute_stats(replicate_epistemic_od["close"])
    epistemic_od_corrupt_mean, epistemic_od_corrupt_stderr     = compute_stats(replicate_epistemic_od["corrupt"])
    epistemic_od_far_mean, epistemic_od_far_stderr     = compute_stats(replicate_epistemic_od["far"])

    mean_precision, stderr_precision = compute_stats(replicate_precision)
    mean_recall,    stderr_recall    = compute_stats(replicate_recall)
    mean_brier, stderr_brier = compute_stats(replicate_brier)
    mean_ece,   stderr_ece   = compute_stats(replicate_ece)
    
    # Compute stats for the in-domain breakdown.
    tot_ent_in_corr_mean, tot_ent_in_corr_stderr = compute_stats(replicate_total_entropy_in_correct)
    epistemic_in_corr_mean, epistemic_in_corr_stderr = compute_stats(replicate_epistemic_in_correct)
    tot_ent_in_inc_mean, tot_ent_in_inc_stderr = compute_stats(replicate_total_entropy_in_incorrect)
    epistemic_in_inc_mean, epistemic_in_inc_stderr = compute_stats(replicate_epistemic_in_incorrect)

    print("\n===== Summary over {} replicates =====".format(R))
    print("Avg. Ensemble Epochs: Mean = {:.6f}, SE = {:.6f}".format(avg_epochs_mean, avg_epochs_stderr))
    print("In-Domain NLL: Mean = {:.6f}, SE = {:.6f}".format(nll_in_mean, nll_in_stderr))
    print(f"Brier Score:   {mean_brier:.4f} ± {stderr_brier:.4f}")
    print(f"ECE:           {mean_ece:.4f} ± {stderr_ece:.4f}")
    print(f"Precision:     {mean_precision:.4f} ± {stderr_precision:.4f}")
    print(f"Recall:        {mean_recall:.4f} ± {stderr_recall:.4f}")
    print("In-Domain Total Entropy: Mean = {:.6f}, SE = {:.6f}".format(tot_ent_in_mean, tot_ent_in_stderr))
    print("In-Domain Epistemic Uncertainty: Mean = {:.6f}, SE = {:.6f}".format(epistemic_in_mean, epistemic_in_stderr))
    print("In-Domain F1 Score: Mean = {:.6f}, SE = {:.6f}".format(f1_in_mean, f1_in_stderr))
    print("In-Domain AUC-PR: Mean = {:.6f}, SE = {:.6f}".format(aucpr_in_mean, aucpr_in_stderr))
    print("In-Domain Accuracy (%): Mean = {:.6f}, SE = {:.6f}".format(accuracy_mean, accuracy_stderr))
    print("\nBreakdown for In-Domain Predictions:")
    print("Correct Predictions - Total Entropy: Mean = {:.6f}, SE = {:.6f}".format(tot_ent_in_corr_mean, tot_ent_in_corr_stderr))
    print("Correct Predictions - Epistemic Uncertainty: Mean = {:.6f}, SE = {:.6f}".format(epistemic_in_corr_mean, epistemic_in_corr_stderr))
    print("Incorrect Predictions - Total Entropy: Mean = {:.6f}, SE = {:.6f}".format(tot_ent_in_inc_mean, tot_ent_in_inc_stderr))
    print("Incorrect Predictions - Epistemic Uncertainty: Mean = {:.6f}, SE = {:.6f}".format(epistemic_in_inc_mean, epistemic_in_inc_stderr))
    print("OOD Close Total Entropy: Mean = {:.6f}, SE = {:.6f}".format(tot_ent_od_close_mean, tot_ent_od_close_stderr))
    print("OOD Close Epistemic Uncertainty: Mean = {:.6f}, SE = {:.6f}".format(epistemic_od_close_mean, epistemic_od_close_stderr))
    print("OOD Corrupt Total Entropy: Mean = {:.6f}, SE = {:.6f}".format(tot_ent_od_corrupt_mean, tot_ent_od_corrupt_stderr))
    print("OOD Corrupt Epistemic Uncertainty: Mean = {:.6f}, SE = {:.6f}".format(epistemic_od_corrupt_mean, epistemic_od_corrupt_stderr))
    print("OOD Far Total Entropy: Mean = {:.6f}, SE = {:.6f}".format(tot_ent_od_far_mean, tot_ent_od_far_stderr))
    print("OOD Far Epistemic Uncertainty: Mean = {:.6f}, SE = {:.6f}".format(epistemic_od_far_mean, epistemic_od_far_stderr))
    
    ############################################
    # Save all metrics and replicate results.
    ############################################
    metrics = {
        'R': R,
        'ensemble_size': ensemble_size,
        'replicate_de_time': replicate_de_time,
        'replicate_epochs': np.array(replicate_epochs),
        'replicate_nll': np.array(replicate_nll_in),
        'replicate_total_entropy': np.array(replicate_total_entropy_in),
        'replicate_epistemic_entropy': np.array(replicate_epistemic_in),
        'replicate_f1_score': np.array(replicate_f1_in),
        'replicate_auc_pr': np.array(replicate_aucpr_in),
        'replicate_accuracy': np.array(replicate_accuracy),
        'replicate_correct_total_entropy': np.array(replicate_total_entropy_in_correct),
        'replicate_correct_epistemic_entropy': np.array(replicate_epistemic_in_correct),
        'replicate_incorrect_total_entropy': np.array(replicate_total_entropy_in_incorrect),
        'replicate_incorrect_epistemic_entropy': np.array(replicate_epistemic_in_incorrect),
        'replicate_total_entropy_od_close': np.array(replicate_total_entropy_od['close']),
        'replicate_epistemic_od_close': np.array(replicate_epistemic_od['close']),
        'replicate_total_entropy_od_corrupt': np.array(replicate_total_entropy_od['corrupt']),
        'replicate_epistemic_od_corrupt': np.array(replicate_epistemic_od['corrupt']),
        'replicate_total_entropy_od_far': np.array(replicate_total_entropy_od['far']),
        'replicate_epistemic_od_far': np.array(replicate_epistemic_od['far']),
        'replicate_brier': replicate_brier,
        'replicate_ece': replicate_ece,
        'replicate_precision': replicate_precision,
        'replicate_recall': replicate_recall,
        # Summary statistics.
        # 'avg_epochs_mean': avg_epochs_mean,
        # 'avg_epochs_stderr': avg_epochs_stderr,
        # 'nll_in_mean': nll_in_mean,
        # 'nll_in_stderr': nll_in_stderr,
        # 'tot_ent_in_mean': tot_ent_in_mean,
        # 'tot_ent_in_stderr': tot_ent_in_stderr,
        # 'epistemic_in_mean': epistemic_in_mean,
        # 'epistemic_in_stderr': epistemic_in_stderr,
        # 'f1_in_mean': f1_in_mean,
        # 'f1_in_stderr': f1_in_stderr,
        # 'aucpr_in_mean': aucpr_in_mean,
        # 'aucpr_in_stderr': aucpr_in_stderr,
        # 'accuracy_mean': accuracy_mean,
        # 'accuracy_stderr': accuracy_stderr,
        # 'tot_ent_od_mean': tot_ent_od_mean,
        # 'tot_ent_od_stderr': tot_ent_od_stderr,
        # 'epistemic_od_mean': epistemic_od_mean,
        # 'epistemic_od_stderr': epistemic_od_stderr,
        # 'tot_ent_in_correct_mean': tot_ent_in_corr_mean,
        # 'tot_ent_in_correct_stderr': tot_ent_in_corr_stderr,
        # 'epistemic_in_correct_mean': epistemic_in_corr_mean,
        # 'epistemic_in_correct_stderr': epistemic_in_corr_stderr,
        # 'tot_ent_in_incorrect_mean': tot_ent_in_inc_mean,
        # 'tot_ent_in_incorrect_stderr': tot_ent_in_inc_stderr,
        # 'epistemic_in_incorrect_mean': epistemic_in_inc_mean,
        # 'epistemic_in_incorrect_stderr': epistemic_in_inc_stderr,
    }
    
    savemat('BayesianNN_CIFAR_DE_metrics.mat', metrics)
    print("\nAll computed metrics (and raw replicate results) have been saved to 'BayesianNN_CIFAR_DE_metrics.mat'.")
