#!/usr/bin/env python3
"""
PHMC Summarization Script for CIFAR-10 using Precomputed OOD Embeddings

This script aggregates PHMC runs for Bayesian inference on CIFAR-10 
(using SimpleMLP with MAP prior incorporation). Each file (from a single run)
contains one sample (i.e. one set of HMC particle parameters, predictions, and Lsum).
For each replicate, the script loads the corresponding block of files,
concatenates the samples, and computes:
  - Aggregated estimated probabilities (averaging softmax outputs over all samples)
  - Standard in-domain (ID) metrics: accuracy, F1, AUC‑PR, and Negative Log Likelihood (NLL)
  - Total predictive entropy and epistemic uncertainty (ensemble entropy minus average per-sample entropy)
  - Average Lsum (cumulative “epoch” count) across files
For out‑of‑domain (OOD) analysis, the script uses ResNet‑50 embeddings for 3 OODs using the same approach as in your single-file code.
Results are aggregated over replicates and saved.
"""

import os
import numpy as np
from scipy.io import loadmat, savemat
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
import random
import sys
from sklearn.metrics import f1_score, average_precision_score, precision_score, recall_score
from sklearn.preprocessing import label_binarize
from torch.func import functional_call

from PIL import Image
import tarfile, urllib.request
import random
import torchvision.datasets as datasets
from torchvision import models

############################################
#      Device and Prior Settings
############################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sigma_w = np.sqrt(0.2)   # standard deviation for weights (used in SimpleMLP)
sigma_b = np.sqrt(0.2)   # standard deviation for biases

# ----------------------------
# Helper Functions
# ----------------------------
def unflatten_params(flat, net):
    """Reconstruct parameters (as a dictionary) from a flat vector."""
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

def softmax_np(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

# ----------------------------
# Dataset Class for Filtering
# ----------------------------
class FilteredDataset(Dataset):
    def __init__(self, dataset, allowed_labels):
        # Keep only samples with labels in allowed_labels.
        self.data = [(img, label) for img, label in dataset if label in allowed_labels]
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)

# ----------------------------
# OOD Embedding Extraction Function
# ----------------------------
# --- 1) CIFAR-100 “close” OOD: classes not in CIFAR-10 ---
def sample_cifar100_not_in_cifar10(n1, root='./data'):
    # Download both datasets
    cifar10 = datasets.CIFAR10(root=root, train=False, download=True)
    cifar100 = datasets.CIFAR100(root=root, train=False, download=True)
    
    # Which CIFAR-100 fine-class names are NOT in CIFAR-10?
    cif10_set = set(cifar10.classes)
    cif100_names = cifar100.classes  # list of 100 names
    ood_names = [name for name in cif100_names if name not in cif10_set]
    
    # Filter indices whose label name is in our OOD set
    idxs = [
        idx for idx, label in enumerate(cifar100.targets)
        if cif100_names[label] in ood_names
    ]
    random.seed(42)
    chosen = random.sample(idxs, n1)
    
    # Return as PIL Images
    return [Image.fromarray(cifar100.data[i]) for i in chosen]

# --- 2) CIFAR-10-C “corrupt” OOD: all corruption types/severities ---
def sample_cifar10c(n2, root='./data'):
    url = 'https://zenodo.org/record/2535967/files/CIFAR-10-C.tar'
    tar_path = os.path.join(root, 'CIFAR-10-C.tar')
    extract_dir = os.path.join(root, 'CIFAR-10-C')
    
    # Download if needed
    if not os.path.exists(tar_path):
        os.makedirs(root, exist_ok=True)
        print("Downloading CIFAR-10-C (≈180 MB)...")
        urllib.request.urlretrieve(url, tar_path)
    
    # Extract if needed
    if not os.path.isdir(extract_dir):
        print("Extracting CIFAR-10-C...")
        with tarfile.open(tar_path) as tar:
            tar.extractall(path=root)
    
    # Load every corruption .npy (skip the labels file)
    all_imgs = []
    for fname in os.listdir(extract_dir):
        if fname.endswith('.npy') and 'labels' not in fname:
            arr = np.load(os.path.join(extract_dir, fname))
            # arr shape: [10000, 32, 32, 3]
            all_imgs.append(arr)
    # Concatenate: shape [#types × 10000, 32,32,3]
    all_imgs = np.vstack(all_imgs)
    
    # Sample and convert
    random.seed(42)
    chosen = random.sample(range(len(all_imgs)), n2)
    return [Image.fromarray(all_imgs[i].astype(np.uint8)) for i in chosen]

# --- 3) SVHN “far” OOD: street-view digits ---
def sample_svhn(n3, root='./data'):
    # transform to PIL so dataset[i][0] is an Image
    #transform = transforms.ToPILImage()
    #svhn = datasets.SVHN(root=root, split='test', download=True, transform=transform)
    svhn = datasets.SVHN(root=root, split='test', download=True, transform=None)
    
    random.seed(42)
    chosen = random.sample(range(len(svhn)), n3)
    return [svhn[i][0] for i in chosen]

def create_resnet50_embedded_ood_datasets(
        n_close, n_corrupt, n_far,
        cache_path='ood_embeddings.pt'):

    if os.path.exists(cache_path):
        print("Loading cached ood embeddings...")
        return torch.load(cache_path)

    print("Cached embeddings not found. Computing ood embeddings...")
    
    # 1) Define the image‐preprocessing transform for OOD PIL images:
    transform = transforms.Compose([
        transforms.Resize(224),                  # match ResNet50’s input size
        transforms.ToTensor(),                   # convert PIL→[0,1] tensor
        transforms.Normalize(                    # ImageNet mean/std
            mean=[0.485, 0.456, 0.406],
            std =[0.229, 0.224, 0.225]
        )
    ])

    # 2) Build the ResNet‐50 feature extractor:
    resnet50 = models.resnet50(pretrained=True)     # download pretrained weights
    # drop the final fully‐connected layer:
    feature_extractor = nn.Sequential(*list(resnet50.children())[:-1])
    feature_extractor.eval()                        # set to eval mode
    feature_extractor.to(device)                    # move to GPU if available

    # 3) Sample PIL images
    close_imgs   = sample_cifar100_not_in_cifar10(n_close)
    corrupt_imgs = sample_cifar10c(n_corrupt)
    far_imgs     = sample_svhn(n_far)

    # 4) Convert → tensors & DataLoaders
    def imgs_to_embeddings(img_list):
        tensors = torch.stack([transform(img) for img in img_list])
        loader = DataLoader(TensorDataset(tensors), batch_size=128, shuffle=False)
        feats = []
        with torch.no_grad():
            for (batch,) in loader:
                batch = batch.to(device)
                out = feature_extractor(batch).view(batch.size(0), -1)
                feats.append(out.cpu())
        return torch.cat(feats, dim=0)

    X_close   = imgs_to_embeddings(close_imgs)
    X_corrupt = imgs_to_embeddings(corrupt_imgs)
    X_far     = imgs_to_embeddings(far_imgs)

    torch.save((X_close, X_corrupt, X_far), cache_path)
    print("ood embeddings computed and saved.")
    return X_close, X_corrupt, X_far


#-------------------------------------------
#      Metrics computation: Brier and ECE
#-------------------------------------------
def compute_brier(probs, labels):
    """
    Multi-class Brier score: 
      (1/N) ∑_i ∑_c (p_{i,c} – 1{y_i=c})^2.
    """
    # probs: shape (N, C), rows sum to 1
    # labels: shape (N,), ints in [0..C-1]
    N, C = probs.shape
    # one-hot encode
    one_hot = np.zeros_like(probs)
    one_hot[np.arange(N), labels] = 1
    return np.mean(np.sum((probs - one_hot)**2, axis=1))

# def compute_ece(probs, labels, n_bins=10):
#     """Expected Calibration Error over `n_bins` equally‐spaced bins."""
#     bins = np.linspace(0.0, 1.0, n_bins + 1)
#     ece = 0.0
#     for i in range(n_bins):
#         mask = (probs >= bins[i]) & (probs < bins[i+1])
#         if not np.any(mask):
#             continue
#         preds_bin = (probs[mask] >= 0.5).astype(int)
#         acc_bin   = (labels[mask] == preds_bin).mean()
#         conf_bin  = probs[mask].mean()
#         ece      += np.abs(acc_bin - conf_bin) * (mask.sum() / len(probs))
#     return ece
def compute_ece(probs, labels, n_bins=10):
    """
    Expected Calibration Error for multiclass outputs.
    
    Args:
      probs: array of shape (N, C) of predicted class probabilities (each row sums to 1)
      labels: array of shape (N,) of integer true labels in [0..C-1]
      n_bins: number of equal-width confidence bins in [0,1]
    
    Returns:
      scalar ECE = sum_b (|acc(b) – conf(b)| · |bin_b|/N)
    """
    # 1) For each example, get predicted class and its confidence
    preds       = np.argmax(probs, axis=1)                             # shape (N,)
    confidences = probs[np.arange(len(probs)), preds]                 # shape (N,)
    correct     = (preds == labels).astype(float)                     # shape (N,)
    
    # 2) Build bins over [0,1]
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    
    # 3) Loop through bins
    for i in range(n_bins):
        low, high = bin_edges[i], bin_edges[i+1]
        in_bin = (confidences >= low) & (confidences < high)
        prop_in_bin = in_bin.mean()
        if prop_in_bin == 0:
            continue
        
        # average accuracy and average confidence in this bin
        acc_bin  = correct[in_bin].mean()
        conf_bin = confidences[in_bin].mean()
        
        # accumulate weighted gap
        ece += prop_in_bin * abs(acc_bin - conf_bin)
    
    return ece


# ----------------------------
# Define SimpleMLP for CIFAR-10 (Used in PHMC Runs)
# ----------------------------
class SimpleMLP(torch.nn.Module):
    """
    SimpleMLP for classification (logistic regression when hidden_dim==0) on 10 classes.
    """
    def __init__(self, input_dim=2048, hidden_dim=0, num_classes=10):
        super(SimpleMLP, self).__init__()
        if hidden_dim is None or hidden_dim == 0:
            self.fc = torch.nn.Linear(input_dim, num_classes)
            #torch.nn.init.normal_(self.fc.weight, mean=0, std=sigma_w)
            #if self.fc.bias is not None:
            #    torch.nn.init.normal_(self.fc.bias, mean=0, std=sigma_b)
        else:
            self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
            self.fc2 = torch.nn.Linear(hidden_dim, num_classes)
            #torch.nn.init.normal_(self.fc1.weight, mean=0, std=sigma_w)
            #torch.nn.init.normal_(self.fc2.weight, mean=0, std=sigma_w)
            #if self.fc1.bias is not None:
            #    torch.nn.init.normal_(self.fc1.bias, mean=0, std=sigma_b)
            #if self.fc2.bias is not None:
            #    torch.nn.init.normal_(self.fc2.bias, mean=0, std=sigma_b)
    def forward(self, x):
        if hasattr(self, 'fc'):
            return self.fc(x)
        else:
            x = self.fc1(x)
            x = F.relu(x)
            return self.fc2(x)

# ----------------------------
# Main Execution
# ----------------------------
if __name__ == '__main__':

    # Load In-Domain CIFAR-10 Data
    transform_cifar = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    ])
    full_val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_cifar)
    allowed_labels = list(range(10))
    filtered_val_dataset = FilteredDataset(full_val_dataset, allowed_labels)

    val_loader = DataLoader(filtered_val_dataset, batch_size=len(filtered_val_dataset), shuffle=False)
    x_val, y_val = next(iter(val_loader))
    y_val_np = y_val.numpy().flatten()  # ground truth labels


    # Load Out-of-Domain (OOD) Data
    n_close, n_corrupt, n_far = 1000, 1000, 1000
    X_close, X_corrupt, X_far = create_resnet50_embedded_ood_datasets(
        n_close, n_corrupt, n_far,
        cache_path="ood_all_embeddings.pt"
        )

    # User-Adjustable Parameters for Summarization.
    N = 10        # number of samples per processor (each file contains one PHMC sample)
    R = 5         # number of replicates
    selected_P = [1, 2, 4, 8]  # list of processor counts (number of files per replicate)
    # File naming parameters.
    d_val      = 20490
    #N_tr       = 1000    # (example) number of training samples used in PHMC run
    N_val      = len(filtered_val_dataset)  # number of validation samples
    thin_all   = 200     # thinning period used in the PHMC run
    burnin_all = 200     # burn-in period used in the PHMC run 

    # Ensure OOD embeddings are on the same device.
    X_close = X_close.to(device)
    X_corrupt = X_corrupt.to(device)
    X_far = X_far.to(device)
    
    results = {}

    # Loop over selected processor settings.
    for P_use in selected_P:
        print(f"\n=== Evaluating for P = {P_use} processors ===")
        # Containers for replicate metrics (ID).
        rep_accuracy = []         # accuracy per replicate
        rep_f1 = []               # F1 score per replicate
        rep_aucpr = []            # AUC-PR per replicate
        rep_nll = []              # Negative log likelihood per replicate
        rep_total_entropy = []    # average total predictive entropy per replicate
        rep_epistemic = []        # average epistemic uncertainty per replicate
        rep_Lsum = []             # average Lsum (epoch count) per replicate
        rep_brier = []
        rep_ece   = []
        rep_precision = []
        rep_recall    = []

        # Containers for splitting ID metrics by prediction correctness.
        rep_total_entropy_correct = []
        rep_epistemic_correct = []
        rep_total_entropy_incorrect = []
        rep_epistemic_incorrect = []

        # Containers for replicate metrics (OOD)
        rep_total_entropy_ood = {
            "close":   [],
            "corrupt": [],
            "far":     []
        }
        rep_epistemic_ood = {
            "close":   [],
            "corrupt": [],
            "far":     []
        }

        # For each replicate, load the next N*P_use files.
        for r_idx in range(R):
            rep_base_idx = r_idx * (N * P_use) + 1
            files_to_load = N * P_use
            file_indices = list(range(rep_base_idx, rep_base_idx + files_to_load))

            rep_preds_list = []   # will hold hmc_single_pred from each file; shape: (1, N_val, 10)
            rep_params_list = []  # will hold hmc_single_x from each file; shape: (1, d_val)
            Lsum_list = []        # list of Lsum values from each file

            for node in file_indices:
                filename = f"BayesianNN_CIFAR_hmc_SimpleMLP_MAP_d{d_val}_thin{thin_all}_burnin{burnin_all}_node{node}.mat"
                if not os.path.exists(filename):
                    print(f"File {filename} not found; skipping this file.")
                    continue
                data = loadmat(filename)
                preds = data['hmc_single_pred']   # shape: (1, N_val, 10)
                params = data['hmc_single_x']       # shape: (1, d_val)
                Lsum_val = float(data['Lsum']) if 'Lsum' in data else np.nan
                rep_preds_list.append(preds)
                rep_params_list.append(params)
                Lsum_list.append(Lsum_val)

            if len(rep_preds_list) == 0:
                print(f"No files loaded for replicate {r_idx+1}; skipping replicate.")
                continue

            # Concatenate along axis 0.
            rep_preds = np.concatenate(rep_preds_list, axis=0)   # shape: (N*P_use, N_val, 10)
            rep_params = np.concatenate(rep_params_list, axis=0)   # shape: (N*P_use, d_val)
            rep_Lsum.append(np.mean(Lsum_list))

            # --- In-Domain (ID) Analysis ---
            rep_preds_soft = softmax_np(rep_preds)  # shape: (N*P_use, N_val, 10)
            aggregated_prob = np.mean(rep_preds_soft, axis=0)  # shape: (N_val, 10)
            pred_labels = np.argmax(aggregated_prob, axis=1)
            accuracy = np.mean(pred_labels == y_val_np)
            eps = 1e-12
            nll = -np.mean(np.log(aggregated_prob[np.arange(len(y_val_np)), y_val_np] + eps))
            total_entropy = -np.sum(aggregated_prob * np.log(aggregated_prob + eps), axis=1)
            avg_total_entropy = np.mean(total_entropy)
            per_sample_entropy = -np.sum(rep_preds_soft * np.log(rep_preds_soft + eps), axis=2)
            avg_particle_entropy = np.mean(per_sample_entropy, axis=0)
            epistemic_uncertainty = total_entropy - avg_particle_entropy
            avg_epistemic = np.mean(epistemic_uncertainty)

            proc_soft = softmax_np(rep_preds)              # shape (P, M, N, C)
            predictive_probs = np.mean(proc_soft, axis=0)

            brier = compute_brier(aggregated_prob, y_val_np)
            # after computing pred_labels = argmax
            #confidences = aggregated_prob[np.arange(len(y_val_np)), pred_labels]
            #correctness  = (pred_labels == y_val_np).astype(int)
            ece = compute_ece(aggregated_prob, y_val_np, n_bins=10)
            #ece   = compute_ece(aggregated_prob_soft[:, 1], y_val_np, n_bins=10)

            y_val_bin = label_binarize(y_val_np, classes=np.arange(10))
            f1_val = f1_score(y_val_np, pred_labels, average='macro', zero_division=0)
            aucpr_val = average_precision_score(y_val_bin, aggregated_prob, average='macro')

            precision_val = precision_score(y_val_np, pred_labels,average='macro', zero_division=0)
            recall_val = recall_score(y_val_np, pred_labels,average='macro', zero_division=0)

            rep_accuracy.append(accuracy)
            rep_f1.append(f1_val)
            rep_aucpr.append(aucpr_val)
            rep_nll.append(nll)
            rep_total_entropy.append(avg_total_entropy)
            rep_epistemic.append(avg_epistemic)

            rep_brier.append(brier)
            rep_ece.append(ece)
            rep_precision.append(precision_val)
            rep_recall.append(recall_val)

            correct_mask = (pred_labels == y_val_np)
            incorrect_mask = (pred_labels != y_val_np)
            if np.sum(correct_mask) > 0:
                total_entropy_correct = np.mean(total_entropy[correct_mask])
                epistemic_correct = np.mean(epistemic_uncertainty[correct_mask])
            else:
                total_entropy_correct = np.nan
                epistemic_correct = np.nan
            if np.sum(incorrect_mask) > 0:
                total_entropy_incorrect = np.mean(total_entropy[incorrect_mask])
                epistemic_incorrect = np.mean(epistemic_uncertainty[incorrect_mask])
            else:
                total_entropy_incorrect = np.nan
                epistemic_incorrect = np.nan
            rep_total_entropy_correct.append(total_entropy_correct)
            rep_epistemic_correct.append(epistemic_correct)
            rep_total_entropy_incorrect.append(total_entropy_incorrect)
            rep_epistemic_incorrect.append(epistemic_incorrect)

            # --- Out-of-Domain (OOD) Analysis ---
            rep_particles = rep_params  # shape: (N*P_use, d_val)
            net_ood = SimpleMLP(input_dim=2048, hidden_dim=0, num_classes=10).to(device)
            net_ood.eval()

            for name, X_ood in [("close", X_close), ("corrupt", X_corrupt), ("far", X_far)]:
                particle_logits_od = []
                for i in range(rep_particles.shape[0]):
                    flat = rep_particles[i]
                    flat_tensor = torch.tensor(flat, dtype=torch.float32, device=device)
                    param_dict = unflatten_params(flat_tensor, net_ood)
                    # Use precomputed embeddings X_ood (shape: [N_ood, 2048])
                    logits = functional_call(net_ood, param_dict, X_ood)
                    particle_logits_od.append(logits.detach().cpu().numpy())
                particle_logits_od = np.stack(particle_logits_od, axis=0)  # shape: (N*P_use, N_ood, 10)
                particle_probs_od = softmax_np(particle_logits_od)
                ensemble_probs_od = np.mean(particle_probs_od, axis=0)  # shape: (N_ood, 10)
                total_entropy_od = -np.sum(ensemble_probs_od * np.log(ensemble_probs_od + eps), axis=1)
                avg_total_entropy_od = np.mean(total_entropy_od)
                particle_entropy_od = -np.sum(particle_probs_od * np.log(particle_probs_od + eps), axis=2)
                avg_particle_entropy_od = np.mean(particle_entropy_od, axis=0)
                epistemic_uncertainty_od = total_entropy_od - avg_particle_entropy_od
                avg_epistemic_od = np.mean(epistemic_uncertainty_od)
                rep_total_entropy_ood[name].append(avg_total_entropy_od)
                rep_epistemic_ood[name].append(avg_epistemic_od)

        def compute_stats(metric_list):
            arr = np.array(metric_list, dtype=float)
            mean_val = np.nanmean(arr)
            stderr_val = np.nanstd(arr, ddof=1) / np.sqrt(len(arr))
            return mean_val, stderr_val

        mean_acc, stderr_acc = compute_stats(rep_accuracy)
        mean_f1, stderr_f1 = compute_stats(rep_f1)
        mean_aucpr, stderr_aucpr = compute_stats(rep_aucpr)
        mean_nll, stderr_nll = compute_stats(rep_nll)
        mean_total_entropy, stderr_total_entropy = compute_stats(rep_total_entropy)
        mean_epistemic, stderr_epistemic = compute_stats(rep_epistemic)
        mean_Lsum, stderr_Lsum = compute_stats(rep_Lsum)
        mean_total_entropy_correct, stderr_total_entropy_correct = compute_stats(rep_total_entropy_correct)
        mean_epistemic_correct, stderr_epistemic_correct = compute_stats(rep_epistemic_correct)
        mean_total_entropy_incorrect, stderr_total_entropy_incorrect = compute_stats(rep_total_entropy_incorrect)
        mean_epistemic_incorrect, stderr_epistemic_incorrect = compute_stats(rep_epistemic_incorrect)

        mean_brier, stderr_brier = compute_stats(rep_brier)
        mean_ece,   stderr_ece   = compute_stats(rep_ece)
        mean_precision, stderr_precision = compute_stats(rep_precision)
        mean_recall,    stderr_recall    = compute_stats(rep_recall)
        
        mean_total_entropy_ood_close, stderr_total_entropy_ood_close = compute_stats(rep_total_entropy_ood["close"])
        mean_epistemic_ood_close, stderr_epistemic_ood_close = compute_stats(rep_epistemic_ood["close"])
        mean_total_entropy_ood_corrupt, stderr_total_entropy_ood_corrupt = compute_stats(rep_total_entropy_ood["corrupt"])
        mean_epistemic_ood_corrupt, stderr_epistemic_ood_corrupt = compute_stats(rep_epistemic_ood["corrupt"])
        mean_total_entropy_ood_far, stderr_total_entropy_ood_far = compute_stats(rep_total_entropy_ood["far"])
        mean_epistemic_ood_far, stderr_epistemic_ood_far = compute_stats(rep_epistemic_ood["far"])

        print(f"\nAggregated In-Domain metrics for P = {P_use} over {R} replicates:")
        print(f"Accuracy:      {mean_acc:.8f} ± {stderr_acc:.8f}")
        print(f"F1 Score:      {mean_f1:.8f} ± {stderr_f1:.8f}")
        print(f"AUC-PR:        {mean_aucpr:.8f} ± {stderr_aucpr:.8f}")
        print(f"NLL:           {mean_nll:.8f} ± {stderr_nll:.8f}")
        print(f"Brier Score:   {mean_brier:.8f} ± {stderr_brier:.8f}")
        print(f"ECE:           {mean_ece:.8f} ± {stderr_ece:.8f}")
        print(f"Precision:     {mean_precision:.8f} ± {stderr_precision:.8f}")
        print(f"Recall:        {mean_recall:.8f} ± {stderr_recall:.8f}")

        print(f"Total Entropy (all): {mean_total_entropy:.8f} ± {stderr_total_entropy:.8f}")
        print(f"Epistemic (all):     {mean_epistemic:.8f} ± {stderr_epistemic:.8f}")
        print(f"Lsum:          {mean_Lsum:.8f} ± {stderr_Lsum:.8f}")
        print("\nFor ID predictions split by correctness:")
        print(f"Correct Predictions - Total Entropy: {mean_total_entropy_correct:.8f} ± {stderr_total_entropy_correct:.8f}")
        print(f"Correct Predictions - Epistemic:     {mean_epistemic_correct:.8f} ± {stderr_epistemic_correct:.8f}")
        print(f"Incorrect Predictions - Total Entropy: {mean_total_entropy_incorrect:.8f} ± {stderr_total_entropy_incorrect:.8f}")
        print(f"Incorrect Predictions - Epistemic:     {mean_epistemic_incorrect:.8f} ± {stderr_epistemic_incorrect:.8f}")

        print(f"\nAggregated OOD metrics for P = {P_use} over {R} replicates:")
        print(f"Close Total Entropy: {mean_total_entropy_ood_close:.8f} ± {stderr_total_entropy_ood_close:.8f}")
        print(f"Close Epistemic:     {mean_epistemic_ood_close:.8f} ± {stderr_epistemic_ood_close:.8f}")
        print(f"Corrupt Total Entropy: {mean_total_entropy_ood_corrupt:.8f} ± {stderr_total_entropy_ood_corrupt:.8f}")
        print(f"Corrupt Epistemic:     {mean_epistemic_ood_corrupt:.8f} ± {stderr_epistemic_ood_corrupt:.8f}")
        print(f"Far Total Entropy: {mean_total_entropy_ood_far:.8f} ± {stderr_total_entropy_ood_far:.8f}")
        print(f"Far Epistemic:     {mean_epistemic_ood_far:.8f} ± {stderr_epistemic_ood_far:.8f}")

        results = {
            'mean_acc': mean_acc,
            'stderr_acc': stderr_acc,
            'mean_f1_id': mean_f1,
            'stderr_f1_id': stderr_f1,
            'mean_aucpr_id': mean_aucpr,
            'stderr_aucpr_id': stderr_aucpr,
            'mean_nll': mean_nll,
            'stderr_nll': stderr_nll,
            'mean_tot_ent': mean_total_entropy,
            'stderr_tot_ent': stderr_total_entropy,
            'mean_epi': mean_epistemic,
            'stderr_epi': stderr_epistemic,
            'mean_Lsum': mean_Lsum,
            'stderr_Lsum': stderr_Lsum,
            'mean_brier': mean_brier,
            'stderr_brier': stderr_brier,
            'mean_ece': mean_ece,
            'stderr_ece': stderr_ece,
            'mean_precision': mean_precision,
            'stderr_precision': stderr_precision,
            'mean_recall': mean_recall,
            'stderr_recall': stderr_recall,
            # correct vs incorrect
            'mean_tot_corr': mean_total_entropy_correct,
            'se_tot_corr': stderr_total_entropy_correct,
            'mean_epi_corr': mean_epistemic_correct,
            'se_epi_corr': stderr_epistemic_correct,
            'mean_tot_inc': mean_total_entropy_incorrect,
            'se_tot_inc': stderr_total_entropy_incorrect,
            'mean_epi_inc': mean_epistemic_incorrect,
            'se_epi_inc': stderr_epistemic_incorrect,
            # ood
            'ood_close_mean_tot': mean_total_entropy_ood_far, 
            'ood_close_se_tot': stderr_total_entropy_ood_far,
            'ood_close_mean_epi': mean_epistemic_ood_far, 
            'ood_close_se_epi': stderr_epistemic_ood_far,
            'ood_corrupt_mean_tot': mean_total_entropy_ood_corrupt, 
            'ood_corrupt_se_tot': stderr_total_entropy_ood_corrupt,
            'ood_corrupt_mean_epi': mean_epistemic_ood_corrupt, 
            'ood_corrupt_se_epi': stderr_epistemic_ood_corrupt,
            'ood_far_mean_tot': mean_total_entropy_ood_far,
            'ood_far_se_tot': stderr_total_entropy_ood_far,
            'ood_far_mean_epi': mean_epistemic_ood_far,
            'ood_far_se_epi': stderr_epistemic_ood_far,
            }

        savemat(f'phmc_aggregated_results_P{P_use}.mat', results)