#!/usr/bin/env python3
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
from scipy.io import savemat
import pyro

# For F1 and AUC-PR metrics.
from sklearn.metrics import f1_score, average_precision_score
from sklearn.preprocessing import label_binarize

# Set device.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

############################################
# Global Prior Parameters (same as MAP)
############################################
v = 0.1
sdb = np.sqrt(v)         # standard deviation for biases
var_conv = v    # variance for conv weights (std = sqrt(0.01)=0.1)
var_fc = v      # variance for fc weights   (std = sqrt(0.01)=0.1)

############################################
# Define the CNN Architecture for Filtered MNIST
############################################
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv = nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(4 * 14 * 14, 8)
        
        # Initialize weights from Normal distributions matching the regularization priors.
        nn.init.normal_(self.conv.weight, mean=0, std=np.sqrt(var_conv))
        if self.conv.bias is not None:
            nn.init.normal_(self.conv.bias, mean=0, std=sdb)
        nn.init.normal_(self.fc.weight, mean=0, std=np.sqrt(var_fc))
        if self.fc.bias is not None:
            nn.init.normal_(self.fc.bias, mean=0, std=sdb)
    
    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

############################################
# Define a filtered dataset (keep digits 0–7)
############################################
class FilteredDataset(Dataset):
    def __init__(self, dataset, allowed_labels):
        self.data = [(img, label) for img, label in dataset if label in allowed_labels]
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)

############################################
# Deep Ensemble Prediction Function
############################################
def deep_ensemble_predict(x, models):
    preds = []
    for model in models:
        model.eval()
        with torch.no_grad():
            logits = model(x.to(device))
            probs = F.softmax(logits, dim=1)
            preds.append(probs)
    # Average the softmax outputs of all ensemble members.
    ensemble_probs = torch.stack(preds, dim=0).mean(dim=0)
    return ensemble_probs

############################################
#      Metrics computation: Brier and ECE
############################################
def compute_brier(probs, labels):
    """
    Multi-class Brier score: 
      (1/N) ∑_i ∑_c (p_{i,c} – 1{y_i=c})^2.
    """
    # probs: shape (N, C), rows sum to 1
    # labels: shape (N,), ints in [0..C-1]
    N, C = probs.shape
    # one-hot encode
    one_hot = np.zeros_like(probs)
    one_hot[np.arange(N), labels] = 1
    return np.mean(np.sum((probs - one_hot)**2, axis=1))


# def compute_ece(probs, labels, n_bins=10):
#     """Expected Calibration Error over `n_bins` equally‐spaced bins."""
#     bins = np.linspace(0.0, 1.0, n_bins + 1)
#     ece = 0.0
#     for i in range(n_bins):
#         mask = (probs >= bins[i]) & (probs < bins[i+1])
#         if not np.any(mask):
#             continue
#         preds_bin = (probs[mask] >= 0.5).astype(int)
#         acc_bin   = (labels[mask] == preds_bin).mean()
#         conf_bin  = probs[mask].mean()
#         ece      += np.abs(acc_bin - conf_bin) * (mask.sum() / len(probs))
#     return ece
def compute_ece(probs, labels, n_bins=10):
    """
    Expected Calibration Error for multiclass outputs.
    
    Args:
      probs: array of shape (N, C) of predicted class probabilities (each row sums to 1)
      labels: array of shape (N,) of integer true labels in [0..C-1]
      n_bins: number of equal-width confidence bins in [0,1]
    
    Returns:
      scalar ECE = sum_b (|acc(b) – conf(b)| · |bin_b|/N)
    """
    # 1) For each example, get predicted class and its confidence
    preds       = np.argmax(probs, axis=1)                             # shape (N,)
    confidences = probs[np.arange(len(probs)), preds]                 # shape (N,)
    correct     = (preds == labels).astype(float)                     # shape (N,)
    
    # 2) Build bins over [0,1]
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    
    # 3) Loop through bins
    for i in range(n_bins):
        low, high = bin_edges[i], bin_edges[i+1]
        if i == n_bins - 1:
            in_bin = (confidences >= low) & (confidences <= high)
        else:
            in_bin = (confidences >= low) & (confidences < high)
        prop_in_bin = in_bin.mean()
        if prop_in_bin == 0:
            continue
        
        # average accuracy and average confidence in this bin
        acc_bin  = correct[in_bin].mean()
        conf_bin = confidences[in_bin].mean()
        
        # accumulate weighted gap
        ece += prop_in_bin * abs(acc_bin - conf_bin)
    
    return ece

############################################
# Main: Run R replicates of DE on Filtered MNIST
############################################
if __name__ == '__main__':
    # Number of replicates (R) and ensemble members.
    R = 5
    ensemble_size = 80

    # Lists to store metrics for each replicate.
    replicate_epochs = []         # average epochs per replicate
    replicate_accuracy = []
    replicate_nll_in = []
    replicate_total_entropy_in = []
    replicate_epistemic_in = []
    replicate_f1_in = []
    replicate_aucpr_in = []
    # For OOD (aggregated over all OOD samples)
    replicate_total_entropy_od = []
    replicate_epistemic_od = []
    replicate_f1_od = []
    replicate_aucpr_od = []
    # In-domain breakdown.
    replicate_total_entropy_in_correct = []
    replicate_epistemic_in_correct = []
    replicate_total_entropy_in_incorrect = []
    replicate_epistemic_in_incorrect = []
    # Detailed OOD breakdown.
    replicate_total_entropy_ood_8 = []
    replicate_epistemic_ood_8 = []
    replicate_total_entropy_ood_9 = []
    replicate_epistemic_ood_9 = []
    replicate_total_entropy_ood_perturbed = []
    replicate_epistemic_ood_perturbed = []
    replicate_total_entropy_ood_whitenoise = []
    replicate_epistemic_ood_whitenoise = []
    # In-domain group breakdown.
    per_digit_total_entropy_list = []
    per_digit_epistemic_entropy_list = []
    total_entropy_perturbed_list = []
    epistemic_entropy_perturbed_list = []
    total_entropy_whitenoise_list = []
    epistemic_entropy_whitenoise_list = []
    inID_total_list = []
    inID_epi_list = []

    replicate_brier = []
    replicate_ece   = []

     ############################################
     # Data Loading (shared across replicates)
     ############################################
    transform = transforms.Compose([transforms.ToTensor()])
    full_train_dataset = torchvision.datasets.MNIST(root='./data', train=True,  download=True, transform=transform)
    full_val_dataset   = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_dataset       = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    allowed_labels = list(range(8))
 
     # Sizes
    N_tr   = 1000   # train for DE
    N_val  = 200    # early-stop val for DE
    N_test = 7000   # ID test

    # Filter & split first 2000 train images into 1600/400
    filtered_all = FilteredDataset(full_train_dataset, allowed_labels)
    filtered_all = Subset(filtered_all, list(range(N_tr + N_val)))
    filtered_train_dataset = Subset(filtered_all, list(range(N_tr)))
    filtered_val_dataset   = Subset(filtered_all, list(range(N_tr, N_tr + N_val)))
 
    train_loader = DataLoader(filtered_train_dataset, batch_size=64, shuffle=True)
    val_loader   = DataLoader(filtered_val_dataset,   batch_size=64, shuffle=False)
 
    # Build ID test set: first 1000 filtered test images (0–7)
    filtered_test = FilteredDataset(test_dataset, allowed_labels)
    filtered_test = Subset(filtered_test, list(range(N_test)))
    test_loader   = DataLoader(filtered_test, batch_size=64, shuffle=False)
    # Full‐batch for evaluation
    test_loader_full = DataLoader(filtered_test, batch_size=len(filtered_test), shuffle=False)
    x_test, y_test   = next(iter(test_loader_full))
    x_test, y_test   = x_test.to(device), y_test.to(device)
    labels_val_np = y_test.cpu().numpy().flatten()

    # OOD sets from the same raw test_dataset:
    # Digit 8
    od8_indices = []; count8 = 0
    for idx, (_, lb) in enumerate(test_dataset):
        if lb==8 and count8<100:
            od8_indices.append(idx); count8+=1
    od8_loader = DataLoader(Subset(test_dataset, od8_indices), batch_size=100, shuffle=False)
    # Digit 9
    od9_indices = []; count9 = 0
    for idx, (_, lb) in enumerate(test_dataset):
        if lb==9 and count9<100:
            od9_indices.append(idx); count9+=1
    od9_loader = DataLoader(Subset(test_dataset, od9_indices), batch_size=100, shuffle=False)

    random.seed(2)
    np.random.seed(2)
    torch.manual_seed(2)
    pyro.set_rng_seed(2)
    # perturbed images
    perturbed_loader = DataLoader(filtered_val_dataset, batch_size=100, shuffle=False)
    images_perturbed, _ = next(iter(perturbed_loader))
    images_perturbed = images_perturbed + 0.5 * torch.randn_like(images_perturbed)
    images_perturbed = torch.clamp(images_perturbed, 0, 1).to(device)
    # white noise
    white_noise_images = torch.rand(100, 1, 28, 28).to(device)

    for r in range(1, R+1):
        print(f"\n===== Starting replicate {r} =====\n")
        # Set replicate seed.
        torch.manual_seed(r)
        np.random.seed(r)
        random.seed(r)
        pyro.set_rng_seed(r)
        
        ############################################
        # Deep Ensembles (DE) Training for replicate r
        ############################################
        ensemble_models = []
        ensemble_epochs = []
        for m in range(ensemble_size):
            member_seed = r * 1000 + m  
            torch.manual_seed(member_seed)
            np.random.seed(member_seed)
            random.seed(member_seed)
            pyro.set_rng_seed(member_seed)
            
            model_de = SimpleCNN().to(device)
            optimizer_de = optim.Adam(model_de.parameters(), lr=0.001)
            criterion_de = nn.CrossEntropyLoss()
            
            train_losses_de = []
            val_losses_de = []
            moving_avg_window = 10
            best_moving_avg_de = float('inf')
            patience = 5
            no_improve_count_de = 0

            start_time_de = time.time()
            for epoch in range(1000):
                model_de.train()
                running_loss = 0.0
                for images, labels in train_loader:
                    optimizer_de.zero_grad()
                    outputs = model_de(images.to(device))
                    ce_loss = criterion_de(outputs, labels.to(device))
                    reg_conv = (torch.sum(model_de.conv.weight**2) / (2 * var_conv) +
                                torch.sum(model_de.conv.bias**2) / (2 * sdb**2))
                    reg_fc = (torch.sum(model_de.fc.weight**2) / (2 * var_fc) +
                              torch.sum(model_de.fc.bias**2) / (2 * sdb**2))
                    reg_loss = reg_conv + reg_fc
                    loss = ce_loss + reg_loss/len(filtered_train_dataset)
                    loss.backward()
                    optimizer_de.step()
                    running_loss += loss.item()
                train_loss = running_loss / len(train_loader)
                train_losses_de.append(train_loss)
                
                model_de.eval()
                val_running_loss = 0.0
                with torch.no_grad():
                    for images, labels in val_loader:
                        outputs = model_de(images.to(device))
                        ce_loss = criterion_de(outputs, labels.to(device))
                        reg_conv = (torch.sum(model_de.conv.weight**2) / (2 * var_conv) +
                                    torch.sum(model_de.conv.bias**2) / (2 * sdb**2))
                        reg_fc = (torch.sum(model_de.fc.weight**2) / (2 * var_fc) +
                                  torch.sum(model_de.fc.bias**2) / (2 * sdb**2))
                        reg_loss = reg_conv + reg_fc
                        loss = ce_loss + reg_loss/len(filtered_train_dataset)
                        val_running_loss += loss.item()
                val_loss = val_running_loss / len(val_loader)
                val_losses_de.append(val_loss)
                
                print(f"Replicate {r}, DE Model (seed={member_seed}) Epoch {epoch+1}: Train Loss = {train_loss:.8f}, Val Loss = {val_loss:.8f}")
                
                if epoch >= moving_avg_window - 1:
                    moving_avg = sum(val_losses_de[-moving_avg_window:]) / moving_avg_window
                    if moving_avg < best_moving_avg_de:
                        best_moving_avg_de = moving_avg
                        no_improve_count_de = 0
                    else:
                        no_improve_count_de += 1
                    if no_improve_count_de >= patience:
                        print(f"Replicate {r}, DE Model (seed={member_seed}) Early stopping at epoch {epoch+1}")
                        break
            
            total_time_de = time.time() - start_time_de
            epochs_trained = epoch + 1
            ensemble_epochs.append(epochs_trained)
            print(f"Replicate {r}, DE Model (seed={member_seed}) Total training time: {total_time_de:.2f} seconds")
            print(f"Trained DE model with seed {member_seed} in {epochs_trained} epochs\n")
            ensemble_models.append(model_de)
        
        avg_epochs = sum(ensemble_epochs) / len(ensemble_epochs)
        print(f"Replicate {r}, Average training epochs over ensembles: {avg_epochs:.2f}")
        
        ############################################
        # Evaluate Deep Ensemble on In-Domain Data (digits 0-7)
        ############################################
        ensemble_probs = deep_ensemble_predict(x_test, ensemble_models)
        ensemble_probs_np = ensemble_probs.cpu().numpy()
        nlls = -np.log(ensemble_probs_np[np.arange(len(labels_val_np)), labels_val_np] + 1e-12)
        avg_nll_in = np.mean(nlls)
        total_entropy = -np.sum(ensemble_probs_np * np.log(ensemble_probs_np + 1e-12), axis=1)
        avg_total_entropy_in = np.mean(total_entropy)
        
        # Compute per-model predictions to estimate epistemic uncertainty.
        all_model_probs = []
        for model in ensemble_models:
            model.eval()
            with torch.no_grad():
                logits = model(x_test)
                probs = F.softmax(logits, dim=1)
                all_model_probs.append(probs.cpu().numpy())
        all_model_probs = np.stack(all_model_probs, axis=0)
        model_entropies = -np.sum(all_model_probs * np.log(all_model_probs + 1e-12), axis=2)
        avg_model_entropy = np.mean(model_entropies, axis=0)
        mi = total_entropy - avg_model_entropy
        avg_epistemic_in = np.mean(mi)
        
        # In-domain breakdown: Correct vs. Incorrect.
        ensemble_preds = np.argmax(ensemble_probs_np, axis=1)
        correct_mask = (ensemble_preds == labels_val_np)
        incorrect_mask = (ensemble_preds != labels_val_np)
        if np.sum(correct_mask) > 0:
            total_entropy_in_correct = np.mean(total_entropy[correct_mask])
            epistemic_in_correct = np.mean(mi[correct_mask])
        else:
            total_entropy_in_correct = np.nan
            epistemic_in_correct = np.nan
        if np.sum(incorrect_mask) > 0:
            total_entropy_in_incorrect = np.mean(total_entropy[incorrect_mask])
            epistemic_in_incorrect = np.mean(mi[incorrect_mask])
        else:
            total_entropy_in_incorrect = np.nan
            epistemic_in_incorrect = np.nan
        
        # Compute in-domain classification metrics.
        f1_in = f1_score(labels_val_np, ensemble_preds, average='macro', zero_division=0)
        labels_val_bin = label_binarize(labels_val_np, classes=np.arange(8))
        aucpr_in = average_precision_score(labels_val_bin, ensemble_probs_np, average='macro')
        accuracy_in = np.mean(ensemble_preds == labels_val_np)

        brier = compute_brier(ensemble_probs_np, labels_val_np)
        # after computing y_pred_in = argmax
        #confidences = ensemble_probs_np[np.arange(len(ensemble_probs_np)), ensemble_preds]
        #correctness  = (ensemble_preds == labels_val_np).astype(int)
        ece = compute_ece(ensemble_probs_np, labels_val_np, n_bins=10)
        replicate_brier.append(brier)
        replicate_ece.append(ece)
        
        # ----- NEW: Per-digit breakdown for in-domain (digits 0-7)
        per_digit_total = []
        per_digit_epi = []
        for d in allowed_labels:
            mask = (labels_val_np == d)
            if np.sum(mask) > 0:
                per_digit_total.append(np.mean(total_entropy[mask]))
                per_digit_epi.append(np.mean(mi[mask]))
            else:
                per_digit_total.append(np.nan)
                per_digit_epi.append(np.nan)
        per_digit_total_entropy_list.append(np.array(per_digit_total))
        per_digit_epistemic_entropy_list.append(np.array(per_digit_epi))
        
        # Save overall in-domain averages.
        inID_total_list.append(avg_total_entropy_in)
        inID_epi_list.append(avg_epistemic_in)
        
        print("\nReplicate {} In-Domain Analysis (Digits 0-7):".format(r))
        print("Average NLL: {:.8f}".format(avg_nll_in))
        print("Average Total Entropy: {:.8f}".format(avg_total_entropy_in))
        print("Average Epistemic Uncertainty: {:.8f}".format(avg_epistemic_in))
        print("F1 Score (macro): {:.8f}".format(f1_in))
        print("AUC-PR (macro): {:.8f}".format(aucpr_in))
        print("Validation Accuracy: {:.8f}%".format(accuracy_in * 100))
        print("\nBreakdown for In-Domain Predictions:")
        print("Correct Predictions - Total Entropy: {:.8f}".format(total_entropy_in_correct))
        print("Correct Predictions - Epistemic Uncertainty: {:.8f}".format(epistemic_in_correct))
        print("Incorrect Predictions - Total Entropy: {:.8f}".format(total_entropy_in_incorrect))
        print("Incorrect Predictions - Epistemic Uncertainty: {:.8f}".format(epistemic_in_incorrect))
        
        ############################################
        # Extended Out-Of-Domain Analysis
        ############################################
        def compute_ood_metrics(images):
            ensemble_probs_group = deep_ensemble_predict(images, ensemble_models)
            ensemble_probs_group_np = ensemble_probs_group.cpu().numpy()
            total_entropy_group = -np.sum(ensemble_probs_group_np * np.log(ensemble_probs_group_np + 1e-12), axis=1)
            avg_total_entropy_group = np.mean(total_entropy_group)
            all_model_probs_group = []
            for model in ensemble_models:
                model.eval()
                with torch.no_grad():
                    logits = model(images)
                    probs = F.softmax(logits, dim=1)
                    all_model_probs_group.append(probs.cpu().numpy())
            all_model_probs_group = np.stack(all_model_probs_group, axis=0)
            model_entropies_group = -np.sum(all_model_probs_group * np.log(all_model_probs_group + 1e-12), axis=2)
            avg_model_entropy_group = np.mean(model_entropies_group, axis=0)
            mi_group = total_entropy_group - avg_model_entropy_group
            avg_epistemic_group = np.mean(mi_group)
            return avg_total_entropy_group, avg_epistemic_group

        # OOD Group: Digit 8.
        ood8_indices = []
        count8 = 0
        for idx, (img, label) in enumerate(full_val_dataset):
            if label == 8 and count8 < 100:
                ood8_indices.append(idx)
                count8 += 1
            if count8 == 100:
                break
        ood8_dataset = Subset(full_val_dataset, ood8_indices)
        ood8_loader = DataLoader(ood8_dataset, batch_size=len(ood8_dataset), shuffle=False)
        images_ood8, _ = next(iter(ood8_loader))
        images_ood8 = images_ood8.to(device)
        entropy_ood8, epistemic_ood8 = compute_ood_metrics(images_ood8)
        
        # OOD Group: Digit 9.
        ood9_indices = []
        count9 = 0
        for idx, (img, label) in enumerate(full_val_dataset):
            if label == 9 and count9 < 100:
                ood9_indices.append(idx)
                count9 += 1
            if count9 == 100:
                break
        ood9_dataset = Subset(full_val_dataset, ood9_indices)
        ood9_loader = DataLoader(ood9_dataset, batch_size=len(ood9_dataset), shuffle=False)
        images_ood9, _ = next(iter(ood9_loader))
        images_ood9 = images_ood9.to(device)
        entropy_ood9, epistemic_ood9 = compute_ood_metrics(images_ood9)
        
        # OOD Group: Perturbed in-domain images.
        entropy_perturbed, epistemic_perturbed = compute_ood_metrics(images_perturbed)
        total_entropy_perturbed_list.append(entropy_perturbed)
        epistemic_entropy_perturbed_list.append(epistemic_perturbed)
        
        # OOD Group: White noise images.
        entropy_white, epistemic_white = compute_ood_metrics(white_noise_images)
        total_entropy_whitenoise_list.append(entropy_white)
        epistemic_entropy_whitenoise_list.append(epistemic_white)
        
        # Combined groups (for backward compatibility, aggregate all OOD).
        images_all_ood = torch.cat([images_ood8, images_ood9, images_perturbed, white_noise_images], dim=0)
        entropy_all_ood, epistemic_all_ood = compute_ood_metrics(images_all_ood)
        
        print("\nReplicate {} Extended Out-Of-Domain Analysis:".format(r))
        print("Digit 8 - Average Total Entropy: {:.8f}, Average Epistemic Uncertainty: {:.8f}".format(entropy_ood8, epistemic_ood8))
        print("Digit 9 - Average Total Entropy: {:.8f}, Average Epistemic Uncertainty: {:.8f}".format(entropy_ood9, epistemic_ood9))
        print("Perturbed - Average Total Entropy: {:.8f}, Average Epistemic Uncertainty: {:.8f}".format(entropy_perturbed, epistemic_perturbed))
        print("White Noise - Average Total Entropy: {:.8f}, Average Epistemic Uncertainty: {:.8f}".format(entropy_white, epistemic_white))
        print("All OOD - Average Total Entropy: {:.8f}, Average Epistemic Uncertainty: {:.8f}".format(entropy_all_ood, epistemic_all_ood))
        
        # ----- NEW: Save detailed OOD metrics per group.
        replicate_total_entropy_ood_8.append(entropy_ood8)
        replicate_epistemic_ood_8.append(epistemic_ood8)
        replicate_total_entropy_ood_9.append(entropy_ood9)
        replicate_epistemic_ood_9.append(epistemic_ood9)
        replicate_total_entropy_ood_perturbed.append(entropy_perturbed)
        replicate_epistemic_ood_perturbed.append(epistemic_perturbed)
        replicate_total_entropy_ood_whitenoise.append(entropy_white)
        replicate_epistemic_ood_whitenoise.append(epistemic_white)
        
        # For backward compatibility, store aggregated OOD metrics.
        replicate_total_entropy_od.append(entropy_all_ood)
        replicate_epistemic_od.append(epistemic_all_ood)
        replicate_f1_od.append(0.0)
        replicate_aucpr_od.append(0.0)
        
        # Store replicate metrics.
        replicate_epochs.append(avg_epochs)
        replicate_nll_in.append(avg_nll_in)
        replicate_total_entropy_in.append(avg_total_entropy_in)
        replicate_epistemic_in.append(avg_epistemic_in)
        replicate_f1_in.append(f1_in)
        replicate_aucpr_in.append(aucpr_in)
        replicate_accuracy.append(accuracy_in * 100)
        replicate_total_entropy_in_correct.append(total_entropy_in_correct)
        replicate_epistemic_in_correct.append(epistemic_in_correct)
        replicate_total_entropy_in_incorrect.append(total_entropy_in_incorrect)
        replicate_epistemic_in_incorrect.append(epistemic_in_incorrect)
    
    ############################################
    # After all replicates: Compute overall averages and standard errors.
    ############################################
    def compute_stats(metric_list):
        metric_array = np.array(metric_list)
        mean_val = np.mean(metric_array)
        stderr_val = np.std(metric_array, ddof=1) / np.sqrt(R)
        return mean_val, stderr_val

    avg_epochs_mean, avg_epochs_stderr = compute_stats(replicate_epochs)
    nll_in_mean, nll_in_stderr = compute_stats(replicate_nll_in)
    tot_ent_in_mean, tot_ent_in_stderr = compute_stats(replicate_total_entropy_in)
    epistemic_in_mean, epistemic_in_stderr = compute_stats(replicate_epistemic_in)
    f1_in_mean, f1_in_stderr = compute_stats(replicate_f1_in)
    aucpr_in_mean, aucpr_in_stderr = compute_stats(replicate_aucpr_in)
    accuracy_mean, accuracy_stderr = compute_stats(replicate_accuracy)
    tot_ent_od_mean, tot_ent_od_stderr = compute_stats(replicate_total_entropy_od)
    epistemic_od_mean, epistemic_od_stderr = compute_stats(replicate_epistemic_od)
    f1_od_mean, f1_od_stderr = compute_stats(replicate_f1_od)
    aucpr_od_mean, aucpr_od_stderr = compute_stats(replicate_aucpr_od)
    
    tot_ent_in_corr_mean, tot_ent_in_corr_stderr = compute_stats(replicate_total_entropy_in_correct)
    epistemic_in_corr_mean, epistemic_in_corr_stderr = compute_stats(replicate_epistemic_in_correct)
    tot_ent_in_inc_mean, tot_ent_in_inc_stderr = compute_stats(replicate_total_entropy_in_incorrect)
    epistemic_in_inc_mean, epistemic_in_inc_stderr = compute_stats(replicate_epistemic_in_incorrect)

    mean_brier, stderr_brier = compute_stats(replicate_brier)
    mean_ece,   stderr_ece   = compute_stats(replicate_ece)
    
    # ----- NEW: Average per-digit breakdown for in-domain.
    per_digit_total_entropy = np.mean(np.stack(per_digit_total_entropy_list, axis=0), axis=0)
    per_digit_epistemic_entropy = np.mean(np.stack(per_digit_epistemic_entropy_list, axis=0), axis=0)
    # Averages for perturbed and white noise.
    total_entropy_perturbed = np.mean(total_entropy_perturbed_list)
    epistemic_entropy_perturbed = np.mean(epistemic_entropy_perturbed_list)
    total_entropy_whitenoise = np.mean(total_entropy_whitenoise_list)
    epistemic_entropy_whitenoise = np.mean(epistemic_entropy_whitenoise_list)
    # Overall in-domain (ID) metrics.
    total_entropy_inID = np.mean(inID_total_list)
    epistemic_inID = np.mean(inID_epi_list)
    
    print("\n===== Summary over {} replicates =====".format(R))
    print("Avg. Ensemble Epochs: Mean = {:.2f}, SE = {:.8f}".format(avg_epochs_mean, avg_epochs_stderr))
    print("In-Domain NLL: Mean = {:.8f}, SE = {:.8f}".format(nll_in_mean, nll_in_stderr))
    print("In-Domain Total Entropy: Mean = {:.8f}, SE = {:.8f}".format(tot_ent_in_mean, tot_ent_in_stderr))
    print("In-Domain Epistemic Uncertainty: Mean = {:.8f}, SE = {:.8f}".format(epistemic_in_mean, epistemic_in_stderr))
    print("In-Domain F1 Score: Mean = {:.8f}, SE = {:.8f}".format(f1_in_mean, f1_in_stderr))
    print("In-Domain AUC-PR: Mean = {:.8f}, SE = {:.8f}".format(aucpr_in_mean, aucpr_in_stderr))
    print("In-Domain Accuracy (%): Mean = {:.8f}, SE = {:.8f}".format(accuracy_mean, accuracy_stderr))
    print(f"Brier Score:   {mean_brier:.8f} ± {stderr_brier:.8f}")
    print(f"ECE:           {mean_ece:.8f} ± {stderr_ece:.8f}")
    print("\nBreakdown for In-Domain Predictions:")
    print("Correct Predictions - Total Entropy: Mean = {:.8f}, SE = {:.8f}".format(tot_ent_in_corr_mean, tot_ent_in_corr_stderr))
    print("Correct Predictions - Epistemic Uncertainty: Mean = {:.8f}, SE = {:.8f}".format(epistemic_in_corr_mean, epistemic_in_corr_stderr))
    print("Incorrect Predictions - Total Entropy: Mean = {:.8f}, SE = {:.8f}".format(tot_ent_in_inc_mean, tot_ent_in_inc_stderr))
    print("Incorrect Predictions - Epistemic Uncertainty: Mean = {:.8f}, SE = {:.8f}".format(epistemic_in_inc_mean, epistemic_in_inc_stderr))
    print("OOD Total Entropy (aggregated): Mean = {:.8f}, SE = {:.8f}".format(tot_ent_od_mean, tot_ent_od_stderr))
    print("OOD Epistemic Uncertainty (aggregated): Mean = {:.8f}, SE = {:.8f}".format(epistemic_od_mean, epistemic_od_stderr))
    print("OOD F1 Score: Mean = {:.8f}, SE = {:.8f}".format(f1_od_mean, f1_od_stderr))
    print("OOD AUC-PR: Mean = {:.8f}, SE = {:.8f}".format(aucpr_od_mean, aucpr_od_stderr))
    
    ############################################
    # Save all metrics and replicate results using savemat.
    ############################################
    metrics = {
        'R': R,
        'ensemble_size': ensemble_size,
        'replicate_epochs': np.array(replicate_epochs),
        'replicate_nll_in': np.array(replicate_nll_in),
        'replicate_total_entropy_in': np.array(replicate_total_entropy_in),
        'replicate_epistemic_in': np.array(replicate_epistemic_in),
        'replicate_f1_in': np.array(replicate_f1_in),
        'replicate_aucpr_in': np.array(replicate_aucpr_in),
        'replicate_accuracy': np.array(replicate_accuracy),
        'replicate_total_entropy_in_correct': np.array(replicate_total_entropy_in_correct),
        'replicate_epistemic_in_correct': np.array(replicate_epistemic_in_correct),
        'replicate_total_entropy_in_incorrect': np.array(replicate_total_entropy_in_incorrect),
        'replicate_epistemic_in_incorrect': np.array(replicate_epistemic_in_incorrect),
        'replicate_total_entropy_od': np.array(replicate_total_entropy_od),
        'replicate_epistemic_od': np.array(replicate_epistemic_od),
        'replicate_f1_od': np.array(replicate_f1_od),
        'replicate_aucpr_od': np.array(replicate_aucpr_od),
        # Detailed OOD breakdown.
        'replicate_total_entropy_ood_8': np.array(replicate_total_entropy_ood_8),
        'replicate_epistemic_ood_8': np.array(replicate_epistemic_ood_8),
        'replicate_total_entropy_ood_9': np.array(replicate_total_entropy_ood_9),
        'replicate_epistemic_ood_9': np.array(replicate_epistemic_ood_9),
        'replicate_total_entropy_ood_perturbed': np.array(replicate_total_entropy_ood_perturbed),
        'replicate_epistemic_ood_perturbed': np.array(replicate_epistemic_ood_perturbed),
        'replicate_total_entropy_ood_whitenoise': np.array(replicate_total_entropy_ood_whitenoise),
        'replicate_epistemic_ood_whitenoise': np.array(replicate_epistemic_ood_whitenoise),
        # In-domain group breakdown (averaged over replicates).
        'per_digit_total_entropy': per_digit_total_entropy,
        'per_digit_epistemic_entropy': per_digit_epistemic_entropy,
        'total_entropy_perturbed': total_entropy_perturbed,
        'epistemic_entropy_perturbed': epistemic_entropy_perturbed,
        'total_entropy_whitenoise': total_entropy_whitenoise,
        'epistemic_entropy_whitenoise': epistemic_entropy_whitenoise,
        'total_entropy_inID': total_entropy_inID,
        'epistemic_entropy_inID': epistemic_inID,
        # Summary statistics.
        'avg_epochs_mean': avg_epochs_mean,
        'avg_epochs_stderr': avg_epochs_stderr,
        'nll_in_mean': nll_in_mean,
        'nll_in_stderr': nll_in_stderr,
        'tot_ent_in_mean': tot_ent_in_mean,
        'tot_ent_in_stderr': tot_ent_in_stderr,
        'epistemic_in_mean': epistemic_in_mean,
        'epistemic_in_stderr': epistemic_in_stderr,
        'f1_in_mean': f1_in_mean,
        'f1_in_stderr': f1_in_stderr,
        'aucpr_in_mean': aucpr_in_mean,
        'aucpr_in_stderr': aucpr_in_stderr,
        'accuracy_mean': accuracy_mean,
        'accuracy_stderr': accuracy_stderr,
        'tot_ent_od_mean': tot_ent_od_mean,
        'tot_ent_od_stderr': tot_ent_od_stderr,
        'epistemic_od_mean': epistemic_od_mean,
        'epistemic_od_stderr': epistemic_od_stderr,
        'f1_od_mean': f1_od_mean,
        'f1_od_stderr': f1_od_stderr,
        'aucpr_od_mean': aucpr_od_mean,
        'aucpr_od_stderr': aucpr_od_stderr,
        'tot_ent_in_correct_mean': tot_ent_in_corr_mean,
        'tot_ent_in_correct_stderr': tot_ent_in_corr_stderr,
        'epistemic_in_correct_mean': epistemic_in_corr_mean,
        'epistemic_in_correct_stderr': epistemic_in_corr_stderr,
        'tot_ent_in_incorrect_mean': tot_ent_in_inc_mean,
        'tot_ent_in_incorrect_stderr': tot_ent_in_inc_stderr,
        'epistemic_in_incorrect_mean': epistemic_in_inc_mean,
        'epistemic_in_incorrect_stderr': epistemic_in_inc_stderr,
    }
    
    savemat('BayesianNN_MNIST_de_metrics.mat', metrics)
    print("\nAll computed metrics (and raw replicate results) have been saved to 'BayesianNN_MNIST_de_metrics.mat'.")
