#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MAP Training for Bayesian inference on whole CIFAR-10 (classes 0–9)
using ResNet-50 embeddings and SimpleMLP.

This script extracts (or loads cached) ResNet-50 embeddings from CIFAR-10 
(images with labels 0–9) and then performs MAP training (logistic regression)
with SimpleMLP on a subset of these embeddings. The code runs R replicates and 
computes, for each replicate:
    - Training epochs and total training time.
    - Test accuracy on the in-domain (classes 0–9) validation set.
    - In-domain metrics: Negative Log-Likelihood (NLL), Total Entropy, F1 Score,
      AUC-PR, and a breakdown of total entropy for correct vs. incorrect predictions.
    - Out-of-domain metrics: Close OOD (semantic neighbors), Corrupt OOD and Far OOD, along with
      F1 and AUC-PR (which for MAP are expected to be near zero).
Overall replicate statistics (mean and standard error) are printed and saved.
"""

############################################
#           Imports
############################################
import os
import time
import random
import numpy as np
from scipy.io import savemat

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
from torchvision import models

from sklearn.metrics import f1_score, average_precision_score, precision_score, recall_score
from sklearn.preprocessing import label_binarize

import pyro
from PIL import Image
import tarfile, urllib.request

############################################
#      Device and Prior Settings
############################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sigma_w = np.sqrt(0.2)   # standard deviation for weights (used in SimpleMLP)
sigma_b = np.sqrt(0.2)   # standard deviation for biases

############################################
#      Filtered CIFAR-10 Dataset Class
############################################
class FilteredCIFAR10(Dataset):
    def __init__(self, root, train, transform, download, allowed_labels):
        self.dataset = datasets.CIFAR10(root=root, train=train, transform=transform, download=download)
        self.allowed_labels = allowed_labels
        self.data = [(img, label) for img, label in self.dataset if label in allowed_labels]
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)

############################################
#   ResNet-50 Embedding Extraction (In-Domain: Labels 0-9)
############################################
def create_resnet50_embedded_cifar10_dataset(
    train_cache_path="cifar10_train_embeddings.pt",
    test_cache_path="cifar10_test_embeddings.pt",
    allowed_labels=list(range(10))
):
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached ResNet-50 embeddings for whole CIFAR-10...")
        X_train, y_train = torch.load(train_cache_path)
        X_test, y_test = torch.load(test_cache_path)
        return X_train, y_train, X_test, y_test

    print("Cached embeddings not found. Computing ResNet-50 embeddings for whole CIFAR-10...")
    transform = transforms.Compose([
         transforms.Resize(224),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
    ])
    train_dataset = FilteredCIFAR10(root='./data', train=True, transform=transform, download=True, allowed_labels=allowed_labels)
    test_dataset  = FilteredCIFAR10(root='./data', train=False, transform=transform, download=True, allowed_labels=allowed_labels)
    N_tr = len(train_dataset)
    N_val = len(test_dataset)
    train_dataset = Subset(train_dataset, list(range(N_tr)))
    test_dataset = Subset(test_dataset, list(range(N_val)))
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=2)
    test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
    resnet50 = models.resnet50(pretrained=True)
    # Remove the final fc layer to obtain 2048-dim features.
    feature_extractor = nn.Sequential(*list(resnet50.children())[:-1])
    feature_extractor.eval()
    feature_extractor.to(device)
    X_train_list, y_train_list = [], []
    with torch.no_grad():
        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)
            features = features.view(features.size(0), -1)
            X_train_list.append(features.cpu())
            y_train_list.append(targets)
    X_train = torch.cat(X_train_list, dim=0)
    y_train = torch.cat(y_train_list, dim=0)
    X_test_list, y_test_list = [], []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)
            features = features.view(features.size(0), -1)
            X_test_list.append(features.cpu())
            y_test_list.append(targets)
    X_test = torch.cat(X_test_list, dim=0)
    y_test = torch.cat(y_test_list, dim=0)
    torch.save((X_train, y_train), train_cache_path)
    torch.save((X_test, y_test), test_cache_path)
    print("ResNet-50 embeddings computed and saved.")
    return X_train, y_train, X_test, y_test

############################################
#   Out-of-Domain Embedding Extraction
############################################
# --- 1) CIFAR-100 “close” OOD: classes not in CIFAR-10 ---
def sample_cifar100_not_in_cifar10(n1, root='./data'):
    # Download both datasets
    cifar10 = datasets.CIFAR10(root=root, train=False, download=True)
    cifar100 = datasets.CIFAR100(root=root, train=False, download=True)
    
    # Which CIFAR-100 fine-class names are NOT in CIFAR-10?
    cif10_set = set(cifar10.classes)
    cif100_names = cifar100.classes  # list of 100 names
    ood_names = [name for name in cif100_names if name not in cif10_set]
    
    # Filter indices whose label name is in our OOD set
    idxs = [
        idx for idx, label in enumerate(cifar100.targets)
        if cif100_names[label] in ood_names
    ]
    random.seed(42)
    chosen = random.sample(idxs, n1)
    
    # Return as PIL Images
    return [Image.fromarray(cifar100.data[i]) for i in chosen]

# --- 2) CIFAR-10-C “corrupt” OOD: all corruption types/severities ---
def sample_cifar10c(n2, root='./data'):
    url = 'https://zenodo.org/record/2535967/files/CIFAR-10-C.tar'
    tar_path = os.path.join(root, 'CIFAR-10-C.tar')
    extract_dir = os.path.join(root, 'CIFAR-10-C')
    
    # Download if needed
    if not os.path.exists(tar_path):
        os.makedirs(root, exist_ok=True)
        print("Downloading CIFAR-10-C (≈180 MB)...")
        urllib.request.urlretrieve(url, tar_path)
    
    # Extract if needed
    if not os.path.isdir(extract_dir):
        print("Extracting CIFAR-10-C...")
        with tarfile.open(tar_path) as tar:
            tar.extractall(path=root)
    
    # Load every corruption .npy (skip the labels file)
    all_imgs = []
    for fname in os.listdir(extract_dir):
        if fname.endswith('.npy') and 'labels' not in fname:
            arr = np.load(os.path.join(extract_dir, fname))
            # arr shape: [10000, 32, 32, 3]
            all_imgs.append(arr)
    # Concatenate: shape [#types × 10000, 32,32,3]
    all_imgs = np.vstack(all_imgs)
    
    # Sample and convert
    random.seed(42)
    chosen = random.sample(range(len(all_imgs)), n2)
    return [Image.fromarray(all_imgs[i].astype(np.uint8)) for i in chosen]

# --- 3) SVHN “far” OOD: street-view digits ---
def sample_svhn(n3, root='./data'):
    # transform to PIL so dataset[i][0] is an Image
    #transform = transforms.ToPILImage()
    #svhn = datasets.SVHN(root=root, split='test', download=True, transform=transform)
    svhn = datasets.SVHN(root=root, split='test', download=True, transform=None)
    
    random.seed(42)
    chosen = random.sample(range(len(svhn)), n3)
    return [svhn[i][0] for i in chosen]

def create_resnet50_embedded_ood_datasets(
        n_close, n_corrupt, n_far,
        cache_path='ood_embeddings.pt'):

    if os.path.exists(cache_path):
        print("Loading cached ood embeddings...")
        return torch.load(cache_path)

    print("Cached embeddings not found. Computing ood embeddings...")
    
    # 1) Define the image‐preprocessing transform for OOD PIL images:
    transform = transforms.Compose([
        transforms.Resize(224),                  # match ResNet50’s input size
        transforms.ToTensor(),                   # convert PIL→[0,1] tensor
        transforms.Normalize(                    # ImageNet mean/std
            mean=[0.485, 0.456, 0.406],
            std =[0.229, 0.224, 0.225]
        )
    ])

    # 2) Build the ResNet‐50 feature extractor:
    resnet50 = models.resnet50(pretrained=True)     # download pretrained weights
    # drop the final fully‐connected layer:
    feature_extractor = nn.Sequential(*list(resnet50.children())[:-1])
    feature_extractor.eval()                        # set to eval mode
    feature_extractor.to(device)                    # move to GPU if available

    # 3) Sample PIL images
    close_imgs   = sample_cifar100_not_in_cifar10(n_close)
    corrupt_imgs = sample_cifar10c(n_corrupt)
    far_imgs     = sample_svhn(n_far)

    # 4) Convert → tensors & DataLoaders
    def imgs_to_embeddings(img_list):
        tensors = torch.stack([transform(img) for img in img_list])
        loader = DataLoader(TensorDataset(tensors), batch_size=128, shuffle=False)
        feats = []
        with torch.no_grad():
            for (batch,) in loader:
                batch = batch.to(device)
                out = feature_extractor(batch).view(batch.size(0), -1)
                feats.append(out.cpu())
        return torch.cat(feats, dim=0)

    X_close   = imgs_to_embeddings(close_imgs)
    X_corrupt = imgs_to_embeddings(corrupt_imgs)
    X_far     = imgs_to_embeddings(far_imgs)

    torch.save((X_close, X_corrupt, X_far), cache_path)
    print("ood embeddings computed and saved.")
    return X_close, X_corrupt, X_far


############################################
#         SimpleMLP Model Definition
############################################
class SimpleMLP(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=0, num_classes=10):
        super(SimpleMLP, self).__init__()
        if hidden_dim is None or hidden_dim == 0:
            # Logistic regression: no hidden layer.
            self.fc = nn.Linear(input_dim, num_classes)
            nn.init.normal_(self.fc.weight, mean=0, std=sigma_w)
            if self.fc.bias is not None:
                nn.init.normal_(self.fc.bias, mean=0, std=sigma_b)
        else:
            # Single-hidden-layer MLP.
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, num_classes)
            nn.init.normal_(self.fc1.weight, mean=0, std=sigma_w)
            nn.init.normal_(self.fc2.weight, mean=0, std=sigma_w)
            if self.fc1.bias is not None:
                nn.init.normal_(self.fc1.bias, mean=0, std=sigma_b)
            if self.fc2.bias is not None:
                nn.init.normal_(self.fc2.bias, mean=0, std=sigma_b)
    def forward(self, x):
        if hasattr(self, 'fc'):
            return self.fc(x)
        else:
            x = self.fc1(x)
            x = F.relu(x)
            return self.fc2(x)

############################################
# Helper function: Compute Accuracy
############################################
def compute_accuracy(model, x, y):
    model.eval()
    with torch.no_grad():
        logits = model(x)
        preds = logits.argmax(dim=1)
        correct = (preds == y.view(-1)).sum().item()
        return 100 * correct / y.size(0)
    
############################################
#      Metrics computation: Brier and ECE
############################################
#def compute_brier(probs, labels):
#    """
#    Multi-class Brier score: 
#      (1/N) ∑_i ∑_c (p_{i,c} – 1{y_i=c})^2.
#    """
#    # probs: shape (N, C), rows sum to 1
#    # labels: shape (N,), ints in [0..C-1]
#    N, C = probs.shape
#    # one-hot encode
#    one_hot = np.zeros_like(probs)
#    one_hot[np.arange(N), labels] = 1
#    return np.mean(np.sum((probs - one_hot)**2, axis=1))
def compute_brier(probs, labels):
    # probs: (N,C) tensor, labels: (N,) long tensor
    N, C = probs.shape
    one_hot = torch.zeros_like(probs).scatter_(1, labels.unsqueeze(1), 1.0)
    # squared error per example:
    per_example = torch.sum((probs - one_hot) ** 2, dim=1)
    return per_example.mean().item()



# def compute_ece(probs, labels, n_bins=10):
#     """Expected Calibration Error over `n_bins` equally‐spaced bins."""
#     bins = np.linspace(0.0, 1.0, n_bins + 1)
#     ece = 0.0
#     for i in range(n_bins):
#         mask = (probs >= bins[i]) & (probs < bins[i+1])
#         if not np.any(mask):
#             continue
#         preds_bin = (probs[mask] >= 0.5).astype(int)
#         acc_bin   = (labels[mask] == preds_bin).mean()
#         conf_bin  = probs[mask].mean()
#         ece      += np.abs(acc_bin - conf_bin) * (mask.sum() / len(probs))
#     return ece
def compute_ece(probs, labels, n_bins=10):
    """
    Expected Calibration Error for multiclass outputs.
    
    Args:
      probs: array of shape (N, C) of predicted class probabilities (each row sums to 1)
      labels: array of shape (N,) of integer true labels in [0..C-1]
      n_bins: number of equal-width confidence bins in [0,1]
    
    Returns:
      scalar ECE = sum_b (|acc(b) – conf(b)| · |bin_b|/N)
    """
    # 1) For each example, get predicted class and its confidence
    #preds       = np.argmax(probs, axis=1)                             # shape (N,)
    #confidences = probs[np.arange(len(probs)), preds]                 # shape (N,)
    #correct     = (preds == labels).astype(float)                     # shape (N,)
    if isinstance(probs, torch.Tensor):
        probs  = probs.detach().cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.detach().cpu().numpy()
    # 1) For each example, get predicted class and its confidence
    preds       = np.argmax(probs, axis=1)                            # shape (N,)
    confidences = probs[np.arange(len(probs)), preds]                 # shape (N,)
    correct     = (preds == labels).astype(float)
    
    # 2) Build bins over [0,1]
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    
    # 3) Loop through bins
    for i in range(n_bins):
        low, high = bin_edges[i], bin_edges[i+1]
        if i == n_bins - 1:
            in_bin = (confidences >= low) & (confidences <= high)
        else:
            in_bin = (confidences >= low) & (confidences < high)
        prop_in_bin = in_bin.mean()
        if prop_in_bin == 0:
            continue
        
        # average accuracy and average confidence in this bin
        acc_bin  = correct[in_bin].mean()
        conf_bin = confidences[in_bin].mean()
        
        # accumulate weighted gap
        ece += prop_in_bin * abs(acc_bin - conf_bin)
    
    return ece

############################################
# Helper function: Compute mean and standard error.
############################################
def compute_stats(metric_list, R):
    arr = np.array(metric_list)
    mean_val = np.mean(arr)
    stderr_val = np.std(arr, ddof=1) / np.sqrt(R)
    return mean_val, stderr_val

############################################
#          Main Execution: MAP Training over Replicates
############################################
if __name__ == '__main__':
    # Number of replicates.
    R = 5

    # Lists to store replicate metrics.
    replicate_epochs = []          # MAP training epochs per replicate.
    replicate_map_time = []        # Total training time per replicate.
    replicate_accuracy = []        # Test accuracy (%) on in-domain validation set.
    replicate_nll_in = []          # In-domain Negative Log Likelihood.
    replicate_total_entropy_in = []  # In-domain Total Entropy.
    replicate_f1_in = []           # In-domain F1 score (macro).
    replicate_aucpr_in = []        # In-domain AUC-PR.

    replicate_brier = []
    replicate_ece   = []
    replicate_precision = []
    replicate_recall    = []
    #replicate_total_entropy_od = []  # Out-of-domain Total Entropy.
    # For in-domain breakdown.
    replicate_total_entropy_in_correct = []
    replicate_total_entropy_in_incorrect = []

    replicate_total_entropy_od = {
        "close":   [],
        "corrupt": [],
        "far":     []
    }

    # Load the full in-domain embeddings (for labels 0-9) and split:
    allowed_labels = list(range(10))
    train_cache = "cifar10_train_embeddings.pt"
    test_cache  = "cifar10_test_embeddings.pt"
    X_all_train, y_all_train, X_all_val, y_all_val = create_resnet50_embedded_cifar10_dataset(
            train_cache_path=train_cache,
            test_cache_path=test_cache,
            allowed_labels=allowed_labels
    )
    # Shuffle and split into 40 000 train / 10 000 early‐stop validation
    n_total = X_all_train.size(0)        # should be 50 000 for CIFAR-10
    print(f"Total CIFAR-10 train embeddings: {n_total}")

    # Split
    n_train_es = 50000
    n_val_es   = n_total - n_train_es    # = 25 000

    X_train_es = X_all_train[:n_train_es]
    y_train_es = y_all_train[:n_train_es]
    X_val_es   = X_all_train[n_train_es:]
    y_val_es   = y_all_train[n_train_es:]

    # Whole test (labels 0–9) remains as your in‐domain test set
    X_test_id,  y_test_id  = X_all_val, y_all_val

    # Create training and early-stop validation subsets.
    X_train_subset, y_train_subset = X_train_es, y_train_es
    X_val_subset,   y_val_subset   = X_val_es,   y_val_es

    # Create TensorDatasets and DataLoaders.
    train_dataset = TensorDataset(X_train_subset, y_train_subset)
    val_dataset   = TensorDataset(X_val_subset, y_val_subset)
    batch_size = 128
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


    for r in range(1, R+1):
        print(f"\n===== Starting replicate {r} =====\n")
        # Set replicate seed.
        torch.manual_seed(r)
        np.random.seed(r)
        random.seed(r)
        pyro.set_rng_seed(r)
        
        ############################################
        # MAP Training (Logistic Regression with SimpleMLP)
        ############################################
        model_mlp = SimpleMLP(input_dim=X_train_subset.shape[1], hidden_dim=0, num_classes=10).to(device)
        print(f'dimension d={sum(p.numel() for p in model_mlp.parameters())}')
        optimizer_mlp = optim.Adam(model_mlp.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()

        train_losses = []
        val_losses = []
        moving_avg_window = 10
        best_moving_avg = float('inf')
        patience = 5
        no_improve_count = 0
        max_epochs = 200

        start_time = time.time()
        for epoch in range(max_epochs):
            model_mlp.train()
            running_loss = 0.0
            for inputs, labels in train_loader:                   # iterate mini-batches
                #x_batch, y_batch = x_batch.to(device), y_batch.to(device)

                optimizer_mlp.zero_grad()
                outputs = model_mlp(inputs)
                ce_loss = criterion(outputs, labels)

                # MAP regularizer on this batch
                reg_loss = (torch.sum(model_mlp.fc.weight**2) / (2 * sigma_w**2) +
                            torch.sum(model_mlp.fc.bias**2)   / (2 * sigma_b**2))
                loss = ce_loss + reg_loss / len(train_loader.dataset)             # divide reg by total train size
                loss.backward()
                optimizer_mlp.step()
                running_loss += loss.item()
            train_loss = running_loss / len(train_loader)
            print(f"Replicate {r}, MAP Epoch {epoch+1:04d}: Train Loss = {loss.item():.4f}")

            # model_mlp.eval()
            # val_running_loss = 0.0
            # with torch.no_grad():
            #     for inputs, labels in val_loader:
            #         outputs_val = model_mlp(inputs)
            #         ce_loss_val = criterion(outputs_val, labels)
            #         reg_loss_val = (torch.sum(model_mlp.fc.weight**2) / (2 * sigma_w**2) +
            #                         torch.sum(model_mlp.fc.bias**2)   / (2 * sigma_b**2))
            #         val_running_loss = (ce_loss_val + reg_loss_val / len(train_loader.dataset)).item()
            # val_loss = val_running_loss / len(val_loader)
            # val_losses.append(val_loss)
        
            # print(f"Replicate {r}, MAP Epoch {epoch+1:04d}: Train Loss = {loss.item():.4f}, Val Loss = {val_loss:.4f}")
            # if epoch >= moving_avg_window - 1:
            #     moving_avg = np.mean(val_losses[-moving_avg_window:])
            #     if moving_avg < best_moving_avg:
            #         best_moving_avg = moving_avg
            #         no_improve_count = 0
            #     else:
            #         no_improve_count += 1
            #     if no_improve_count >= patience:
            #         print(f"Replicate {r}, MAP: Early stopping at epoch {epoch+1}")
            #         break

        map_epochs = epoch + 1
        total_map_time = time.time() - start_time
        print(f"\nReplicate {r}, MAP Total execution time: {total_map_time:.2f} seconds")
        replicate_epochs.append(map_epochs)
        replicate_map_time.append(total_map_time)

        ############################################
        # Evaluate Test Accuracy on In-Domain Test Set (whole test)
        ############################################
        # move full test set onto device
        x_test_full = X_test_id.to(device)
        y_test_full = y_test_id.to(device)
        model_mlp.eval()
        acc = compute_accuracy(model_mlp, x_test_full, y_test_full)
        print(f"Replicate {r}, MAP Test Accuracy (classes 0-9): {acc:.2f}%")
        replicate_accuracy.append(acc)

        ############################################
        # In-Domain Analysis (On the full test set)
        ############################################
        model_mlp.eval()
        with torch.no_grad():
            outputs_test = model_mlp(x_test_full)
            probs = F.softmax(outputs_test, dim=1)
        eps = 1e-10
        # Negative Log Likelihood (NLL) per sample.
        nll_samples = -torch.log(probs[range(len(y_test_full)), y_test_full] + eps)
        avg_nll = nll_samples.mean().item()
        replicate_nll_in.append(avg_nll)

        # Total Entropy per sample.
        total_entropy = -(probs * torch.log(probs + eps)).sum(dim=1).cpu().numpy()
        avg_total_entropy = np.mean(total_entropy)
        replicate_total_entropy_in.append(avg_total_entropy)
        # Predictions and compute F1 and AUC-PR.
        y_pred_in = probs.argmax(dim=1).cpu().numpy()
        y_true_in = y_test_full.cpu().numpy()
        f1_in = f1_score(y_true_in, y_pred_in, average='macro', zero_division=0)
        all_labels_bin = label_binarize(y_true_in, classes=list(range(10)))
        aucpr_in = average_precision_score(all_labels_bin, probs.cpu().numpy(), average='macro')
        replicate_f1_in.append(f1_in)
        replicate_aucpr_in.append(aucpr_in)
        #print(f"\nReplicate {r}, MAP In-Domain Analysis (classes 0-9 on test set):")

        brier = compute_brier(probs, y_test_full)
        # after computing y_pred_in = argmax
        #confidences = probs[np.arange(len(y_test_full)), y_pred_in]
        #correctness  = (y_pred_in == y_test_full).astype(int)
        ece = compute_ece(probs, y_test_full, n_bins=10)
        replicate_brier.append(brier)
        replicate_ece.append(ece)

        y_val_bin = label_binarize(y_test_full, classes=np.arange(10))
        f1_val = f1_score(y_test_full, y_pred_in, average='macro', zero_division=0)
        aucpr_val = average_precision_score(y_val_bin, probs, average='macro')
        precision_val = precision_score(y_test_full, y_pred_in,average='macro', zero_division=0)
        recall_val = recall_score(y_test_full, y_pred_in,average='macro', zero_division=0)
        replicate_precision.append(precision_val)
        replicate_recall.append(recall_val)

        # Breakdown for in-domain predictions (MAP is deterministic so epistemic uncertainty = 0).
        correct_mask = (y_pred_in == y_true_in)
        incorrect_mask = (y_pred_in != y_true_in)
        if np.sum(correct_mask) > 0:
            tot_ent_correct = np.mean(total_entropy[correct_mask])
        else:
            tot_ent_correct = np.nan
        if np.sum(incorrect_mask) > 0:
            tot_ent_incorrect = np.mean(total_entropy[incorrect_mask])
        else:
            tot_ent_incorrect = np.nan
        replicate_total_entropy_in_correct.append(tot_ent_correct)
        replicate_total_entropy_in_incorrect.append(tot_ent_incorrect)

        print(f"\nReplicate {r}, MAP In-Domain Analysis (Classes 0-9):")
        print(f"  Average NLL: {avg_nll:.6f}")
        print(f"  Average Total Entropy: {avg_total_entropy:.6f}")
        print(f"  F1 Score (macro): {f1_in:.6f}")
        print(f"  AUC-PR (macro): {aucpr_in:.6f}")
        print("Breakdown for In-Domain Predictions:")
        print(f"  Correct Predictions - Total Entropy: {tot_ent_correct:.6f}")
        print(f"  Incorrect Predictions - Total Entropy: {tot_ent_incorrect:.6f}")

        ############################################
        # Out-of-Domain Analysis
        ############################################
        print("\nComputing out-of-domain embeddings...")
        n_close, n_corrupt, n_far = 1000, 1000, 1000

        X_close, X_corrupt, X_far = create_resnet50_embedded_ood_datasets(
            n_close, n_corrupt, n_far,
            cache_path="ood_all_embeddings.pt"
            )
        model_mlp.eval()
        for name, X_ood in [("close", X_close), ("corrupt", X_corrupt), ("far", X_far)]:
            with torch.no_grad():
                logits_ood = model_mlp(X_ood.to(device))
                probs_ood  = F.softmax(logits_ood, dim=1)
            # entropy per sample
            ent = -(probs_ood * torch.log(probs_ood + 1e-10)).sum(dim=1).cpu().numpy()
            avg_ent = ent.mean()
            print(f"Replicate {r}, OOD ({name}) Avg. Entropy: {avg_ent:.6f}")
            # if you want F1/AUC-PR you can label-binarize a dummy y_ood,
            # but often for OOD you just report entropy.
            replicate_total_entropy_od[name].append(avg_ent)

    ############################################
    # After replicates: Compute overall statistics.
    ############################################
    epochs_mean, epochs_stderr = compute_stats(replicate_epochs, R)
    accuracy_mean, accuracy_stderr = compute_stats(replicate_accuracy, R)
    nll_in_mean, nll_in_stderr = compute_stats(replicate_nll_in, R)
    tot_ent_in_mean, tot_ent_in_stderr = compute_stats(replicate_total_entropy_in, R)
    f1_in_mean, f1_in_stderr = compute_stats(replicate_f1_in, R)
    aucpr_in_mean, aucpr_in_stderr = compute_stats(replicate_aucpr_in, R)
    time_mean, time_stderr = compute_stats(replicate_map_time, R)

    mean_precision, stderr_precision = compute_stats(replicate_precision, R)
    mean_recall,    stderr_recall    = compute_stats(replicate_recall, R)
    mean_brier, stderr_brier = compute_stats(replicate_brier, R)
    mean_ece,   stderr_ece   = compute_stats(replicate_ece, R)

    # Compute stats for the in-domain breakdown.
    tot_ent_in_corr_mean, tot_ent_in_corr_stderr = compute_stats(replicate_total_entropy_in_correct, R)
    tot_ent_in_inc_mean, tot_ent_in_inc_stderr = compute_stats(replicate_total_entropy_in_incorrect, R)

    # ood
    tot_ent_od_close_mean, tot_ent_od_close_stderr = compute_stats(replicate_total_entropy_od["close"], R)
    tot_ent_od_corrupt_mean, tot_ent_od_corrupt_stderr = compute_stats(replicate_total_entropy_od["corrupt"], R)
    tot_ent_od_far_mean, tot_ent_od_far_stderr = compute_stats(replicate_total_entropy_od["far"], R)

    print("\n===== Summary over {} replicates =====".format(R))
    print(f"MAP Epochs: Mean = {epochs_mean:.6f}, SE = {epochs_stderr:.6f}")
    print(f"MAP Test Accuracy (%): Mean = {accuracy_mean:.6f}, SE = {accuracy_stderr:.6f}")
    print(f"In-Domain NLL: Mean = {nll_in_mean:.6f}, SE = {nll_in_stderr:.6f}")
    print(f"Brier Score:   {mean_brier:.6f} ± {stderr_brier:.6f}")
    print(f"ECE:           {mean_ece:.6f} ± {stderr_ece:.6f}")
    print(f"Precision:     {mean_precision:.4f} ± {stderr_precision:.4f}")
    print(f"Recall:        {mean_recall:.4f} ± {stderr_recall:.4f}")
    print(f"In-Domain Total Entropy: Mean = {tot_ent_in_mean:.6f}, SE = {tot_ent_in_stderr:.6f}")
    print(f"In-Domain F1 Score (macro): Mean = {f1_in_mean:.6f}, SE = {f1_in_stderr:.6f}")
    print(f"In-Domain AUC-PR (macro): Mean = {aucpr_in_mean:.6f}, SE = {aucpr_in_stderr:.6f}")
    print(f"Out-of-Domain Close Total Entropy: Mean = {tot_ent_od_close_mean:.6f}, SE = {tot_ent_od_close_stderr:.6f}")
    print(f"Out-of-Domain Corrupt Total Entropy: Mean = {tot_ent_od_corrupt_mean:.6f}, SE = {tot_ent_od_corrupt_stderr:.6f}")
    print(f"Out-of-Domain Far Total Entropy: Mean = {tot_ent_od_far_mean:.6f}, SE = {tot_ent_od_far_stderr:.6f}")
    print(f"MAP Training Time (s): Mean = {time_mean:.6f}, SE = {time_stderr:.6f}")
    print("\nBreakdown for In-Domain Predictions:")
    print(f"  Correct Predictions - Total Entropy: Mean = {tot_ent_in_corr_mean:.6f}, SE = {tot_ent_in_corr_stderr:.6f}")
    print(f"  Incorrect Predictions - Total Entropy: Mean = {tot_ent_in_inc_mean:.6f}, SE = {tot_ent_in_inc_stderr:.6f}")
    

    ############################################
    # Save all computed metrics and replicate results.
    ############################################
    metrics = {
        'R': R,
        'replicate_epochs': np.array(replicate_epochs),
        'replicate_map_time': np.array(replicate_map_time),
        'replicate_accuracy': np.array(replicate_accuracy),
        'replicate_nll': np.array(replicate_nll_in),
        'replicate_total_entropy': np.array(replicate_total_entropy_in),
        'replicate_f1_score': np.array(replicate_f1_in),
        'replicate_auc_pr': np.array(replicate_aucpr_in),
        'replicate_total_entropy_od_close': np.array(replicate_total_entropy_od['close']),
        'replicate_total_entropy_od_corrupt': np.array(replicate_total_entropy_od['corrupt']),
        'replicate_total_entropy_od_far': np.array(replicate_total_entropy_od['far']),
        'replicate_correct_entropy': np.array(replicate_total_entropy_in_correct),
        'replicate_incorrect_entropy': np.array(replicate_total_entropy_in_incorrect),
        # summary
        'epochs_mean': epochs_mean,
        'epochs_se': epochs_stderr,
        'accuracy_mean': accuracy_mean,
        'accuracy_se': accuracy_stderr,
        'nll_mean': nll_in_mean,
        'nll_se': nll_in_stderr,
        'total_entropy_mean': tot_ent_in_mean,
        'total_entropy_se': tot_ent_in_stderr,
        'f1_mean': f1_in_mean,
        'f1_se': f1_in_stderr,
        'auc_pr_mean': aucpr_in_mean,
        'auc_pr_se': aucpr_in_stderr,
        'time_mean': time_mean,
        'time_se': time_stderr,
        'mean_brier': mean_brier,
        'stderr_brier': stderr_brier,
        'mean_ece': mean_ece,
        'stderr_ece': stderr_ece,
        'mean_precision': mean_precision,
        'stderr_precision': stderr_precision,
        'mean_recall': mean_recall,
        'stderr_recall': stderr_recall,
        'correct_entropy_mean': tot_ent_in_corr_mean,
        'correct_entropy_se': tot_ent_in_corr_stderr,
        'incorrect_entropy_mean': tot_ent_in_inc_mean,
        'incorrect_entropy_se': tot_ent_in_inc_stderr,
        'ood_close_mean': tot_ent_od_close_mean,
        'ood_close_se': tot_ent_od_close_stderr,
        'ood_corrupt_mean': tot_ent_od_corrupt_mean,
        'ood_corrupt_se': tot_ent_od_corrupt_stderr,
        'ood_far_mean': tot_ent_od_far_mean,
        'ood_far_se': tot_ent_od_far_stderr,
    }
    savemat('BayesianNN_CIFAR_MAP_metrics.mat', metrics)
    print("\nAll computed metrics (and raw replicate results) have been saved to 'BayesianNN_CIFAR_MAP_metrics.mat'.")
