#!/usr/bin/env python3
"""
PHMC Summarization Script

This script summarizes HMC runs (PHMC) where each file contains one sample.
For each replicate, the script loads the corresponding files, aggregates softmax outputs,
computes standard metrics for in‐domain (ID) samples (filtered MNIST digits 0–7) and performs
out‑of‑domain (OOD) analysis for separate groups:
    - MNIST digit 8
    - MNIST digit 9
    - perturbed in‑domain images (digits 0–7 with added noise)
    - White noise images
and then computes combined metrics for digits 8+9 and all OOD.

In addition, it computes a per‐digit breakdown for in‐domain data so that later you can
plot bar figures and generate LaTeX tables.
All printed metrics are formatted with 8‐decimal precision.
"""

import os
import numpy as np
from scipy.io import loadmat, savemat
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
import random
import sys
import pyro
from sklearn.metrics import f1_score, average_precision_score
from sklearn.preprocessing import label_binarize
from torch.func import functional_call

device_used = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# — the same CNN you used in psmc_single
class SimpleCNN(torch.nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv = torch.nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.fc   = torch.nn.Linear(4*14*14, 8)
        # initialization commented out to match original
        # torch.nn.init.normal_(self.conv.weight, mean=0, std=sigma_conv)
        # ...
    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

def hmc_predict_batch(x_data, particles_flat, NetClass):
    net = NetClass().to(device_used)
    net.eval()
    logits_list = []
    for flat in particles_flat:
        params = unflatten_params(torch.tensor(flat, device=device_used), net)
        logits = functional_call(net, params, x_data.to(device_used))
        logits_list.append(logits.detach().cpu().numpy())
    # returns shape (N_particles, N_samples, 8) of *logits*
    return np.stack(logits_list, axis=0)

# ----------------------------
# Helper function: unflatten parameters
# ----------------------------
def unflatten_params(flat, net):
    """Reconstruct parameters (as a dictionary) from a flat vector."""
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

############################################
#      Metrics computation: Brier and ECE
############################################
def compute_brier(probs, labels):
    """
    Multi-class Brier score: 
      (1/N) ∑_i ∑_c (p_{i,c} – 1{y_i=c})^2.
    """
    # probs: shape (N, C), rows sum to 1
    # labels: shape (N,), ints in [0..C-1]
    N, C = probs.shape
    # one-hot encode
    one_hot = np.zeros_like(probs)
    one_hot[np.arange(N), labels] = 1
    return np.mean(np.sum((probs - one_hot)**2, axis=1))


# def compute_ece(probs, labels, n_bins=10):
#     """Expected Calibration Error over `n_bins` equally‐spaced bins."""
#     bins = np.linspace(0.0, 1.0, n_bins + 1)
#     ece = 0.0
#     for i in range(n_bins):
#         mask = (probs >= bins[i]) & (probs < bins[i+1])
#         if not np.any(mask):
#             continue
#         preds_bin = (probs[mask] >= 0.5).astype(int)
#         acc_bin   = (labels[mask] == preds_bin).mean()
#         conf_bin  = probs[mask].mean()
#         ece      += np.abs(acc_bin - conf_bin) * (mask.sum() / len(probs))
#     return ece
def compute_ece(probs, labels, n_bins=10):
    """
    Expected Calibration Error for multiclass outputs.
    
    Args:
      probs: array of shape (N, C) of predicted class probabilities (each row sums to 1)
      labels: array of shape (N,) of integer true labels in [0..C-1]
      n_bins: number of equal-width confidence bins in [0,1]
    
    Returns:
      scalar ECE = sum_b (|acc(b) – conf(b)| · |bin_b|/N)
    """
    # 1) For each example, get predicted class and its confidence
    preds       = np.argmax(probs, axis=1)                             # shape (N,)
    confidences = probs[np.arange(len(probs)), preds]                 # shape (N,)
    correct     = (preds == labels).astype(float)                     # shape (N,)
    
    # 2) Build bins over [0,1]
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    
    # 3) Loop through bins
    for i in range(n_bins):
        low, high = bin_edges[i], bin_edges[i+1]
        if i == n_bins - 1:
            in_bin = (confidences >= low) & (confidences <= high)
        else:
            in_bin = (confidences >= low) & (confidences < high)
        prop_in_bin = in_bin.mean()
        if prop_in_bin == 0:
            continue
        
        # average accuracy and average confidence in this bin
        acc_bin  = correct[in_bin].mean()
        conf_bin = confidences[in_bin].mean()
        
        # accumulate weighted gap
        ece += prop_in_bin * abs(acc_bin - conf_bin)
    
    return ece

# ----------------------------
# Helper function: stable softmax in NumPy
# ----------------------------
def softmax_np(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

# ----------------------------
# Load validation dataset for ID analysis (Filtered MNIST: digits 0–7)
# ----------------------------
class FilteredDataset(Dataset):
    def __init__(self, dataset, allowed_labels):
        self.data = [(img, label) for img, label in dataset if label in allowed_labels]
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)

# Set up datasets.
transform = transforms.Compose([transforms.ToTensor()])
full_val_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
allowed_labels = list(range(8))
filtered_val_dataset = FilteredDataset(full_val_dataset, allowed_labels)
subset_indices = list(range(7000))
filtered_val_dataset = Subset(filtered_val_dataset, subset_indices)
val_loader = DataLoader(filtered_val_dataset, batch_size=len(filtered_val_dataset), shuffle=False)
x_val, y_val = next(iter(val_loader))
y_val_np = y_val.numpy().flatten()  # shape: (N_val,)

# For OOD, use full validation set.
full_val = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# ----------------------------
# User-adjustable parameters for PHMC summarization.
# ----------------------------
N = 10        # number of samples per processor (each file contains one sample)
R = 5         # number of replicates (fixed across all processor settings)
selected_P = [1,2,4,8]  # you can list other processor counts if needed

# File naming parameters (must match your HMC run settings)
d_val      = 6320   # dimension of flattened parameters
N_tr       = 1000   # number of training samples
N_val      = 200   # number of validation samples
thin_all   = 160     # thinning period used in HMC run
burnin_all = 160     # burn-in period used in HMC run

# ----------------------------
# Containers for final aggregated results.
# ----------------------------
results_by_P = {}

# ----------------------------
# Initialize containers for OOD metrics (for each replicate, store per-group metrics).
# Groups: "digit8", "digit9", "perturbed", "white_noise", "combined_8_9", "all_ood"
def init_ood_dict():
    return {"total_entropy": [], "epistemic": [], "f1": [], "aucpr": [], "nll": []}

rep_ood_results = {
    "digit8": init_ood_dict(),
    "digit9": init_ood_dict(),
    "perturbed": init_ood_dict(),
    "white_noise": init_ood_dict(),
    "combined_8_9": init_ood_dict(),
    "all_ood": init_ood_dict()
}

# ----------------------------
# NEW: Initialize containers for per-digit in-domain breakdown.
# ----------------------------
per_digit_total_entropy_list = []   # each element: array for labels 0-7 for one replicate.
per_digit_epistemic_entropy_list = [] # same shape as above.
inID_total_list = []  # overall in-domain total entropy for each replicate.
inID_epi_list = []    # overall in-domain epistemic uncertainty for each replicate.

# ----------------------------
# Helper function for computing OOD metrics given a batch of images & labels.
# ----------------------------
def compute_ood_metrics(images, labels, rep_particles, net_ood, eps=1e-12, valid_label=False):
    particle_logits = []
    for i in range(rep_particles.shape[0]):
        flat = rep_particles[i]
        flat_tensor = torch.tensor(flat, device=device_used)
        param_dict = unflatten_params(flat_tensor, net_ood)
        logits = functional_call(net_ood, param_dict, images)
        particle_logits.append(logits.detach().cpu().numpy())
    particle_logits = np.stack(particle_logits, axis=0)
    particle_probs = softmax_np(particle_logits)
    ensemble_probs = np.mean(particle_probs, axis=0)
    total_entropy = -np.sum(ensemble_probs * np.log(ensemble_probs + eps), axis=1)
    avg_total_entropy = np.mean(total_entropy)
    particle_entropy = -np.sum(particle_probs * np.log(particle_probs + eps), axis=2)
    avg_particle_entropy = np.mean(particle_entropy, axis=0)
    epistemic_uncertainty = total_entropy - avg_particle_entropy
    avg_epistemic = np.mean(epistemic_uncertainty)
    if valid_label:
        ens_preds = np.argmax(ensemble_probs, axis=1)
        f1_val = f1_score(labels, ens_preds, average='macro', zero_division=0)
    else:
        f1_val = np.nan
    aucpr_val = 0.0  # not meaningful for OOD
    nll_val = np.nan
    return avg_total_entropy, avg_epistemic, f1_val, aucpr_val, nll_val

# ----------------------------
# Loop over selected processor settings.
# ----------------------------
for P_use in selected_P:
    print(f"\n=== Evaluating for P = {P_use} processors ===")
    # For each replicate, we load N*P_use files.
    rep_accuracy = []
    rep_f1 = []
    rep_aucpr = []
    rep_nll = []
    rep_total_entropy = []
    rep_epistemic = []
    rep_Lsum = []

    rep_brier = []
    rep_ece   = []

    # New: Separate metrics for ID correct vs. incorrect.
    rep_total_entropy_correct = []
    rep_epistemic_correct = []
    rep_total_entropy_incorrect = []
    rep_epistemic_incorrect = []

    # Reset OOD replicate storage.
    rep_ood_results = {
        "digit8": init_ood_dict(),
        "digit9": init_ood_dict(),
        "perturbed": init_ood_dict(),
        "white_noise": init_ood_dict(),
        "combined_8_9": init_ood_dict(),
        "all_ood": init_ood_dict()
    }

    # Also initialize a list for Lsum per replicate.
    rep_Lsum = []

    # Loop over replicates (r_idx).
    for r_idx in range(R):
        rep_base_idx = r_idx * (N * P_use) + 1
        files_to_load = N * P_use
        file_indices = list(range(rep_base_idx, rep_base_idx + files_to_load))
        
        rep_preds_list = []
        rep_params_list = []
        Lsum_list = []

        for node in file_indices:
            filename = (f"BayesianNN_MNIST_hmc_results_d{d_val}_train{N_tr}_val{N_val}_"
                        f"thin{thin_all}_burnin{burnin_all}_node{node}.mat")
            if not os.path.exists(filename):
                print(f"File {filename} not found; skipping this file.")
                continue
            data = loadmat(filename)
            #preds = data['hmc_single_pred']   # shape (1, N_val, 8)
            params = data['hmc_single_x']       # shape (1, d_val)
            preds = hmc_predict_batch(x_val, params, SimpleCNN)
            Lsum_val = float(data['Lsum']) if 'Lsum' in data else np.nan
            rep_preds_list.append(preds)
            rep_params_list.append(params)
            Lsum_list.append(Lsum_val)
        
        if len(rep_preds_list) == 0:
            print(f"No files loaded for replicate {r_idx+1}; skipping replicate.")
            continue

        rep_preds = np.concatenate(rep_preds_list, axis=0)  # shape (N*P_use, N_val, 8)
        rep_params = np.concatenate(rep_params_list, axis=0)  # shape (N*P_use, d_val)
        rep_Lsum.append(np.mean(Lsum_list))

        # --- In-Domain Analysis ---
        rep_preds_soft = softmax_np(rep_preds)
        aggregated_prob = np.mean(rep_preds_soft, axis=0)
        pred_labels = np.argmax(aggregated_prob, axis=1)
        accuracy = np.mean(pred_labels == y_val_np)
        eps = 1e-12
        nll = -np.mean(np.log(aggregated_prob[np.arange(len(y_val_np)), y_val_np] + eps))

        proc_soft = softmax_np(rep_preds)              # shape (P, M, N, C)
        predictive_probs = np.mean(proc_soft, axis=(0))  # shape (N, C)
        brier = compute_brier(predictive_probs, y_val_np)
        # after computing pred_labels = argmax
        #confidences = aggregated_prob[np.arange(len(y_val_np)), pred_labels]
        #correctness  = (pred_labels == y_val_np).astype(int)
        ece = compute_ece(aggregated_prob, y_val_np, n_bins=10)
        
        total_entropy = -np.sum(aggregated_prob * np.log(aggregated_prob + eps), axis=1)
        avg_total_entropy = np.mean(total_entropy)
        avg_particle_entropy = np.mean(-np.sum(rep_preds_soft * np.log(rep_preds_soft + eps), axis=2), axis=0)
        epistemic_uncertainty = total_entropy - avg_particle_entropy
        avg_epistemic = np.mean(epistemic_uncertainty)
        y_val_bin = label_binarize(y_val_np, classes=np.arange(8))
        f1_val = f1_score(y_val_np, pred_labels, average='macro', zero_division=0)
        aucpr_val = average_precision_score(y_val_bin, aggregated_prob, average='macro')
        
        rep_accuracy.append(accuracy)
        rep_f1.append(f1_val)
        rep_aucpr.append(aucpr_val)
        rep_nll.append(nll)
        rep_total_entropy.append(avg_total_entropy)
        rep_epistemic.append(avg_epistemic)

        rep_brier.append(brier)
        rep_ece.append(ece)

        # Compute breakdown for correct vs. incorrect.
        correct_mask = (pred_labels == y_val_np)
        incorrect_mask = (pred_labels != y_val_np)
        tot_corr = np.mean(total_entropy[correct_mask]) if np.sum(correct_mask) > 0 else np.nan
        epi_corr = np.mean(epistemic_uncertainty[correct_mask]) if np.sum(correct_mask) > 0 else np.nan
        tot_incorr = np.mean(total_entropy[incorrect_mask]) if np.sum(incorrect_mask) > 0 else np.nan
        epi_incorr = np.mean(epistemic_uncertainty[incorrect_mask]) if np.sum(incorrect_mask) > 0 else np.nan
        rep_total_entropy_correct.append(tot_corr)
        rep_epistemic_correct.append(epi_corr)
        rep_total_entropy_incorrect.append(tot_incorr)
        rep_epistemic_incorrect.append(epi_incorr)
        
        # NEW: Per-digit ID breakdown.
        per_digit_total = []
        per_digit_epi = []
        for d in allowed_labels:
            mask = (y_val_np == d)
            if np.sum(mask) > 0:
                per_digit_total.append(np.mean(total_entropy[mask]))
                per_digit_epi.append(np.mean(epistemic_uncertainty[mask]))
            else:
                per_digit_total.append(np.nan)
                per_digit_epi.append(np.nan)
        per_digit_total_entropy_list.append(np.array(per_digit_total))
        per_digit_epistemic_entropy_list.append(np.array(per_digit_epi))
        inID_total_list.append(avg_total_entropy)
        inID_epi_list.append(avg_epistemic)
        
        # --- OOD Analysis ---
        # Define OOD groups.
        # Group 1: digit8.
        indices8 = []
        count8 = 0
        for idx, (img, label) in enumerate(full_val):
            if label == 8 and count8 < 100:
                indices8.append(idx)
                count8 += 1
            if count8 == 100:
                break
        group8 = Subset(full_val, indices8)
        # Group 2: digit9.
        indices9 = []
        count9 = 0
        for idx, (img, label) in enumerate(full_val):
            if label == 9 and count9 < 100:
                indices9.append(idx)
                count9 += 1
            if count9 == 100:
                break
        group9 = Subset(full_val, indices9)

        random.seed(2)
        np.random.seed(2)
        torch.manual_seed(2)
        pyro.set_rng_seed(2)
        # Group 3: perturbed (take first 100 images from filtered_val_dataset and add noise)
        perturbed_imgs = x_val[:100] + 0.5 * torch.randn_like(x_val[:100])
        perturbed_labels = y_val[:100].numpy()
        group_perturbed = TensorDataset(perturbed_imgs, torch.tensor(perturbed_labels))
        # Group 4: white_noise.
        white_noise_imgs = torch.rand(100, 1, 28, 28)
        white_noise_labels = -2 * torch.ones(100, dtype=torch.long)
        group_white = TensorDataset(white_noise_imgs, white_noise_labels)
        
        ood_groups = {
            "digit8": group8,
            "digit9": group9,
            "perturbed": group_perturbed,
            "white_noise": group_white
        }
        
        device_used = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # Define a simple CNN model for OOD analysis that matches your architecture.
        class SimpleCNN(torch.nn.Module):
            def __init__(self):
                super(SimpleCNN, self).__init__()
                self.conv = torch.nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1)
                self.pool = torch.nn.MaxPool2d(2, 2)
                self.fc = torch.nn.Linear(4*14*14, 8)
                torch.nn.init.normal_(self.conv.weight, mean=0, std=0.1)
                if self.conv.bias is not None:
                    torch.nn.init.normal_(self.conv.bias, mean=0, std=1e-1)
                torch.nn.init.normal_(self.fc.weight, mean=0, std=0.1)
                if self.fc.bias is not None:
                    torch.nn.init.normal_(self.fc.bias, mean=0, std=1e-1)
            def forward(self, x):
                x = F.relu(self.conv(x))
                x = self.pool(x)
                x = x.view(x.size(0), -1)
                x = self.fc(x)
                return x
        net_ood = SimpleCNN().to(device_used)
        net_ood.eval()
        
        # For each OOD group, compute metrics.
        for group_name, dataset in ood_groups.items():
            loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
            images, labels = next(iter(loader))
            images = images.to(device_used)
            valid_label = (group_name in ["digit8", "digit9"])
            if valid_label:
                true_labels = np.array([labels[i].item() for i in range(len(labels))])
            else:
                true_labels = None
            metrics_vals = compute_ood_metrics(images, true_labels, rep_params, net_ood, eps=eps, valid_label=valid_label)
            rep_ood_results[group_name]["total_entropy"].append(metrics_vals[0])
            rep_ood_results[group_name]["epistemic"].append(metrics_vals[1])
            rep_ood_results[group_name]["f1"].append(metrics_vals[2])
            rep_ood_results[group_name]["aucpr"].append(metrics_vals[3])
            rep_ood_results[group_name]["nll"].append(metrics_vals[4])
        
        # Combined group: combined_8_9.
        loader8 = DataLoader(group8, batch_size=len(group8), shuffle=False)
        images8, labels8 = next(iter(loader8))
        loader9 = DataLoader(group9, batch_size=len(group9), shuffle=False)
        images9, labels9 = next(iter(loader9))
        combined_8_9_images = torch.cat([images8, images9], dim=0).to(device_used)
        combined_8_9_labels = np.concatenate([np.array([labels8[i].item() for i in range(len(labels8))]),
                                              np.array([labels9[i].item() for i in range(len(labels9))])])
        metrics_vals = compute_ood_metrics(combined_8_9_images, combined_8_9_labels, rep_params, net_ood, eps=eps, valid_label=True)
        rep_ood_results["combined_8_9"]["total_entropy"].append(metrics_vals[0])
        rep_ood_results["combined_8_9"]["epistemic"].append(metrics_vals[1])
        rep_ood_results["combined_8_9"]["f1"].append(metrics_vals[2])
        rep_ood_results["combined_8_9"]["aucpr"].append(metrics_vals[3])
        rep_ood_results["combined_8_9"]["nll"].append(metrics_vals[4])
        
        # Combined group: all_ood.
        loaders = [DataLoader(ds, batch_size=len(ds), shuffle=False) for ds in ood_groups.values()]
        all_images_list = []
        for loader in loaders:
            imgs, _ = next(iter(loader))
            all_images_list.append(imgs)
        all_ood_images = torch.cat(all_images_list, dim=0).to(device_used)
        metrics_vals = compute_ood_metrics(all_ood_images, None, rep_params, net_ood, eps=eps, valid_label=False)
        rep_ood_results["all_ood"]["total_entropy"].append(metrics_vals[0])
        rep_ood_results["all_ood"]["epistemic"].append(metrics_vals[1])
        rep_ood_results["all_ood"]["f1"].append(metrics_vals[2])
        rep_ood_results["all_ood"]["aucpr"].append(metrics_vals[3])
        rep_ood_results["all_ood"]["nll"].append(metrics_vals[4])
    
    # ----------------------------
    # Compute overall averages and standard errors over replicates (for ID and OOD metrics).
    # ----------------------------
    def compute_stats(metric_list):
        arr = np.array(metric_list, dtype=float)
        mean_val = np.nanmean(arr)
        stderr_val = np.nanstd(arr, ddof=1) / np.sqrt(len(arr))
        return mean_val, stderr_val

    mean_acc, stderr_acc = compute_stats(rep_accuracy)
    mean_f1, stderr_f1 = compute_stats(rep_f1)
    mean_aucpr, stderr_aucpr = compute_stats(rep_aucpr)
    mean_nll, stderr_nll = compute_stats(rep_nll)
    mean_total_entropy, stderr_total_entropy = compute_stats(rep_total_entropy)
    mean_epistemic, stderr_epistemic = compute_stats(rep_epistemic)
    mean_Lsum, stderr_Lsum = compute_stats(rep_Lsum)

    mean_brier, stderr_brier = compute_stats(rep_brier)
    mean_ece,   stderr_ece   = compute_stats(rep_ece)

    mean_total_entropy_correct, stderr_total_entropy_correct = compute_stats(rep_total_entropy_correct)
    mean_epistemic_correct, stderr_epistemic_correct = compute_stats(rep_epistemic_correct)
    mean_total_entropy_incorrect, stderr_total_entropy_incorrect = compute_stats(rep_total_entropy_incorrect)
    mean_epistemic_incorrect, stderr_epistemic_incorrect = compute_stats(rep_epistemic_incorrect)

    print(f"\nAggregated ID metrics for P = {P_use} over {R} replicates:")
    print(f"Accuracy:      {mean_acc:.8f} ± {stderr_acc:.8f}")
    print(f"F1 Score:      {mean_f1:.8f} ± {stderr_f1:.8f}")
    print(f"AUC-PR:        {mean_aucpr:.8f} ± {stderr_aucpr:.8f}")
    print(f"NLL:           {mean_nll:.8f} ± {stderr_nll:.8f}")
    print(f"Brier Score:   {mean_brier:.8f} ± {stderr_brier:.8f}")
    print(f"ECE:           {mean_ece:.8f} ± {stderr_ece:.8f}")
    print(f"Total Entropy (all): {mean_total_entropy:.8f} ± {stderr_total_entropy:.8f}")
    print(f"Epistemic (all):     {mean_epistemic:.8f} ± {stderr_epistemic:.8f}")
    print(f"Lsum:          {mean_Lsum:.8f} ± {stderr_Lsum:.8f}")
    print("\nFor ID predictions split by correctness:")
    print(f"Correct Predictions - Total Entropy: {mean_total_entropy_correct:.8f} ± {stderr_total_entropy_correct:.8f}")
    print(f"Correct Predictions - Epistemic:     {mean_epistemic_correct:.8f} ± {stderr_epistemic_correct:.8f}")
    print(f"Incorrect Predictions - Total Entropy: {mean_total_entropy_incorrect:.8f} ± {stderr_total_entropy_incorrect:.8f}")
    print(f"Incorrect Predictions - Epistemic:     {mean_epistemic_incorrect:.8f} ± {stderr_epistemic_incorrect:.8f}")
    
    print(f"\nAggregated OOD metrics for P = {P_use} over {R} replicates:")
    for group, m in rep_ood_results.items():
        m_total, s_total = compute_stats(m["total_entropy"])
        m_epi, s_epi = compute_stats(m["epistemic"])
        m_f1, s_f1 = compute_stats(m["f1"])
        m_aucpr, s_aucpr = compute_stats(m["aucpr"])
        m_nll, s_nll = compute_stats(m["nll"])
        print(f"{group}:")
        print(f"  Total Entropy: {m_total:.8f} ± {s_total:.8f}")
        print(f"  Epistemic:     {m_epi:.8f} ± {s_epi:.8f}")
        print(f"  F1 Score:      {m_f1:.8f} ± {s_f1:.8f}")
        print(f"  AUC-PR:        {m_aucpr:.8f} ± {s_aucpr:.8f}")
        print(f"  NLL:           {m_nll:.8f} ± {s_nll:.8f}")


    # ----------------------------
    # Save aggregated results to .mat
    # ----------------------------
    results = {
        # ID metrics
        'mean_acc': mean_acc,        
        'stderr_acc': stderr_acc,
        'mean_f1_id': mean_f1,    
        'stderr_f1_id': stderr_f1,
        'mean_aucpr_id': mean_aucpr, 
        'stderr_aucpr_id': stderr_aucpr,
        'mean_nll': mean_nll,        
        'stderr_nll': stderr_nll,
        'mean_tot_ent': mean_total_entropy,
        'stderr_tot_ent': stderr_total_entropy,
        'mean_epi': mean_epistemic,        
        'stderr_epi': stderr_epistemic,
        'mean_Lsum': mean_Lsum,      
        'stderr_Lsum': stderr_Lsum,
        'mean_brier': mean_brier,
        'stderr_brier': stderr_brier,
        'mean_ece': mean_ece,
        'stderr_ece': stderr_ece,

        # Correct vs incorrect
        'mean_tot_corr': mean_total_entropy_correct,  
        'se_tot_corr': stderr_total_entropy_correct,
        'mean_epi_corr': mean_epistemic_correct,  
        'se_epi_corr': stderr_epistemic_correct,
        'mean_tot_inc': mean_total_entropy_incorrect,    
        'se_tot_inc': stderr_total_entropy_incorrect,
        'mean_epi_inc': mean_epistemic_incorrect,    
        'se_epi_inc': stderr_epistemic_incorrect,
    }

    # OOD per-category (same aggregation as SMC)
    # build aggregated_ood first
    aggregated_ood = {}
    for key, metrics in rep_ood_results.items():
        m_tot, s_tot   = compute_stats(metrics['total_entropy'])
        m_epi, s_epi   = compute_stats(metrics['epistemic'])
        aggregated_ood[key] = (m_tot, s_tot, m_epi, s_epi)

    # then dump into results dict
    for key, (m_tot, s_tot, m_epi, s_epi) in aggregated_ood.items():
        results[f'ood_{key}_mean_tot'] = m_tot
        results[f'ood_{key}_se_tot']   = s_tot
        results[f'ood_{key}_mean_epi'] = m_epi
        results[f'ood_{key}_se_epi']   = s_epi

    # ——— Add per-digit (0–7) ID breakdown into results ———
    # Stack over replicates (shape R×8)
    per_digit_tot_arr = np.vstack(per_digit_total_entropy_list)
    per_digit_epi_arr = np.vstack(per_digit_epistemic_entropy_list)
    # Compute mean and standard error for each digit
    mean_digit_tot = np.nanmean(per_digit_tot_arr, axis=0)
    se_digit_tot   = np.nanstd(per_digit_tot_arr, ddof=1, axis=0) / np.sqrt(per_digit_tot_arr.shape[0])
    mean_digit_epi = np.nanmean(per_digit_epi_arr, axis=0)
    se_digit_epi   = np.nanstd(per_digit_epi_arr,   ddof=1, axis=0) / np.sqrt(per_digit_epi_arr.shape[0])

    # Save them as 8-element vectors in the .mat
    results['mean_digit_tot_ent']   = mean_digit_tot
    results['stderr_digit_tot_ent'] = se_digit_tot
    results['mean_digit_epi_ent']   = mean_digit_epi
    results['stderr_digit_epi_ent'] = se_digit_epi

    savemat(f'phmc_aggregated_results_P{P_use}.mat', results)
