#!/usr/bin/env python3
"""
This script aggregates replicate files of single‐replicate SMC results 
for CIFAR‑10 (in-domain classes 0–9 and out‐of‑domain and computes 
comprehensive analysis metrics. The file‐naming pattern is assumed to be:

  "BayesianNN_CIFAR_psmc_SimpleMLP_results_d{d_val}_train{N_tr}_val{N_val}_N{N_particles}_M{M_val}_node{node}.mat"

for each replicate and processor setting. For each selected number of processors 
(select_index = [1, 2, 4, 8, 16]) and R replicates, the results from each processor 
are combined by first averaging over each processor’s particles and then over the 
available processors. Using the aggregated estimated probabilities, standard 
in‐domain metrics (accuracy, F1, AUC‑PR, and NLL) are computed. Out‐of‐domain (OOD) 
metrics are also computed, whose embeddings were 
previously computed (e.g. saved as "cifar10_test_embeddings_ood.pt"). Finally, for each 
selected P the across‐replicate means and standard errors are computed and saved.
"""

import os
import numpy as np
from scipy.io import loadmat, savemat
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
import sys
from sklearn.metrics import f1_score, average_precision_score, precision_score, recall_score
from sklearn.preprocessing import label_binarize
from torch.func import functional_call
from multiprocessing import freeze_support

from PIL import Image
import tarfile, urllib.request
import random
import torchvision.datasets as datasets
from torchvision import models

############################################
#      Device and Prior Settings
############################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sigma_w = np.sqrt(0.2)   # standard deviation for weights (used in SimpleMLP)
sigma_b = np.sqrt(0.2)   # standard deviation for biases

############################################
#       Helper Functions
############################################
def unflatten_params(flat, net):
    """Reconstruct parameters (as a dictionary) from a flat vector."""
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

def softmax_np(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

############################################
#      CIFAR-10 Filtered Dataset
############################################
class FilteredDataset(Dataset):
    def __init__(self, dataset, allowed_labels):
        self.data = [(img, label) for img, label in dataset if label in allowed_labels]
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)

############################################
#  OOD Embedding Extraction Function
############################################
# --- 1) CIFAR-100 “close” OOD: classes not in CIFAR-10 ---
def sample_cifar100_not_in_cifar10(n1, root='./data'):
    # Download both datasets
    cifar10 = datasets.CIFAR10(root=root, train=False, download=True)
    cifar100 = datasets.CIFAR100(root=root, train=False, download=True)
    
    # Which CIFAR-100 fine-class names are NOT in CIFAR-10?
    cif10_set = set(cifar10.classes)
    cif100_names = cifar100.classes  # list of 100 names
    ood_names = [name for name in cif100_names if name not in cif10_set]
    
    # Filter indices whose label name is in our OOD set
    idxs = [
        idx for idx, label in enumerate(cifar100.targets)
        if cif100_names[label] in ood_names
    ]
    random.seed(42)
    chosen = random.sample(idxs, n1)
    
    # Return as PIL Images
    return [Image.fromarray(cifar100.data[i]) for i in chosen]

# --- 2) CIFAR-10-C “corrupt” OOD: all corruption types/severities ---
def sample_cifar10c(n2, root='./data'):
    url = 'https://zenodo.org/record/2535967/files/CIFAR-10-C.tar'
    tar_path = os.path.join(root, 'CIFAR-10-C.tar')
    extract_dir = os.path.join(root, 'CIFAR-10-C')
    
    # Download if needed
    if not os.path.exists(tar_path):
        os.makedirs(root, exist_ok=True)
        print("Downloading CIFAR-10-C (≈180 MB)...")
        urllib.request.urlretrieve(url, tar_path)
    
    # Extract if needed
    if not os.path.isdir(extract_dir):
        print("Extracting CIFAR-10-C...")
        with tarfile.open(tar_path) as tar:
            tar.extractall(path=root)
    
    # Load every corruption .npy (skip the labels file)
    all_imgs = []
    for fname in os.listdir(extract_dir):
        if fname.endswith('.npy') and 'labels' not in fname:
            arr = np.load(os.path.join(extract_dir, fname))
            # arr shape: [10000, 32, 32, 3]
            all_imgs.append(arr)
    # Concatenate: shape [#types × 10000, 32,32,3]
    all_imgs = np.vstack(all_imgs)
    
    # Sample and convert
    random.seed(42)
    chosen = random.sample(range(len(all_imgs)), n2)
    return [Image.fromarray(all_imgs[i].astype(np.uint8)) for i in chosen]

# --- 3) SVHN “far” OOD: street-view digits ---
def sample_svhn(n3, root='./data'):
    # transform to PIL so dataset[i][0] is an Image
    #transform = transforms.ToPILImage()
    #svhn = datasets.SVHN(root=root, split='test', download=True, transform=transform)
    svhn = datasets.SVHN(root=root, split='test', download=True, transform=None)
    
    random.seed(42)
    chosen = random.sample(range(len(svhn)), n3)
    return [svhn[i][0] for i in chosen]

def create_resnet50_embedded_ood_datasets(
        n_close, n_corrupt, n_far,
        cache_path='ood_embeddings.pt'):

    if os.path.exists(cache_path):
        print("Loading cached ood embeddings...")
        return torch.load(cache_path)

    print("Cached embeddings not found. Computing ood embeddings...")
    
    # 1) Define the image‐preprocessing transform for OOD PIL images:
    transform = transforms.Compose([
        transforms.Resize(224),                  # match ResNet50’s input size
        transforms.ToTensor(),                   # convert PIL→[0,1] tensor
        transforms.Normalize(                    # ImageNet mean/std
            mean=[0.485, 0.456, 0.406],
            std =[0.229, 0.224, 0.225]
        )
    ])

    # 2) Build the ResNet‐50 feature extractor:
    resnet50 = models.resnet50(pretrained=True)     # download pretrained weights
    # drop the final fully‐connected layer:
    feature_extractor = nn.Sequential(*list(resnet50.children())[:-1])
    feature_extractor.eval()                        # set to eval mode
    feature_extractor.to(device)                    # move to GPU if available

    # 3) Sample PIL images
    close_imgs   = sample_cifar100_not_in_cifar10(n_close)
    corrupt_imgs = sample_cifar10c(n_corrupt)
    far_imgs     = sample_svhn(n_far)

    # 4) Convert → tensors & DataLoaders
    def imgs_to_embeddings(img_list):
        tensors = torch.stack([transform(img) for img in img_list])
        loader = DataLoader(TensorDataset(tensors), batch_size=128, shuffle=False)
        feats = []
        with torch.no_grad():
            for (batch,) in loader:
                batch = batch.to(device)
                out = feature_extractor(batch).view(batch.size(0), -1)
                feats.append(out.cpu())
        return torch.cat(feats, dim=0)

    X_close   = imgs_to_embeddings(close_imgs)
    X_corrupt = imgs_to_embeddings(corrupt_imgs)
    X_far     = imgs_to_embeddings(far_imgs)

    torch.save((X_close, X_corrupt, X_far), cache_path)
    print("ood embeddings computed and saved.")
    return X_close, X_corrupt, X_far

############################################
#      Metrics computation: Brier and ECE
############################################
def compute_brier(probs, labels):
    """
    Multi-class Brier score: 
      (1/N) ∑_i ∑_c (p_{i,c} – 1{y_i=c})^2.
    """
    # probs: shape (N, C), rows sum to 1
    # labels: shape (N,), ints in [0..C-1]
    N, C = probs.shape
    # one-hot encode
    one_hot = np.zeros_like(probs)
    one_hot[np.arange(N), labels] = 1
    return np.mean(np.sum((probs - one_hot)**2, axis=1))


# def compute_ece(probs, labels, n_bins=10):
#     """Expected Calibration Error over `n_bins` equally‐spaced bins."""
#     bins = np.linspace(0.0, 1.0, n_bins + 1)
#     ece = 0.0
#     for i in range(n_bins):
#         mask = (probs >= bins[i]) & (probs < bins[i+1])
#         if not np.any(mask):
#             continue
#         preds_bin = (probs[mask] >= 0.5).astype(int)
#         acc_bin   = (labels[mask] == preds_bin).mean()
#         conf_bin  = probs[mask].mean()
#         ece      += np.abs(acc_bin - conf_bin) * (mask.sum() / len(probs))
#     return ece
def compute_ece(probs, labels, n_bins=10):
    """
    Expected Calibration Error for multiclass outputs.
    
    Args:
      probs: array of shape (N, C) of predicted class probabilities (each row sums to 1)
      labels: array of shape (N,) of integer true labels in [0..C-1]
      n_bins: number of equal-width confidence bins in [0,1]
    
    Returns:
      scalar ECE = sum_b (|acc(b) – conf(b)| · |bin_b|/N)
    """
    # 1) For each example, get predicted class and its confidence
    preds       = np.argmax(probs, axis=1)                             # shape (N,)
    confidences = probs[np.arange(len(probs)), preds]                 # shape (N,)
    correct     = (preds == labels).astype(float)                     # shape (N,)
    
    # 2) Build bins over [0,1]
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    
    # 3) Loop through bins
    for i in range(n_bins):
        low, high = bin_edges[i], bin_edges[i+1]
        in_bin = (confidences >= low) & (confidences < high)
        prop_in_bin = in_bin.mean()
        if prop_in_bin == 0:
            continue
        
        # average accuracy and average confidence in this bin
        acc_bin  = correct[in_bin].mean()
        conf_bin = confidences[in_bin].mean()
        
        # accumulate weighted gap
        ece += prop_in_bin * abs(acc_bin - conf_bin)
    
    return ece

############################################
#      Define SimpleMLP (CIFAR-10 Version)
############################################
class SimpleMLP(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=0, num_classes=10):
        super(SimpleMLP, self).__init__()
        if hidden_dim is None or hidden_dim == 0:
            self.fc = nn.Linear(input_dim, num_classes)
            #nn.init.normal_(self.fc.weight, mean=0, std=sigma_w)
            #if self.fc.bias is not None:
                #nn.init.normal_(self.fc.bias, mean=0, std=sigma_b)
        else:
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, num_classes)
            #nn.init.normal_(self.fc1.weight, mean=0, std=sigma_w)
            #nn.init.normal_(self.fc2.weight, mean=0, std=sigma_w)
            #if self.fc1.bias is not None:
            #    nn.init.normal_(self.fc1.bias, mean=0, std=sigma_b)
            #if self.fc2.bias is not None:
            #    nn.init.normal_(self.fc2.bias, mean=0, std=sigma_b)
    def forward(self, x):
        if hasattr(self, 'fc'):
            return self.fc(x)
        else:
            x = self.fc1(x)
            x = F.relu(x)
            return self.fc2(x)
        

############################################
#   Main Aggregation Code
############################################
if __name__ == '__main__':
    freeze_support()  # For multiprocessing on platforms using 'spawn'
    
    # Load In-Domain CIFAR-10 Data
    transform_cifar = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    ])
    full_val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_cifar)
    allowed_labels = list(range(10))
    filtered_val_dataset = FilteredDataset(full_val_dataset, allowed_labels)

    val_loader = DataLoader(filtered_val_dataset, batch_size=len(filtered_val_dataset), shuffle=False)
    x_val, y_val = next(iter(val_loader))
    y_val_np = y_val.numpy().flatten()  # ground truth labels

    # Load Out-of-Domain (OOD) Data
    # Load (or generate) OOD embeddings.
    n_close, n_corrupt, n_far = 1000, 1000, 1000
    X_close, X_corrupt, X_far = create_resnet50_embedded_ood_datasets(
        n_close, n_corrupt, n_far,
        cache_path="ood_all_embeddings.pt"
        )

    # User-Adjustable Parameters for Summarization
    R = 5  # number of replicates
    select_index = [1, 2, 4, 8]  # selected processor settings
    # File naming parameters
    d_val = 20490   
    N_val = len(filtered_val_dataset)  
    N_particles = 10      
    M_val = 5

    # Ensure OOD embeddings are on the same device.
    X_close = X_close.to(device)
    X_corrupt = X_corrupt.to(device)
    X_far = X_far.to(device)
    
    results = {}

    for P_use in select_index:
        print(f"\n=== Evaluating for P = {P_use} processors ===")
        # Containers for replicate metrics (In-Domain)
        rep_accuracy = []
        rep_f1 = []
        rep_aucpr = []
        rep_nll = []
        rep_brier = []
        rep_ece   = []
        rep_precision = []
        rep_recall    = []    

        rep_total_entropy = []
        rep_epistemic = []
        rep_Lsum = []
        rep_total_entropy_correct = []
        rep_epistemic_correct = []
        rep_total_entropy_incorrect = []
        rep_epistemic_incorrect = []
          
        
        # Containers for replicate metrics (OOD)
        rep_total_entropy_ood = {
            "close":   [],
            "corrupt": [],
            "far":     []
        }
        rep_epistemic_ood = {
            "close":   [],
            "corrupt": [],
            "far":     []
        }
        
        # For each replicate, files are chosen as:
        # node_index = r + p * R + 1 for p = 0, 1, ..., P_use-1.
        for r in range(R):
            proc_preds_list = []    # each element: psmc_single_pred, shape: (N_particles, N_val, 10)
            proc_params_list = []   # each element: psmc_single_x, shape: (N_particles, d_val)
            Lsum_list = []          # list of Lsum values from each file in this replicate
            
            for p in range(P_use):
                node_index = r + p * R + 1
                filename = (f"BayesianNN_CIFAR_MAP_psmc_SimpleMLP"
                            f"_N{N_particles}_M{M_val}_node{node_index}.mat")
                if not os.path.exists(filename):
                    print(f"File {filename} not found; skipping this processor.")
                    continue
                data = loadmat(filename)
                preds = data['psmc_single_pred']      # shape: (N_particles, N_val, 10)
                particles = data['psmc_single_x']       # shape: (N_particles, d_val)
                Lsum_val = float(data['Lsum']) if 'Lsum' in data else np.nan
                proc_preds_list.append(preds)
                proc_params_list.append(particles)
                Lsum_list.append(Lsum_val)
            
            if len(proc_preds_list) == 0:
                print(f"No files loaded for replicate {r+1}; skipping replicate.")
                continue
            
            proc_preds = np.stack(proc_preds_list, axis=0)    # shape: (P_actual, N_particles, N_val, 10)
            proc_params = np.stack(proc_params_list, axis=0)    # shape: (P_actual, N_particles, d_val)
            rep_Lsum.append(np.mean(Lsum_list))
            
            # --- In-Domain (ID) Analysis ---
            ensemble_probs_per_proc = np.mean(proc_preds, axis=1)  # shape: (P_actual, N_val, 10)
            aggregated_prob = np.mean(ensemble_probs_per_proc, axis=0)  # shape: (N_val, 10)
            aggregated_prob_soft = softmax_np(aggregated_prob)
            
            pred_labels = np.argmax(aggregated_prob_soft, axis=1)
            accuracy = np.mean(pred_labels == y_val_np)
            eps = 1e-12
            nll = -np.mean(np.log(aggregated_prob_soft[np.arange(len(y_val_np)), y_val_np] + eps))

            proc_soft = softmax_np(proc_preds)              # shape (P, M, N, C)
            predictive_probs = np.mean(proc_soft, axis=(0,1))
    
            brier = compute_brier(predictive_probs, y_val_np)
            # after computing pred_labels = argmax
            #confidences = aggregated_prob_soft[np.arange(len(y_val_np)), pred_labels]
            #correctness  = (pred_labels == y_val_np).astype(int)
            ece = compute_ece(predictive_probs, y_val_np, n_bins=10)
            #ece   = compute_ece(aggregated_prob_soft[:, 1], y_val_np, n_bins=10)

            total_entropy = -np.sum(aggregated_prob_soft * np.log(aggregated_prob_soft + eps), axis=1)
            avg_total_entropy = np.mean(total_entropy)
            
            proc_preds_soft = softmax_np(proc_preds)  # shape: (P_actual, N_particles, N_val, 10)
            ensemble_prob_all = np.mean(proc_preds_soft, axis=(0,1))  # shape: (N_val, 10)
            total_entropy_all = -np.sum(ensemble_prob_all * np.log(ensemble_prob_all + eps), axis=1)
            particle_entropy = -np.sum(proc_preds_soft * np.log(proc_preds_soft + eps), axis=-1)
            avg_particle_entropy = np.mean(particle_entropy, axis=(0,1))
            epistemic_uncertainty = total_entropy_all - avg_particle_entropy
            avg_epistemic = np.mean(epistemic_uncertainty)
            
            y_val_bin = label_binarize(y_val_np, classes=np.arange(10))
            f1_val = f1_score(y_val_np, pred_labels, average='macro', zero_division=0)
            aucpr_val = average_precision_score(y_val_bin, aggregated_prob_soft, average='macro')
            precision_val = precision_score(y_val_np, pred_labels,average='macro', zero_division=0)
            recall_val = recall_score(y_val_np, pred_labels,average='macro', zero_division=0)
            
            rep_accuracy.append(accuracy)
            rep_f1.append(f1_val)
            rep_aucpr.append(aucpr_val)
            rep_nll.append(nll)
            rep_precision.append(precision_val)
            rep_recall.append(recall_val)

            rep_brier.append(brier)
            rep_ece.append(ece)

            rep_total_entropy.append(avg_total_entropy)
            rep_epistemic.append(avg_epistemic)
            
            correct_mask = (pred_labels == y_val_np)
            incorrect_mask = (pred_labels != y_val_np)
            if np.sum(correct_mask) > 0:
                total_entropy_correct = np.mean(total_entropy[correct_mask])
                epistemic_correct = np.mean(epistemic_uncertainty[correct_mask])
            else:
                total_entropy_correct = np.nan
                epistemic_correct = np.nan
            if np.sum(incorrect_mask) > 0:
                total_entropy_incorrect = np.mean(total_entropy[incorrect_mask])
                epistemic_incorrect = np.mean(epistemic_uncertainty[incorrect_mask])
            else:
                total_entropy_incorrect = np.nan
                epistemic_incorrect = np.nan
            rep_total_entropy_correct.append(total_entropy_correct)
            rep_epistemic_correct.append(epistemic_correct)
            rep_total_entropy_incorrect.append(total_entropy_incorrect)
            rep_epistemic_incorrect.append(epistemic_incorrect)
            
            # --- Out-of-Domain (OOD) Analysis ---
            # Flatten particles across processors.
            rep_particles = proc_params.reshape(-1, d_val)
            # Create a model for OOD evaluation and move it to the chosen device.
            net_ood = SimpleMLP(input_dim=2048, hidden_dim=0, num_classes=10).to(device)
            net_ood.eval()
            
            for name, X_ood in [("close", X_close), ("corrupt", X_corrupt), ("far", X_far)]:
                particle_logits_od = []
                with torch.no_grad():
                    for i in range(rep_particles.shape[0]):
                        flat = rep_particles[i]
                        # Move the flat parameter vector to the device.
                        flat_tensor = torch.tensor(flat, dtype=torch.float32, device=device)
                        param_dict = unflatten_params(flat_tensor, net_ood)
                        # Ensure X_ood is on device_used.
                        logits = functional_call(net_ood, param_dict, X_ood)
                        particle_logits_od.append(logits.detach().cpu().numpy())
                particle_logits_od = np.stack(particle_logits_od, axis=0)  # shape: (total_particles, N_ood, 10)
                particle_probs_od = softmax_np(particle_logits_od)
                ensemble_probs_od = np.mean(particle_probs_od, axis=0)  # shape: (N_ood, 10)
                total_entropy_od = -np.sum(ensemble_probs_od * np.log(ensemble_probs_od + eps), axis=1)
                avg_total_entropy_od = np.mean(total_entropy_od)
                particle_entropy_od = -np.sum(particle_probs_od * np.log(particle_probs_od + eps), axis=2)
                avg_particle_entropy_od = np.mean(particle_entropy_od, axis=0)
                epistemic_uncertainty_od = total_entropy_od - avg_particle_entropy_od
                avg_epistemic_od = np.mean(epistemic_uncertainty_od)
                ens_preds_od = np.argmax(ensemble_probs_od, axis=1)
                aucpr_od = 0.0  # Not meaningful for OOD in this context
                rep_total_entropy_ood[name].append(avg_total_entropy_od)
                rep_epistemic_ood[name].append(avg_epistemic_od)

        
        # Compute overall replicate statistics
        def compute_stats(metric_list):
            arr = np.array(metric_list, dtype=float)
            mean_val = np.nanmean(arr)
            stderr_val = np.nanstd(arr, ddof=1) / np.sqrt(len(arr))
            return mean_val, stderr_val
        
        mean_acc, stderr_acc = compute_stats(rep_accuracy)
        mean_f1, stderr_f1 = compute_stats(rep_f1)
        mean_aucpr, stderr_aucpr = compute_stats(rep_aucpr)
        mean_nll, stderr_nll = compute_stats(rep_nll)
        mean_precision, stderr_precision = compute_stats(rep_precision)
        mean_recall,    stderr_recall    = compute_stats(rep_recall)

        mean_total_entropy, stderr_total_entropy = compute_stats(rep_total_entropy)
        mean_epistemic, stderr_epistemic = compute_stats(rep_epistemic)
        mean_Lsum, stderr_Lsum = compute_stats(rep_Lsum)
        mean_total_entropy_correct, stderr_total_entropy_correct = compute_stats(rep_total_entropy_correct)
        mean_epistemic_correct, stderr_epistemic_correct = compute_stats(rep_epistemic_correct)
        mean_total_entropy_incorrect, stderr_total_entropy_incorrect = compute_stats(rep_total_entropy_incorrect)
        mean_epistemic_incorrect, stderr_epistemic_incorrect = compute_stats(rep_epistemic_incorrect)

        mean_brier, stderr_brier = compute_stats(rep_brier)
        mean_ece,   stderr_ece   = compute_stats(rep_ece)
        
        mean_total_entropy_ood_close, stderr_total_entropy_ood_close = compute_stats(rep_total_entropy_ood["close"])
        mean_epistemic_ood_close, stderr_epistemic_ood_close = compute_stats(rep_epistemic_ood["close"])
        mean_total_entropy_ood_corrupt, stderr_total_entropy_ood_corrupt = compute_stats(rep_total_entropy_ood["corrupt"])
        mean_epistemic_ood_corrupt, stderr_epistemic_ood_corrupt = compute_stats(rep_epistemic_ood["corrupt"])
        mean_total_entropy_ood_far, stderr_total_entropy_ood_far = compute_stats(rep_total_entropy_ood["far"])
        mean_epistemic_ood_far, stderr_epistemic_ood_far = compute_stats(rep_epistemic_ood["far"])
        
        print(f"\nAggregated In-Domain metrics for P = {P_use} over {R} replicates:")
        print(f"Accuracy:      {mean_acc:.8f} ± {stderr_acc:.8f}")
        print(f"F1 Score:      {mean_f1:.8f} ± {stderr_f1:.8f}")
        print(f"AUC-PR:        {mean_aucpr:.8f} ± {stderr_aucpr:.8f}")
        print(f"NLL:           {mean_nll:.8f} ± {stderr_nll:.8f}")
        print(f"Brier Score:   {mean_brier:.8f} ± {stderr_brier:.8f}")
        print(f"ECE:           {mean_ece:.8f} ± {stderr_ece:.8f}")
        print(f"Precision:     {mean_precision:.8f} ± {stderr_precision:.8f}")
        print(f"Recall:        {mean_recall:.8f} ± {stderr_recall:.8f}")

        print(f"Total Entropy (all): {mean_total_entropy:.8f} ± {stderr_total_entropy:.8f}")
        print(f"Epistemic (all):     {mean_epistemic:.8f} ± {stderr_epistemic:.8f}")
        print(f"Lsum:          {mean_Lsum:.8f} ± {stderr_Lsum:.8f}")
        print("\nFor ID predictions split by correctness:")
        print(f"Correct Predictions - Total Entropy: {mean_total_entropy_correct:.8f} ± {stderr_total_entropy_correct:.8f}")
        print(f"Correct Predictions - Epistemic:     {mean_epistemic_correct:.8f} ± {stderr_epistemic_correct:.8f}")
        print(f"Incorrect Predictions - Total Entropy: {mean_total_entropy_incorrect:.8f} ± {stderr_total_entropy_incorrect:.8f}")
        print(f"Incorrect Predictions - Epistemic:     {mean_epistemic_incorrect:.8f} ± {stderr_epistemic_incorrect:.8f}")
        
        print(f"\nAggregated OOD metrics for P = {P_use} over {R} replicates:")
        print(f"Close Total Entropy: {mean_total_entropy_ood_close:.8f} ± {stderr_total_entropy_ood_close:.8f}")
        print(f"Close Epistemic:     {mean_epistemic_ood_close:.8f} ± {stderr_epistemic_ood_close:.8f}")
        print(f"Corrupt Total Entropy: {mean_total_entropy_ood_corrupt:.8f} ± {stderr_total_entropy_ood_corrupt:.8f}")
        print(f"Corrupt Epistemic:     {mean_epistemic_ood_corrupt:.8f} ± {stderr_epistemic_ood_corrupt:.8f}")
        print(f"Far Total Entropy: {mean_total_entropy_ood_far:.8f} ± {stderr_total_entropy_ood_far:.8f}")
        print(f"Far Epistemic:     {mean_epistemic_ood_far:.8f} ± {stderr_epistemic_ood_far:.8f}")
        
        results = {
            'mean_acc': mean_acc,
            'stderr_acc': stderr_acc,
            'mean_f1_id': mean_f1,
            'stderr_f1_id': stderr_f1,
            'mean_aucpr_id': mean_aucpr,
            'stderr_aucpr_id': stderr_aucpr,
            'mean_nll': mean_nll,
            'stderr_nll': stderr_nll,
            'mean_tot_ent': mean_total_entropy,
            'stderr_tot_ent': stderr_total_entropy,
            'mean_epi': mean_epistemic,
            'stderr_epi': stderr_epistemic,
            'mean_Lsum': mean_Lsum,
            'stderr_Lsum': stderr_Lsum,
            'mean_brier': mean_brier,
            'stderr_brier': stderr_brier,
            'mean_ece': mean_ece,
            'stderr_ece': stderr_ece,
            'mean_precision': mean_precision,
            'stderr_precision': stderr_precision,
            'mean_recall': mean_recall,
            'stderr_recall': stderr_recall,
            # correct vs incorrect
            'mean_tot_corr': mean_total_entropy_correct,
            'se_tot_corr': stderr_total_entropy_correct,
            'mean_epi_corr': mean_epistemic_correct,
            'se_epi_corr': stderr_epistemic_correct,
            'mean_tot_inc': mean_total_entropy_incorrect,
            'se_tot_inc': stderr_total_entropy_incorrect,
            'mean_epi_inc': mean_epistemic_incorrect,
            'se_epi_inc': stderr_epistemic_incorrect,
            # ood
            'ood_close_mean_tot': mean_total_entropy_ood_far, 
            'ood_close_se_tot': stderr_total_entropy_ood_far,
            'ood_close_mean_epi': mean_epistemic_ood_far, 
            'ood_close_se_epi': stderr_epistemic_ood_far,
            'ood_corrupt_mean_tot': mean_total_entropy_ood_corrupt, 
            'ood_corrupt_se_tot': stderr_total_entropy_ood_corrupt,
            'ood_corrupt_mean_epi': mean_epistemic_ood_corrupt, 
            'ood_corrupt_se_epi': stderr_epistemic_ood_corrupt,
            'ood_far_mean_tot': mean_total_entropy_ood_far,
            'ood_far_se_tot': stderr_total_entropy_ood_far,
            'ood_far_mean_epi': mean_epistemic_ood_far,
            'ood_far_se_epi': stderr_epistemic_ood_far,
            }

        savemat(f'psmc_aggregated_results_P{P_use}.mat', results)

