#!/usr/bin/env python3
"""
This script summarizes RP replicates of single‐replicate SMC results into a PSMC analysis,
but only for selected processor settings specified by select_index (e.g. [1, 2, 4, 8, 16]).
For each selected P (number of processors per replicate) and for R replicates, the processor‐specific
results are combined by first averaging over each processor’s particles and then over the available processors.
Using the aggregated estimated probability, standard in‐domain (ID) metrics are computed (accuracy, F1,
AUC-PR, and NLL). Then, using the full set of processor results, out-of-domain (OOD) metrics are computed.
Finally, a binary-detection F1 and AUROC (ID vs. OOD) using entropy as the score are computed and printed.
"""

import os
import numpy as np
from scipy.io import loadmat, savemat
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset
import random
import sys
import pyro
from sklearn.metrics import f1_score, average_precision_score, roc_auc_score, precision_recall_curve
from sklearn.preprocessing import label_binarize
from torch.func import functional_call

random.seed(2)
np.random.seed(2)
torch.manual_seed(2)
pyro.set_rng_seed(2)

# — enable device for re-prediction
device_used = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# — the same CNN you used in psmc_single
class SimpleCNN(torch.nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv = torch.nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.fc   = torch.nn.Linear(4*14*14, 8)
        # initialization commented out to match original
        # torch.nn.init.normal_(self.conv.weight, mean=0, std=sigma_conv)
        # ...
    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# — batch-SMC-predict using your psmc_single_x
# def smc_predict_batch(x_data, particles_flat, NetClass):
#     net = NetClass().to(device_used)
#     net.eval()
#     preds_list = []
#     for flat in particles_flat:
#         params = unflatten_params(torch.tensor(flat, device=device_used), net)
#         logits = functional_call(net, params, x_data.to(device_used))
#         probs = F.softmax(logits, dim=1).detach().cpu().numpy()
#         preds_list.append(probs)
#     return np.stack(preds_list, axis=0)  # (N_particles, N_samples, 8)
def smc_predict_batch(x_data, particles_flat, NetClass):
    net = NetClass().to(device_used)
    net.eval()
    logits_list = []
    for flat in particles_flat:
        params = unflatten_params(torch.tensor(flat, device=device_used), net)
        logits = functional_call(net, params, x_data.to(device_used))
        logits_list.append(logits.detach().cpu().numpy())
    # returns shape (N_particles, N_samples, 8) of *logits*
    return np.stack(logits_list, axis=0)

############################################
#      Metrics computation: Brier and ECE
############################################
def compute_brier(probs, labels):
    """
    Multi-class Brier score: 
      (1/N) ∑_i ∑_c (p_{i,c} – 1{y_i=c})^2.
    """
    # probs: shape (N, C), rows sum to 1
    # labels: shape (N,), ints in [0..C-1]
    N, C = probs.shape
    # one-hot encode
    one_hot = np.zeros_like(probs)
    one_hot[np.arange(N), labels] = 1
    return np.mean(np.sum((probs - one_hot)**2, axis=1))


# def compute_ece(probs, labels, n_bins=10):
#     """Expected Calibration Error over `n_bins` equally‐spaced bins."""
#     bins = np.linspace(0.0, 1.0, n_bins + 1)
#     ece = 0.0
#     for i in range(n_bins):
#         mask = (probs >= bins[i]) & (probs < bins[i+1])
#         if not np.any(mask):
#             continue
#         preds_bin = (probs[mask] >= 0.5).astype(int)
#         acc_bin   = (labels[mask] == preds_bin).mean()
#         conf_bin  = probs[mask].mean()
#         ece      += np.abs(acc_bin - conf_bin) * (mask.sum() / len(probs))
#     return ece
def compute_ece(probs, labels, n_bins=10):
    """
    Expected Calibration Error for multiclass outputs.
    
    Args:
      probs: array of shape (N, C) of predicted class probabilities (each row sums to 1)
      labels: array of shape (N,) of integer true labels in [0..C-1]
      n_bins: number of equal-width confidence bins in [0,1]
    
    Returns:
      scalar ECE = sum_b (|acc(b) – conf(b)| · |bin_b|/N)
    """
    # 1) For each example, get predicted class and its confidence
    preds       = np.argmax(probs, axis=1)                             # shape (N,)
    confidences = probs[np.arange(len(probs)), preds]                 # shape (N,)
    correct     = (preds == labels).astype(float)                     # shape (N,)
    
    # 2) Build bins over [0,1]
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    
    # 3) Loop through bins
    for i in range(n_bins):
        low, high = bin_edges[i], bin_edges[i+1]
        if i == n_bins - 1:
            in_bin = (confidences >= low) & (confidences <= high)
        else:
            in_bin = (confidences >= low) & (confidences < high)
        prop_in_bin = in_bin.mean()
        if prop_in_bin == 0:
            continue
        
        # average accuracy and average confidence in this bin
        acc_bin  = correct[in_bin].mean()
        conf_bin = confidences[in_bin].mean()
        
        # accumulate weighted gap
        ece += prop_in_bin * abs(acc_bin - conf_bin)
    
    return ece


# ----------------------------
# Helper: unflatten parameters
# ----------------------------
def unflatten_params(flat, net):
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

# ----------------------------
# Helper: stable softmax in NumPy
# ----------------------------
def softmax_np(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

# ----------------------------
# Load validation dataset for ID analysis (digits 0–7)
# ----------------------------
class FilteredDataset(Dataset):
    def __init__(self, dataset, allowed_labels):
        self.data = [(img, label) for img, label in dataset if label in allowed_labels]
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)

transform = transforms.Compose([transforms.ToTensor()])
full_val_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
allowed_labels = list(range(8))
filtered_val_dataset = FilteredDataset(full_val_dataset, allowed_labels)
filtered_val_dataset = Subset(filtered_val_dataset, list(range(7000)))
val_loader = DataLoader(filtered_val_dataset, batch_size=len(filtered_val_dataset), shuffle=False)
x_val, y_val = next(iter(val_loader))
y_val_np = y_val.numpy().flatten()  # (N_val,)

# For OOD, full MNIST validation set
full_val = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# ----------------------------
# Define OOD subsets
# ----------------------------
ood_digit8_indices = []; count8 = 0
ood_digit9_indices = []; count9 = 0
for idx, (img, label) in enumerate(full_val):
    if label == 8 and count8 < 100:
        ood_digit8_indices.append(idx); count8 += 1
    if label == 9 and count9 < 100:
        ood_digit9_indices.append(idx); count9 += 1
    if count8 == 100 and count9 == 100:
        break

ood_digit8_images, ood_digit8_labels = next(iter(DataLoader(
    Subset(full_val, ood_digit8_indices), batch_size=100, shuffle=False)))
ood_digit9_images, ood_digit9_labels = next(iter(DataLoader(
    Subset(full_val, ood_digit9_indices), batch_size=100, shuffle=False)))

ood_combined_images = torch.cat([ood_digit8_images, ood_digit9_images], dim=0)
ood_combined_labels = torch.cat([ood_digit8_labels, ood_digit9_labels], dim=0)

random.seed(2)
np.random.seed(2)
torch.manual_seed(2)
pyro.set_rng_seed(2)
perturbed_images, _ = next(iter(DataLoader(
    Subset(filtered_val_dataset, list(range(100))), batch_size=100, shuffle=False)))
perturbed_images = torch.clamp(perturbed_images + 0.5*torch.randn_like(perturbed_images), 0, 1)
perturbed_labels = -1 * torch.ones(len(perturbed_images), dtype=torch.int)

white_noise_images = torch.rand(100,1,28,28)
white_noise_labels = -2 * torch.ones(100, dtype=torch.int)

all_ood_images = torch.cat([ood_digit8_images, ood_digit9_images,
                            perturbed_images, white_noise_images], dim=0)
all_ood_labels = torch.cat([ood_digit8_labels, ood_digit9_labels,
                            perturbed_labels, white_noise_labels], dim=0)

# ----------------------------
# User-adjustable parameters
# ----------------------------
R = 5
select_index = [1,2,4,8]

d_val = 6320
N_tr = 1000
N_val = 200
N_particles = 10
M_val = 10

# ----------------------------
# Containers for final results
# ----------------------------
results_by_P = {}

ood_keys = ['digit8','digit9','combined_8_9','perturbed','white_noise','all_ood']
rep_total_entropy_ood = {k: [] for k in ood_keys}
rep_epistemic_ood   = {k: [] for k in ood_keys}
rep_nll_ood         = {k: [] for k in ood_keys}

# New: per-digit ID breakdown
per_digit_total_entropy_list = []
per_digit_epistemic_entropy_list = []
inID_total_list = []
inID_epi_list = []

# ----------------------------
# Helper: per-category OOD metrics
# ----------------------------
def compute_ood_metrics(images, labels, rep_particles, net_ood, eps=1e-12):
    particle_logits = []
    for i in range(rep_particles.shape[0]):
        flat = rep_particles[i]
        flat_tensor = torch.tensor(flat, device=device_used)
        params = unflatten_params(flat_tensor, net_ood)
        logits = functional_call(net_ood, params, images)
        particle_logits.append(logits.detach().cpu().numpy())
    particle_logits = np.stack(particle_logits,axis=0)
    particle_probs  = softmax_np(particle_logits)
    ensemble_probs  = np.mean(particle_probs,axis=0)
    total_entropy   = -np.sum(ensemble_probs * np.log(ensemble_probs+eps),axis=1)
    avg_tot = np.mean(total_entropy)
    particle_entropy = -np.sum(particle_probs * np.log(particle_probs+eps),axis=2)
    avg_part = np.mean(particle_entropy,axis=0)
    avg_epi = np.mean(total_entropy - avg_part)
    return avg_tot, avg_epi, None, None

# ----------------------------
# Loop over P settings
# ----------------------------
for P_use in select_index:
    print(f"\n=== Evaluating for P = {P_use} processors ===")

    # ID metric containers
    rep_accuracy = []; rep_f1 = []; rep_aucpr = []; rep_nll = []
    rep_total_entropy = []; rep_epistemic = []; rep_Lsum = []
    rep_total_entropy_correct = []; rep_epistemic_correct = []
    rep_total_entropy_incorrect = []; rep_epistemic_incorrect = []

    # Detection containers
    rep_f1_det     = []
    rep_aucroc_det = []

    rep_brier = []
    rep_ece   = []

    # ----------------------------
    # Loop over replicates
    # ----------------------------
    for r in range(R):
        proc_preds_list = []; proc_params_list = []; Lsum_list = []

        for p in range(P_use):
            node_index = r + p*R + 1
            filename = (
                f"BayesianNN_MNIST_psmc_SimpleNN_results_d{d_val}_"
                f"train{N_tr}_val{N_val}_N{N_particles}_M{M_val}_node{node_index}.mat"
            )
            if not os.path.exists(filename):
                print(f"File {filename} not found; skipping this processor.")
                continue
            data      = loadmat(filename)
            particles = data['psmc_single_x']    # (N_particles, d_val)
            #preds = data['psmc_single_pred']
            preds = smc_predict_batch(x_val, particles, SimpleCNN)
            #print('max abs difference:', np.max(np.abs(preds - preds_batch)))
            Lsum_val  = float(data.get('Lsum', np.nan))
            proc_preds_list.append(preds)
            proc_params_list.append(particles)
            Lsum_list.append(Lsum_val)

        if not proc_preds_list:
            print(f"No files loaded for replicate {r+1}; skipping replicate.")
            continue

        proc_preds   = np.stack(proc_preds_list,axis=0)
        proc_params  = np.stack(proc_params_list,axis=0)
        rep_Lsum.append(np.mean(Lsum_list))
        #print(len(proc_preds),len(proc_preds[0]),len(proc_preds[0][0]),len(proc_preds[0][0][0]))
        

        #print(y_val_np[:3])
        # --- In-Domain (ID) Analysis ---
        ensemble_per_proc   = np.mean(proc_preds,axis=1)
        aggregated_prob     = np.mean(ensemble_per_proc,axis=0)
        aggregated_soft     = softmax_np(aggregated_prob)
        pred_labels         = np.argmax(aggregated_soft,axis=1)
        accuracy            = np.mean(pred_labels == y_val_np)
        eps                 = 1e-12
        nll                 = -np.mean(np.log(aggregated_soft[np.arange(len(y_val_np)),y_val_np] + eps))

        #proc_soft = softmax_np(proc_preds)                     # (P, M, N, C) probabilities
        #ensemble_prob = np.mean(proc_soft, axis=(0,1))         # (N, C) probability average
        # after loading proc_preds as logits of shape (P, M, N, C)
        proc_soft = softmax_np(proc_preds)              # shape (P, M, N, C)
        predictive_probs = np.mean(proc_soft, axis=(0,1))  # shape (N, C)
        #print(predictive_probs[:3,:])
        assert predictive_probs.ndim == 2
        N, C = predictive_probs.shape
        assert C == 8       # eight classes
        assert N == len(y_val_np)
        row_sums = predictive_probs.sum(axis=1)
        print("min sum:", row_sums.min(), "max sum:", row_sums.max())
        print("min prob:", predictive_probs.min())
        assert (predictive_probs >= 0).all()
        assert not np.isnan(predictive_probs).any()
        assert not np.isinf(predictive_probs).any()
        top_conf = predictive_probs.max(axis=1)
        import matplotlib.pyplot as plt
        plt.hist(top_conf, bins=20)
        plt.xlabel("Max class probability")
        plt.ylabel("Count")
        #plt.show()
        brier = compute_brier(predictive_probs, y_val_np)
        from sklearn.metrics import brier_score_loss
        y_val_oh = label_binarize(y_val_np, classes=np.arange(8))  # (N,8)
        bs_per_class = [
            brier_score_loss(y_val_oh[:,k], predictive_probs[:,k])
            for k in range(8)
        ]
        brier_macro = np.sum(bs_per_class)
        print(brier_macro)
        y_oh = label_binarize(y_val_np, classes=np.arange(8))
        per_sample = np.sum((predictive_probs - y_oh)**2, axis=1)
        print("Brier per sample  → min,  p25,  median,  p75,  max:", 
            np.percentile(per_sample, [0,25,50,75,100]))
        per_sample = np.sum((predictive_probs - y_oh)**2, axis=1)
        bad_idx = np.where(per_sample > 1.5)[0]
        print(bad_idx)

        #brier = compute_brier(aggregated_soft, y_val_np)
        # after computing pred_labels = argmax
        #confidences = aggregated_soft[np.arange(len(y_val_np)), pred_labels]
        #correctness  = (pred_labels == y_val_np).astype(int)
        ece = compute_ece(aggregated_soft, y_val_np, n_bins=10)

        total_entropy       = -np.sum(aggregated_soft * np.log(aggregated_soft + eps),axis=1)
        avg_total_entropy_id= np.mean(total_entropy)

        proc_soft           = softmax_np(proc_preds)
        ens_all             = np.mean(proc_soft,axis=(0,1))
        tot_all             = -np.sum(ens_all * np.log(ens_all + eps),axis=1)
        part_ent            = -np.sum(proc_soft * np.log(proc_soft + eps),axis=-1)
        avg_part_ent        = np.mean(part_ent,axis=(0,1))
        epi_unc             = tot_all - avg_part_ent
        avg_epistemic_id    = np.mean(epi_unc)

        y_val_bin = label_binarize(y_val_np, classes=np.arange(8))
        f1_in     = f1_score(y_val_np, pred_labels, average='macro', zero_division=0)
        aucpr_in  = average_precision_score(y_val_bin, aggregated_soft, average='macro')

        rep_accuracy.append(accuracy)
        rep_f1.append(f1_in)
        rep_aucpr.append(aucpr_in)
        rep_nll.append(nll)
        rep_total_entropy.append(avg_total_entropy_id)
        rep_epistemic.append(avg_epistemic_id)

        rep_brier.append(brier)
        rep_ece.append(ece)

        # --- ID breakdown (correct vs. incorrect) ---
        correct_mask   = (pred_labels == y_val_np)
        incorrect_mask = ~correct_mask
        rep_total_entropy_correct.append(
            np.mean(total_entropy[correct_mask])   if correct_mask.any()   else np.nan)
        rep_epistemic_correct.append(
            np.mean(epi_unc[correct_mask])         if correct_mask.any()   else np.nan)
        rep_total_entropy_incorrect.append(
            np.mean(total_entropy[incorrect_mask]) if incorrect_mask.any() else np.nan)
        rep_epistemic_incorrect.append(
            np.mean(epi_unc[incorrect_mask])       if incorrect_mask.any() else np.nan)

        # --- Per-digit ID breakdown ---
        per_tot = []; per_epi = []
        for d in allowed_labels:
            mask = (y_val_np == d)
            per_tot.append(np.mean(total_entropy[mask]) if mask.any() else np.nan)
            per_epi.append(np.mean(epi_unc[mask]) if mask.any() else np.nan)
        per_digit_total_entropy_list.append(np.array(per_tot))
        per_digit_epistemic_entropy_list.append(np.array(per_epi))
        inID_total_list.append(avg_total_entropy_id)
        inID_epi_list.append(avg_epistemic_id)

        # ----------------------------
        # OOD Analysis per category
        # ----------------------------
        rep_particles = proc_params.reshape(-1, d_val)
        net_ood = SimpleCNN().to(device_used)
        net_ood.eval()

        ood_categories = {
            'digit8':       (ood_digit8_images.to(device_used),  ood_digit8_labels),
            'digit9':       (ood_digit9_images.to(device_used),  ood_digit9_labels),
            'combined_8_9': (ood_combined_images.to(device_used), ood_combined_labels),
            'perturbed':    (perturbed_images.to(device_used),   perturbed_labels),
            'white_noise':  (white_noise_images.to(device_used), white_noise_labels),
            'all_ood':      (all_ood_images.to(device_used),     all_ood_labels)
        }
        for key,(imgs_cat,labels_cat) in ood_categories.items():
            avg_tot, avg_epi, _, _ = compute_ood_metrics(imgs_cat, labels_cat, rep_particles, net_ood)
            rep_total_entropy_ood[key].append(avg_tot)
            rep_epistemic_ood[key].append(avg_epi)
            rep_nll_ood[key].append(np.nan)

        # ——— BINARY OOD DETECTION ———
        total_entropy_id  = -np.sum(ens_all * np.log(ens_all + eps),axis=1)
        ood_logits = []
        for flat in rep_particles:
            flat_tensor = torch.tensor(flat,device=device_used)
            params = unflatten_params(flat_tensor, net_ood)
            logits = functional_call(net_ood, params, all_ood_images.to(device_used))
            ood_logits.append(logits.detach().cpu().numpy())
        particle_probs_ood = softmax_np(np.stack(ood_logits,axis=0))
        ensemble_probs_ood = np.mean(particle_probs_ood,axis=0)
        total_entropy_ood  = -np.sum(ensemble_probs_ood * np.log(ensemble_probs_ood + eps),axis=1)

        y_det  = np.concatenate([np.zeros_like(total_entropy_id), np.ones_like(total_entropy_ood)])
        scores = np.concatenate([total_entropy_id,              total_entropy_ood])

        auroc = roc_auc_score(y_det, scores)
        prec, rec, thr = precision_recall_curve(y_det, scores)
        f1_scores      = 2 * prec * rec / (prec + rec + 1e-12)
        idx_best       = np.nanargmax(f1_scores)
        det_f1         = f1_scores[idx_best]

        rep_f1_det.append(det_f1)
        rep_aucroc_det.append(auroc)

    # ----------------------------
    # Compute overall stats
    # ----------------------------
    def compute_stats(lst):
        arr = np.array(lst, dtype=float)
        m   = np.nanmean(arr)
        se  = np.nanstd(arr, ddof=1)/np.sqrt(len(arr))
        return m, se

    mean_acc, stderr_acc             = compute_stats(rep_accuracy)
    mean_f1_id, stderr_f1_id         = compute_stats(rep_f1)
    mean_aucpr_id, stderr_aucpr_id   = compute_stats(rep_aucpr)
    mean_nll, stderr_nll             = compute_stats(rep_nll)
    mean_tot_ent, stderr_tot_ent     = compute_stats(rep_total_entropy)
    mean_epi, stderr_epi             = compute_stats(rep_epistemic)
    mean_Lsum, stderr_Lsum           = compute_stats(rep_Lsum)

    mean_brier, stderr_brier = compute_stats(rep_brier)
    mean_ece,   stderr_ece   = compute_stats(rep_ece)

    # Aggregated ID prints (including correct/incorrect)
    print(f"\nAggregated ID metrics for P = {P_use} over {R} replicates:")
    print(f"Accuracy:      {mean_acc:.8f} ± {stderr_acc:.8f}")
    print(f"F1 Score:      {mean_f1_id:.8f} ± {stderr_f1_id:.8f}")
    print(f"AUC-PR:        {mean_aucpr_id:.8f} ± {stderr_aucpr_id:.8f}")
    print(f"NLL:           {mean_nll:.8f} ± {stderr_nll:.8f}")
    print(f"Brier Score:   {mean_brier:.8f} ± {stderr_brier:.8f}")
    print(f"ECE:           {mean_ece:.8f} ± {stderr_ece:.8f}")
    print(f"Total Entropy (all): {mean_tot_ent:.8f} ± {stderr_tot_ent:.8f}")
    print(f"Epistemic (all):     {mean_epi:.8f} ± {stderr_epi:.8f}")
    print(f"Lsum:          {mean_Lsum:.8f} ± {stderr_Lsum:.8f}")

    # Correct/Incorrect breakdown
    mean_tot_corr, se_tot_corr = compute_stats(rep_total_entropy_correct)
    mean_epi_corr, se_epi_corr = compute_stats(rep_epistemic_correct)
    mean_tot_inc, se_tot_inc = compute_stats(rep_total_entropy_incorrect)
    mean_epi_inc, se_epi_inc = compute_stats(rep_epistemic_incorrect)
    print("\nFor ID predictions split by correctness:")
    print(f"Correct Predictions - Total Entropy: {mean_tot_corr:.8f} ± {se_tot_corr:.8f}")
    print(f"Correct Predictions - Epistemic:     {mean_epi_corr:.8f} ± {se_epi_corr:.8f}")
    print(f"Incorrect Predictions - Total Entropy: {mean_tot_inc:.8f} ± {se_tot_inc:.8f}")
    print(f"Incorrect Predictions - Epistemic:     {mean_epi_inc:.8f} ± {se_epi_inc:.8f}")

    # Aggregated OOD per-category prints
    aggregated_ood = {}
    for key in ood_keys:
        m_tot, s_tot = compute_stats(rep_total_entropy_ood[key])
        m_epi, s_epi = compute_stats(rep_epistemic_ood[key])
        aggregated_ood[key] = (m_tot, s_tot, m_epi, s_epi)

    print(f"\nAggregated OOD metrics for P = {P_use} over {R} replicates:")
    for key in ood_keys:
        m_tot, s_tot, m_epi, s_epi = aggregated_ood[key]
        print(f"\n-- OOD Category: {key} --")
        print(f"Total Entropy: {m_tot:.8f} ± {s_tot:.8f}")
        print(f"Epistemic:     {m_epi:.8f} ± {s_epi:.8f}")

    # Detection metrics prints
    mean_det_f1, stderr_det_f1     = compute_stats(rep_f1_det)
    mean_det_auroc, stderr_det_auroc = compute_stats(rep_aucroc_det)
    print("\n-- OOD Detection (ID vs OOD) --")
    print(f"Detection F1:    {mean_det_f1:.8f} ± {stderr_det_f1:.8f}")
    print(f"Detection AUROC: {mean_det_auroc:.8f} ± {stderr_det_auroc:.8f}")

        # ----------------------------
    # Save aggregated results to .mat
    # ----------------------------
    results = {
        # ID metrics
        'mean_acc': mean_acc,        'stderr_acc': stderr_acc,
        'mean_f1_id': mean_f1_id,    'stderr_f1_id': stderr_f1_id,
        'mean_aucpr_id': mean_aucpr_id, 'stderr_aucpr_id': stderr_aucpr_id,
        'mean_nll': mean_nll,        'stderr_nll': stderr_nll,
        'mean_tot_ent': mean_tot_ent,'stderr_tot_ent': stderr_tot_ent,
        'mean_epi': mean_epi,        'stderr_epi': stderr_epi,
        'mean_Lsum': mean_Lsum,      'stderr_Lsum': stderr_Lsum,
        'mean_brier': mean_brier,     'stderr_brier': stderr_brier,
        'mean_ece': mean_ece,         'stderr_ece': stderr_ece,

        # Correct vs incorrect
        'mean_tot_corr': mean_tot_corr,  'se_tot_corr': se_tot_corr,
        'mean_epi_corr': mean_epi_corr,  'se_epi_corr': se_epi_corr,
        'mean_tot_inc': mean_tot_inc,    'se_tot_inc': se_tot_inc,
        'mean_epi_inc': mean_epi_inc,    'se_epi_inc': se_epi_inc,

        # Detection
        'mean_det_f1': mean_det_f1,      'stderr_det_f1': stderr_det_f1,
        'mean_det_auroc': mean_det_auroc,'stderr_det_auroc': stderr_det_auroc,
    }

    # OOD per-category
    for key, (m_tot, s_tot, m_epi, s_epi) in aggregated_ood.items():
        results[f'ood_{key}_mean_tot'] = m_tot
        results[f'ood_{key}_se_tot']   = s_tot
        results[f'ood_{key}_mean_epi'] = m_epi
        results[f'ood_{key}_se_epi']   = s_epi
    
    # ——— Add per-digit (0–7) ID breakdown into results ———
    # Stack over replicates (shape R×8)
    per_digit_tot_arr = np.vstack(per_digit_total_entropy_list)
    per_digit_epi_arr = np.vstack(per_digit_epistemic_entropy_list)
    # Compute mean and standard error for each digit
    mean_digit_tot = np.nanmean(per_digit_tot_arr, axis=0)
    se_digit_tot   = np.nanstd(per_digit_tot_arr, ddof=1, axis=0) / np.sqrt(per_digit_tot_arr.shape[0])
    mean_digit_epi = np.nanmean(per_digit_epi_arr, axis=0)
    se_digit_epi   = np.nanstd(per_digit_epi_arr,   ddof=1, axis=0) / np.sqrt(per_digit_epi_arr.shape[0])

    # Save them as 8-element vectors in the .mat
    results['mean_digit_tot_ent']   = mean_digit_tot
    results['stderr_digit_tot_ent'] = se_digit_tot
    results['mean_digit_epi_ent']   = mean_digit_epi
    results['stderr_digit_epi_ent'] = se_digit_epi

    savemat(f'psmc_aggregated_results_P{P_use}.mat', results)


