#!/usr/bin/env python3
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
from scipy.io import savemat
import pyro

# For F1 and AUC-PR computation.
from sklearn.metrics import f1_score, average_precision_score
from sklearn.preprocessing import label_binarize

# Set device.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

############################################
# Global Prior Parameters (same as in your SMC code)
############################################
v = 0.1
sdb = np.sqrt(v)         # standard deviation for biases
var_conv = v    # variance for conv weights (std = sqrt(0.01)=0.1)
var_fc = v      # variance for fc weights   (std = sqrt(0.01)=0.1)

############################################
# Define the CNN Architecture for Filtered MNIST
############################################
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv = nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(4 * 14 * 14, 8)
        
        # Initialize weights from Normal distributions matching the regularization priors.
        nn.init.normal_(self.conv.weight, mean=0, std=np.sqrt(var_conv))
        if self.conv.bias is not None:
            nn.init.normal_(self.conv.bias, mean=0, std=sdb)
        nn.init.normal_(self.fc.weight, mean=0, std=np.sqrt(var_fc))
        if self.fc.bias is not None:
            nn.init.normal_(self.fc.bias, mean=0, std=sdb)
    
    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

############################################
# Define a filtered dataset (keep digits 0–7)
############################################
class FilteredDataset(Dataset):
    def __init__(self, dataset, allowed_labels):
        self.data = [(img, label) for img, label in dataset if label in allowed_labels]
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)

############################################
# Helper function to compute mean and standard error.
############################################
def compute_stats(metric_list, R):
    arr = np.array(metric_list)
    mean_val = np.mean(arr)
    stderr_val = np.std(arr, ddof=1) / np.sqrt(R)
    return mean_val, stderr_val

############################################
#      Metrics computation: Brier and ECE
############################################
def compute_brier(probs, labels):
    """
    Multi-class Brier score: 
      (1/N) ∑_i ∑_c (p_{i,c} – 1{y_i=c})^2.
    """
    # probs: shape (N, C), rows sum to 1
    # labels: shape (N,), ints in [0..C-1]
    N, C = probs.shape
    # one-hot encode
    one_hot = np.zeros_like(probs)
    one_hot[np.arange(N), labels] = 1
    return np.mean(np.sum((probs - one_hot)**2, axis=1))


# def compute_ece(probs, labels, n_bins=10):
#     """Expected Calibration Error over `n_bins` equally‐spaced bins."""
#     bins = np.linspace(0.0, 1.0, n_bins + 1)
#     ece = 0.0
#     for i in range(n_bins):
#         mask = (probs >= bins[i]) & (probs < bins[i+1])
#         if not np.any(mask):
#             continue
#         preds_bin = (probs[mask] >= 0.5).astype(int)
#         acc_bin   = (labels[mask] == preds_bin).mean()
#         conf_bin  = probs[mask].mean()
#         ece      += np.abs(acc_bin - conf_bin) * (mask.sum() / len(probs))
#     return ece
def compute_ece(probs, labels, n_bins=10):
    """
    Expected Calibration Error for multiclass outputs.
    
    Args:
      probs: array of shape (N, C) of predicted class probabilities (each row sums to 1)
      labels: array of shape (N,) of integer true labels in [0..C-1]
      n_bins: number of equal-width confidence bins in [0,1]
    
    Returns:
      scalar ECE = sum_b (|acc(b) – conf(b)| · |bin_b|/N)
    """
    # 1) For each example, get predicted class and its confidence
    preds       = np.argmax(probs, axis=1)                             # shape (N,)
    confidences = probs[np.arange(len(probs)), preds]                 # shape (N,)
    correct     = (preds == labels).astype(float)                     # shape (N,)
    
    # 2) Build bins over [0,1]
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    
    # 3) Loop through bins
    for i in range(n_bins):
        low, high = bin_edges[i], bin_edges[i+1]
        if i == n_bins - 1:
            in_bin = (confidences >= low) & (confidences <= high)
        else:
            in_bin = (confidences >= low) & (confidences < high)
        prop_in_bin = in_bin.mean()
        if prop_in_bin == 0:
            continue
        
        # average accuracy and average confidence in this bin
        acc_bin  = correct[in_bin].mean()
        conf_bin = confidences[in_bin].mean()
        
        # accumulate weighted gap
        ece += prop_in_bin * abs(acc_bin - conf_bin)
    
    return ece


if __name__ == '__main__':
    ############################################
    # Number of replicates.
    ############################################
    R = 1

    # Lists to store replicate metrics.
    replicate_epochs            = []
    replicate_accuracy          = []
    replicate_nll_in            = []
    replicate_total_entropy_in  = []
    replicate_f1_in             = []
    replicate_aucpr_in          = []
    replicate_total_entropy_in_correct   = []
    replicate_epistemic_in_correct       = []
    replicate_total_entropy_in_incorrect = []
    replicate_epistemic_in_incorrect     = []
    replicate_map_time          = []
    replicate_total_entropy_od  = []
    replicate_f1_od             = []
    replicate_aucpr_od          = []
    replicate_total_entropy_ood_8        = []
    replicate_f1_ood_8                   = []
    replicate_total_entropy_ood_9        = []
    replicate_f1_ood_9                   = []
    replicate_total_entropy_ood_perturbed = []
    replicate_total_entropy_ood_whitenoise = []
    replicate_total_entropy_ood_combined_8_9 = []
    replicate_f1_ood_combined_8_9           = []
    replicate_total_entropy_ood_all         = []
    per_digit_total_entropy_list           = []
    per_digit_epistemic_entropy_list       = []
    inID_total_list                        = []
    inID_epi_list                          = []

    replicate_brier = []
    replicate_ece   = []

    ############################################
    # Data Loading (shared across replicates)
    ############################################
    N_tr   = 1000#42175 #1000   # MAP train
    N_val  = 200#6025 #200    # MAP early-stop validation
    N_test = 7000 #1000   # ID test

    transform = transforms.Compose([transforms.ToTensor()])

    full_train_dataset = torchvision.datasets.MNIST(
        root='./data', train=True,  download=True, transform=transform)
    test_dataset       = torchvision.datasets.MNIST(
        root='./data', train=False, download=True, transform=transform)

    allowed_labels = list(range(8))

    # Filter & split the FIRST 2000 train images (0–7) into 1600/400
    filtered_all = FilteredDataset(full_train_dataset, allowed_labels)
    #print(len(filtered_all)) 48200
    filtered_all = Subset(filtered_all, list(range(N_tr + N_val)))
    filtered_train_dataset = Subset(filtered_all, list(range(N_tr)))
    filtered_val_dataset   = Subset(filtered_all, list(range(N_tr, N_tr + N_val)))

    # DataLoaders for MAP training & validation
    train_loader      = DataLoader(filtered_train_dataset, batch_size=64, shuffle=True)
    val_loader        = DataLoader(filtered_val_dataset,   batch_size=64, shuffle=False)
    train_loader_full = DataLoader(filtered_train_dataset, batch_size=len(filtered_train_dataset), shuffle=False)
    #x_train, y_train  = next(iter(train_loader_full))
    #x_train, y_train  = x_train.to(device), y_train.to(device)
    val_loader_full   = DataLoader(filtered_val_dataset, batch_size=len(filtered_val_dataset), shuffle=False)
    #x_val, y_val      = next(iter(val_loader_full))
    #x_val, y_val      = x_val.to(device), y_val.to(device)

    # ID Test set: first 1000 filtered test images (0–7)
    filtered_test = FilteredDataset(test_dataset, allowed_labels)
    #print(len(filtered_test)) 8017
    filtered_test = Subset(filtered_test, list(range(N_test)))
    test_loader   = DataLoader(filtered_test, batch_size=64, shuffle=False)

    # OOD sets (from same test_dataset)
    # Digit 8
    od8_indices = []
    count8 = 0
    for idx, (_, label) in enumerate(test_dataset):
        if label == 8 and count8 < 100:
            od8_indices.append(idx)
            count8 += 1
        if count8 == 100:
            break
    od8_dataset = Subset(test_dataset, od8_indices)
    od8_loader  = DataLoader(od8_dataset, batch_size=len(od8_dataset), shuffle=False)

    # Digit 9
    od9_indices = []
    count9 = 0
    for idx, (_, label) in enumerate(test_dataset):
        if label == 9 and count9 < 100:
            od9_indices.append(idx)
            count9 += 1
        if count9 == 100:
            break
    od9_dataset = Subset(test_dataset, od9_indices)
    od9_loader  = DataLoader(od9_dataset, batch_size=len(od9_dataset), shuffle=False)

    random.seed(2)
    np.random.seed(2)
    torch.manual_seed(2)
    pyro.set_rng_seed(2)
    # perturbed images
    pert_imgs, _ = next(iter(DataLoader(Subset(filtered_test, list(range(100))), batch_size=100)))
    pert_imgs = torch.clamp(pert_imgs + 0.5*torch.randn_like(pert_imgs), 0, 1).to(device)
    # white noise
    wn = torch.rand(100,1,28,28).to(device)

    ############################################
    # MAIN: Run R replicates of MAP training on Filtered MNIST
    ############################################
    for r in range(1, R+1):
        print(f"\n===== Starting replicate {r} =====\n")
        # Set seed for this replicate.
        torch.manual_seed(r)
        np.random.seed(r)
        random.seed(r)
        pyro.set_rng_seed(r)

        ############################################
        # MAP (Deterministic) CNN Training on Filtered MNIST
        ############################################
        model_cnn = SimpleCNN().to(device)
        optimizer_cnn = optim.Adam(model_cnn.parameters(), lr=0.001)
        criterion_cnn = nn.CrossEntropyLoss()

        train_losses = []
        val_losses = []
        moving_avg_window = 10
        best_moving_avg = float('inf')
        patience = 5
        no_improve_count = 0

        start_time = time.time()
        for epoch in range(1000):
            model_cnn.train()
            running_loss = 0.0
            for images, labels in train_loader:
                optimizer_cnn.zero_grad()
                outputs = model_cnn(images.to(device))
                ce_loss = criterion_cnn(outputs, labels.to(device))
                # Regularization.
                reg_conv = (torch.sum(model_cnn.conv.weight**2) / (2 * var_conv) +
                            torch.sum(model_cnn.conv.bias**2) / (2 * sdb**2))
                reg_fc = (torch.sum(model_cnn.fc.weight**2) / (2 * var_fc) +
                          torch.sum(model_cnn.fc.bias**2) / (2 * sdb**2))
                reg_loss = reg_conv + reg_fc
                loss = ce_loss + reg_loss/len(filtered_train_dataset)
                loss.backward()
                optimizer_cnn.step()
                running_loss += loss.item()
            train_loss = running_loss / len(train_loader)
            train_losses.append(train_loss)

            model_cnn.eval()
            val_running_loss = 0.0
            with torch.no_grad():
                for images, labels in val_loader:
                    outputs = model_cnn(images.to(device))
                    ce_loss = criterion_cnn(outputs, labels.to(device))
                    reg_conv = (torch.sum(model_cnn.conv.weight**2) / (2 * var_conv) +
                                torch.sum(model_cnn.conv.bias**2) / (2 * sdb**2))
                    reg_fc = (torch.sum(model_cnn.fc.weight**2) / (2 * var_fc) +
                              torch.sum(model_cnn.fc.bias**2) / (2 * sdb**2))
                    reg_loss = reg_conv + reg_fc
                    loss = ce_loss + reg_loss/len(filtered_train_dataset)
                    val_running_loss += loss.item()
            val_loss = val_running_loss / len(val_loader)
            val_losses.append(val_loss)

            print(f"Replicate {r}, MAP Epoch {epoch+1}: Train Loss = {train_loss:.8f}, Val Loss = {val_loss:.8f}")

            if epoch >= moving_avg_window - 1:
                moving_avg = sum(val_losses[-moving_avg_window:]) / moving_avg_window
                if moving_avg < best_moving_avg:
                    best_moving_avg = moving_avg
                    no_improve_count = 0
                else:
                    no_improve_count += 1
                if no_improve_count >= patience:
                    print(f"Replicate {r}, MAP: Early stopping at epoch {epoch+1}")
                    break

        map_epochs = epoch + 1
        total_map_time = time.time() - start_time
        print(f"Replicate {r}, MAP Total execution time: {total_map_time:.2f} seconds")
        print(f"Replicate {r}, MAP training stopped at epoch {map_epochs}")

        ############################################
        # Evaluate Test Accuracy on Filtered Test Set (ID)
        ############################################
        model_cnn.eval()
        correct = total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                outputs = model_cnn(images.to(device))
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted.cpu() == labels).sum().item()
        test_accuracy = 100 * correct / total
        print(f"Replicate {r}, MAP Test Accuracy (ID, classes 0–7): {test_accuracy:.2f}%")

        ############################################
        # In-Domain Analysis (Filtered Test Set: Digits 0–7)
        ############################################
        all_probs = []
        all_labels = []
        with torch.no_grad():
            for images, labels in test_loader:
                outputs = model_cnn(images.to(device))
                probs = F.softmax(outputs, dim=1)
                all_probs.append(probs.cpu().numpy())
                all_labels.append(labels.numpy())
        all_probs = np.concatenate(all_probs, axis=0)
        all_labels = np.concatenate(all_labels, axis=0)

        nlls = -np.log(all_probs[np.arange(len(all_labels)), all_labels] + 1e-12)
        avg_nll = np.mean(nlls)
        total_entropy = -np.sum(all_probs * np.log(all_probs + 1e-12), axis=1)
        avg_total_entropy = np.mean(total_entropy)
        avg_epistemic = 0.0

        y_pred_in = np.argmax(all_probs, axis=1)
        f1_in = f1_score(all_labels, y_pred_in, average='macro', zero_division=0)
        all_labels_bin = label_binarize(all_labels, classes=np.arange(8))
        aucpr_in = average_precision_score(all_labels_bin, all_probs, average='macro')

        brier = compute_brier(all_probs, all_labels)
        # after computing y_pred_in = argmax
        #confidences = all_probs[np.arange(len(all_labels)), y_pred_in]
        #correctness  = (y_pred_in == all_labels).astype(int)
        ece = compute_ece(all_probs, all_labels, n_bins=10)
        replicate_brier.append(brier)
        replicate_ece.append(ece)

        print(f"\nReplicate {r}, MAP In-Domain Analysis (Test Set):")
        print(f"  Average NLL:            {avg_nll:.8f}")
        print(f"  Average Total Entropy:  {avg_total_entropy:.8f}")
        print(f"  F1 Score (macro):       {f1_in:.8f}")
        print(f"  AUC-PR (macro):         {aucpr_in:.8f}")

        # In-domain breakdown
        correct_mask = (y_pred_in == all_labels)
        incorrect_mask = (y_pred_in != all_labels)
        total_entropy_correct = np.mean(total_entropy[correct_mask]) if correct_mask.any() else np.nan
        epistemic_correct = 0.0 if correct_mask.any() else np.nan
        total_entropy_incorrect = np.mean(total_entropy[incorrect_mask]) if incorrect_mask.any() else np.nan
        epistemic_incorrect = 0.0 if incorrect_mask.any() else np.nan

        print("Breakdown for In-Domain Predictions:")
        print(f"  Correct - Total Entropy:  {total_entropy_correct:.8f}")
        print(f"  Correct - Epistemic Entropy: {epistemic_correct:.8f}")
        print(f"  Incorrect - Total Entropy: {total_entropy_incorrect:.8f}")
        print(f"  Incorrect - Epistemic Entropy: {epistemic_incorrect:.8f}")

        # Per-digit breakdown
        per_digit_total = []
        per_digit_epi = []
        for d in allowed_labels:
            mask = (all_labels == d)
            per_digit_total.append(np.mean(total_entropy[mask]) if mask.any() else np.nan)
            per_digit_epi.append(0.0 if mask.any() else np.nan)
        per_digit_total_entropy_list.append(np.array(per_digit_total))
        per_digit_epistemic_entropy_list.append(np.array(per_digit_epi))
        inID_total_list.append(avg_total_entropy)
        inID_epi_list.append(avg_epistemic)

        ############################################
        # Out-Of-Domain Analysis (all from test_dataset)
        ############################################
        # Digit 8
        od8_images, od8_labels = next(iter(od8_loader))
        od8_images = od8_images.to(device)
        with torch.no_grad():
            probs_8 = F.softmax(model_cnn(od8_images), dim=1).cpu().numpy()
        entropy_8 = -np.sum(probs_8 * np.log(probs_8 + 1e-12), axis=1)
        f1_8 = f1_score(od8_labels.numpy(), np.argmax(probs_8,1), average='macro', zero_division=0)
        print(f"\nReplicate {r}, MAP OOD Analysis – Digit 8:     Entropy={entropy_8.mean():.8f}, F1={f1_8:.8f}")

        # Digit 9
        od9_images, od9_labels = next(iter(od9_loader))
        od9_images = od9_images.to(device)
        with torch.no_grad():
            probs_9 = F.softmax(model_cnn(od9_images), dim=1).cpu().numpy()
        entropy_9 = -np.sum(probs_9 * np.log(probs_9 + 1e-12), axis=1)
        f1_9 = f1_score(od9_labels.numpy(), np.argmax(probs_9,1), average='macro', zero_division=0)
        print(f"\nReplicate {r}, MAP OOD Analysis – Digit 9:     Entropy={entropy_9.mean():.8f}, F1={f1_9:.8f}")

        # Perturbed In-Domain
        with torch.no_grad():
            p_pert = F.softmax(model_cnn(pert_imgs), dim=1).cpu().numpy()
        entropy_pert = -np.sum(p_pert * np.log(p_pert + 1e-12), axis=1)
        print(f"\nReplicate {r}, MAP OOD Analysis – Perturbed ID: Entropy={entropy_pert.mean():.8f}")

        # White Noise
        with torch.no_grad():
            p_wn = F.softmax(model_cnn(wn), dim=1).cpu().numpy()
        entropy_wn = -np.sum(p_wn * np.log(p_wn + 1e-12), axis=1)
        print(f"\nReplicate {r}, MAP OOD Analysis – White Noise: Entropy={entropy_wn.mean():.8f}")

        # Combined 8 & 9
        comb = torch.cat([od8_images, od9_images],0)
        with torch.no_grad():
            p_comb = F.softmax(model_cnn(comb), dim=1).cpu().numpy()
        f1_comb = f1_score(
            np.concatenate([od8_labels.numpy(), od9_labels.numpy()]),
            np.argmax(p_comb,1), average='macro', zero_division=0
        )
        entropy_comb = -np.sum(p_comb * np.log(p_comb + 1e-12), axis=1)
        print(f"\nReplicate {r}, MAP OOD Analysis – Combined 8&9: Entropy={entropy_comb.mean():.8f}, F1={f1_comb:.8f}")

        # All OOD
        all_ood = torch.cat([od8_images, od9_images, pert_imgs, wn],0)
        with torch.no_grad():
            p_all = F.softmax(model_cnn(all_ood), dim=1).cpu().numpy()
        entropy_all = -np.sum(p_all * np.log(p_all + 1e-12), axis=1)
        print(f"\nReplicate {r}, MAP OOD Analysis – All OOD:     Entropy={entropy_all.mean():.8f}")

        ############################################
        # Store replicate metrics.
        ############################################
        replicate_epochs.append(map_epochs)
        replicate_map_time.append(total_map_time)
        replicate_accuracy.append(test_accuracy)
        replicate_nll_in.append(avg_nll)
        replicate_total_entropy_in.append(avg_total_entropy)
        replicate_f1_in.append(f1_in)
        replicate_aucpr_in.append(aucpr_in)
        replicate_total_entropy_in_correct.append(total_entropy_correct)
        replicate_epistemic_in_correct.append(epistemic_correct)
        replicate_total_entropy_in_incorrect.append(total_entropy_incorrect)
        replicate_epistemic_in_incorrect.append(epistemic_incorrect)
        replicate_total_entropy_od.append(avg_total_entropy)  # placeholder
        replicate_f1_od.append(f1_in)                        # placeholder
        replicate_aucpr_od.append(aucpr_in)                  # placeholder
        replicate_total_entropy_ood_8.append(entropy_8.mean())
        replicate_f1_ood_8.append(f1_8)
        replicate_total_entropy_ood_9.append(entropy_9.mean())
        replicate_f1_ood_9.append(f1_9)
        replicate_total_entropy_ood_perturbed.append(entropy_pert.mean())
        replicate_total_entropy_ood_whitenoise.append(entropy_wn.mean())
        replicate_total_entropy_ood_combined_8_9.append(entropy_comb.mean())
        replicate_f1_ood_combined_8_9.append(f1_comb)
        replicate_total_entropy_ood_all.append(entropy_all.mean())

    ############################################
    # Compute overall statistics over replicates.
    ############################################
    epochs_mean, epochs_stderr = compute_stats(replicate_epochs, R)
    accuracy_mean, accuracy_stderr = compute_stats(replicate_accuracy, R)
    nll_in_mean, nll_in_stderr = compute_stats(replicate_nll_in, R)
    tot_ent_in_mean, tot_ent_in_stderr = compute_stats(replicate_total_entropy_in, R)
    f1_in_mean, f1_in_stderr = compute_stats(replicate_f1_in, R)
    aucpr_in_mean, aucpr_in_stderr = compute_stats(replicate_aucpr_in, R)
    tot_ent_od_mean, tot_ent_od_stderr = compute_stats(replicate_total_entropy_od, R)
    f1_od_mean, f1_od_stderr = compute_stats(replicate_f1_od, R)
    aucpr_od_mean, aucpr_od_stderr = compute_stats(replicate_aucpr_od, R)
    time_mean, time_stderr = compute_stats(replicate_map_time, R)

    tot_ent_in_corr_mean, tot_ent_in_corr_stderr = compute_stats(replicate_total_entropy_in_correct, R)
    epistemic_in_corr_mean, epistemic_in_corr_stderr = compute_stats(replicate_epistemic_in_correct, R)
    tot_ent_in_inc_mean, tot_ent_in_inc_stderr = compute_stats(replicate_total_entropy_in_incorrect, R)
    epistemic_in_inc_mean, epistemic_in_inc_stderr = compute_stats(replicate_epistemic_in_incorrect, R)

    tot_ent_ood8_mean, tot_ent_ood8_stderr = compute_stats(replicate_total_entropy_ood_8, R)
    f1_ood8_mean, f1_ood8_stderr = compute_stats(replicate_f1_ood_8, R)
    tot_ent_ood9_mean, tot_ent_ood9_stderr = compute_stats(replicate_total_entropy_ood_9, R)
    f1_ood9_mean, f1_ood9_stderr = compute_stats(replicate_f1_ood_9, R)
    tot_ent_ood_pert_mean, tot_ent_ood_pert_stderr = compute_stats(replicate_total_entropy_ood_perturbed, R)
    tot_ent_ood_white_mean, tot_ent_ood_white_stderr = compute_stats(replicate_total_entropy_ood_whitenoise, R)
    tot_ent_ood_comb8_9_mean, tot_ent_ood_comb8_9_stderr = compute_stats(replicate_total_entropy_ood_combined_8_9, R)
    f1_ood_comb8_9_mean, f1_ood_comb8_9_stderr = compute_stats(replicate_f1_ood_combined_8_9, R)
    tot_ent_ood_all_mean, tot_ent_ood_all_stderr = compute_stats(replicate_total_entropy_ood_all, R)

    per_digit_total_entropy = np.mean(np.stack(per_digit_total_entropy_list, axis=0), axis=0)
    per_digit_epistemic_entropy = np.mean(np.stack(per_digit_epistemic_entropy_list, axis=0), axis=0)
    total_entropy_inID = np.mean(inID_total_list)
    epistemic_inID = np.mean(inID_epi_list)

    mean_brier, stderr_brier = compute_stats(replicate_brier, R)
    mean_ece,   stderr_ece   = compute_stats(replicate_ece, R)

    print("\n===== Summary over {} replicates =====".format(R))
    print(f"MAP Epochs: Mean = {epochs_mean:.2f}, SE = {epochs_stderr:.8f}")
    print(f"MAP Test Accuracy (%): Mean = {accuracy_mean:.2f}, SE = {accuracy_stderr:.8f}")
    print(f"In-Domain NLL: Mean = {nll_in_mean:.8f}, SE = {nll_in_stderr:.8f}")
    print(f"In-Domain Total Entropy: Mean = {tot_ent_in_mean:.8f}, SE = {tot_ent_in_stderr:.8f}")
    print(f"In-Domain F1 Score (macro): Mean = {f1_in_mean:.8f}, SE = {f1_in_stderr:.8f}")
    print(f"In-Domain AUC-PR (macro): Mean = {aucpr_in_mean:.8f}, SE = {aucpr_in_stderr:.8f}")
    print(f"Brier Score:   {mean_brier:.8f} ± {stderr_brier:.8f}")
    print(f"ECE:           {mean_ece:.8f} ± {stderr_ece:.8f}")
    print(f"Out-of-Domain Total Entropy (Digits 8&9 combined): Mean = {tot_ent_od_mean:.8f}, SE = {tot_ent_od_stderr:.8f}")
    print(f"Out-of-Domain F1 Score (Digits 8&9): Mean = {f1_od_mean:.8f}, SE = {f1_od_stderr:.8f}")
    print(f"Out-of-Domain AUC-PR (Digits 8&9): Mean = {aucpr_od_mean:.8f}, SE = {aucpr_od_stderr:.8f}")
    print(f"MAP Training Time (s): Mean = {time_mean:.2f}, SE = {time_stderr:.8f}")
    print("\nBreakdown for In-Domain Predictions:")
    print(f"  Correct Predictions - Total Entropy: Mean = {tot_ent_in_corr_mean:.8f}, SE = {tot_ent_in_corr_stderr:.8f}")
    print(f"  Correct Predictions - Epistemic Entropy: Mean = {epistemic_in_corr_mean:.8f}, SE = {epistemic_in_corr_stderr:.8f}")
    print(f"  Incorrect Predictions - Total Entropy: Mean = {tot_ent_in_inc_mean:.8f}, SE = {tot_ent_in_inc_stderr:.8f}")
    print(f"  Incorrect Predictions - Epistemic Entropy: Mean = {epistemic_in_inc_mean:.8f}, SE = {epistemic_in_inc_stderr:.8f}")
    print("\nDetailed OOD Analysis:")
    print(f"  Digit 8 - Total Entropy: Mean = {tot_ent_ood8_mean:.8f}, SE = {tot_ent_ood8_stderr:.8f}, F1: Mean = {f1_ood8_mean:.8f}, SE = {f1_ood8_stderr:.8f}")
    print(f"  Digit 9 - Total Entropy: Mean = {tot_ent_ood9_mean:.8f}, SE = {tot_ent_ood9_stderr:.8f}, F1: Mean = {f1_ood9_mean:.8f}, SE = {f1_ood9_stderr:.8f}")
    print(f"  Perturbed - Total Entropy: Mean = {tot_ent_ood_pert_mean:.8f}, SE = {tot_ent_ood_pert_stderr:.8f}")
    print(f"  White Noise - Total Entropy: Mean = {tot_ent_ood_white_mean:.8f}, SE = {tot_ent_ood_white_stderr:.8f}")
    print(f"  Combined (Digit 8 & 9) - Total Entropy: Mean = {tot_ent_ood_comb8_9_mean:.8f}, SE = {tot_ent_ood_comb8_9_stderr:.8f}, F1: Mean = {f1_ood_comb8_9_mean:.8f}, SE = {f1_ood_comb8_9_stderr:.8f}")
    print(f"  All OOD - Total Entropy: Mean = {tot_ent_ood_all_mean:.8f}, SE = {tot_ent_ood_all_stderr:.8f}")

    ############################################
    # Save all computed metrics and replicate results.
    ############################################
    metrics = {
        'R': R,
        'replicate_epochs': np.array(replicate_epochs),
        'replicate_map_time': np.array(replicate_map_time),
        'replicate_accuracy': np.array(replicate_accuracy),
        'replicate_nll_in': np.array(replicate_nll_in),
        'replicate_total_entropy_in': np.array(replicate_total_entropy_in),
        'replicate_f1_in': np.array(replicate_f1_in),
        'replicate_aucpr_in': np.array(replicate_aucpr_in),
        'replicate_total_entropy_od': np.array(replicate_total_entropy_od),
        'replicate_f1_od': np.array(replicate_f1_od),
        'replicate_aucpr_od': np.array(replicate_aucpr_od),
        'replicate_total_entropy_in_correct': np.array(replicate_total_entropy_in_correct),
        'replicate_epistemic_in_correct': np.array(replicate_epistemic_in_correct),
        'replicate_total_entropy_in_incorrect': np.array(replicate_total_entropy_in_incorrect),
        'replicate_epistemic_in_incorrect': np.array(replicate_epistemic_in_incorrect),
        'replicate_total_entropy_ood_8': np.array(replicate_total_entropy_ood_8),
        'replicate_f1_ood_8': np.array(replicate_f1_ood_8),
        'replicate_total_entropy_ood_9': np.array(replicate_total_entropy_ood_9),
        'replicate_f1_ood_9': np.array(replicate_f1_ood_9),
        'replicate_total_entropy_ood_perturbed': np.array(replicate_total_entropy_ood_perturbed),
        'replicate_total_entropy_ood_whitenoise': np.array(replicate_total_entropy_ood_whitenoise),
        'replicate_total_entropy_ood_combined_8_9': np.array(replicate_total_entropy_ood_combined_8_9),
        'replicate_f1_ood_combined_8_9': np.array(replicate_f1_ood_combined_8_9),
        'replicate_total_entropy_ood_all': np.array(replicate_total_entropy_ood_all),
        'per_digit_total_entropy': per_digit_total_entropy,
        'per_digit_epistemic_entropy': per_digit_epistemic_entropy,
        'total_entropy_inID': total_entropy_inID,
        'epistemic_entropy_inID': epistemic_inID,
        'epochs_mean': epochs_mean,
        'epochs_stderr': epochs_stderr,
        'accuracy_mean': accuracy_mean,
        'accuracy_stderr': accuracy_stderr,
        'nll_in_mean': nll_in_mean,
        'nll_in_stderr': nll_in_stderr,
        'tot_ent_in_mean': tot_ent_in_mean,
        'tot_ent_in_stderr': tot_ent_in_stderr,
        'f1_in_mean': f1_in_mean,
        'f1_in_stderr': f1_in_stderr,
        'aucpr_in_mean': aucpr_in_mean,
        'aucpr_in_stderr': aucpr_in_stderr,
        'tot_ent_od_mean': tot_ent_od_mean,
        'tot_ent_od_stderr': tot_ent_od_stderr,
        'f1_od_mean': f1_od_mean,
        'f1_od_stderr': f1_od_stderr,
        'aucpr_od_mean': aucpr_od_mean,
        'aucpr_od_stderr': aucpr_od_stderr,
        'time_mean': time_mean,
        'time_stderr': time_stderr,
        'tot_ent_in_corr_mean': tot_ent_in_corr_mean,
        'tot_ent_in_corr_stderr': tot_ent_in_corr_stderr,
        'epistemic_in_corr_mean': epistemic_in_corr_mean,
        'epistemic_in_corr_stderr': epistemic_in_corr_stderr,
        'tot_ent_in_inc_mean': tot_ent_in_inc_mean,
        'tot_ent_in_inc_stderr': tot_ent_in_inc_stderr,
        'epistemic_in_inc_mean': epistemic_in_inc_mean,
        'epistemic_in_inc_stderr': epistemic_in_inc_stderr,
        'tot_ent_ood8_mean': tot_ent_ood8_mean,
        'tot_ent_ood8_stderr': tot_ent_ood8_stderr,
        'f1_ood8_mean': f1_ood8_mean,
        'f1_ood8_stderr': f1_ood8_stderr,
        'tot_ent_ood9_mean': tot_ent_ood9_mean,
        'tot_ent_ood9_stderr': tot_ent_ood9_stderr,
        'f1_ood9_mean': f1_ood9_mean,
        'f1_ood9_stderr': f1_ood9_stderr,
        'tot_ent_ood_pert_mean': tot_ent_ood_pert_mean,
        'tot_ent_ood_pert_stderr': tot_ent_ood_pert_stderr,
        'tot_ent_ood_white_mean': tot_ent_ood_white_mean,
        'tot_ent_ood_white_stderr': tot_ent_ood_white_stderr,
        'tot_ent_ood_comb8_9_mean': tot_ent_ood_comb8_9_mean,
        'tot_ent_ood_comb8_9_stderr': tot_ent_ood_comb8_9_stderr,
        'f1_ood_comb8_9_mean': f1_ood_comb8_9_mean,
        'f1_ood_comb8_9_stderr': f1_ood_comb8_9_stderr,
        'tot_ent_ood_all_mean': tot_ent_ood_all_mean,
        'tot_ent_ood_all_stderr': tot_ent_ood_all_stderr,
        'mean_brier': mean_brier,
        'stderr_brier': stderr_brier,
        'mean_ece': mean_ece,
        'stderr_ece': stderr_ece,
    }
    savemat('BayesianNN_MNIST_MAP_metrics.mat', metrics)
    print("\nAll computed metrics (and raw replicate results) have been saved to 'BayesianNN_MNIST_MAP_metrics.mat'.")
