#!/usr/bin/env python3
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset
import torch.optim as optim
import sys
from scipy.io import savemat, loadmat

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import HMC, MCMC

# Use torch.multiprocessing instead of standard multiprocessing
import torch.multiprocessing as mp

from torch.func import functional_call

# For F1 and AUC-PR metrics.
from sklearn.metrics import f1_score, average_precision_score
from sklearn.preprocessing import label_binarize

# Set device and seeds
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

############################################
# Global Prior Parameters (using standard deviations) for initialization (N(0,v))
############################################
sigma_b    = np.sqrt(0.1)   # standard deviation for biases
sigma_conv = np.sqrt(0.1)    # standard deviation for conv weights
sigma_fc   = np.sqrt(0.1)    # standard deviation for fc weights

############################################
# Tuning variances (s) for the N(MAP, s) prior used in SMC
############################################
sigma_b_s    = np.sqrt(0.01)   # tuning variance for biases
sigma_conv_s = np.sqrt(0.01)   # tuning variance for conv weights
sigma_fc_s   = np.sqrt(0.01)   # tuning variance for fc weights

# Dictionary for ease of use in the SMC functions.
s_variances = {
    'sigma_b': sigma_b_s,
    'sigma_conv': sigma_conv_s,
    'sigma_fc': sigma_fc_s
}

############################################
# Helper Functions for SMC & HMC
############################################
def unflatten_params(flat, net):
    """Reconstruct parameters (as a dictionary) from a flat vector."""
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

def init_particle(net, prior_params, s_variances):
    """
    Initialize a flattened parameter vector for `net` according to the 
    Gaussian prior N(prior_params[name], s_variances[name]).
    """
    param_dict = {}
    for name, param in net.named_parameters():
        if name == "conv.weight":
            sigma = s_variances['sigma_conv']
        elif name == "fc.weight":
            sigma = s_variances['sigma_fc']
        else:  # conv.bias or fc.bias
            sigma = s_variances['sigma_b']
        # add noise around the MAP prior
        param_dict[name] = prior_params[name].to(device) + torch.randn(param.shape, device=device) * sigma
    # flatten in the same order
    flat = torch.cat([param_dict[name].view(-1) for name, _ in net.named_parameters()])
    return flat

def model_loss_func_ll(output, y, temp):
    """Compute scaled cross-entropy loss."""
    crit = nn.CrossEntropyLoss(reduction='sum')
    return crit(output, y.long().view(-1)) * temp

def model_loss_func(params, x, y, temp, net, prior_params, s_variances):
    """
    Compute the potential energy (negative joint log probability)
    using a Gaussian prior centered at prior_params, with separate tuning variances.
    """
    keys = ["conv.weight", "conv.bias", "fc.weight", "fc.bias"]
    log_prior = 0.0
    for k in keys:
        param_tensor = params[k]
        prior_tensor = prior_params[k]
        if k in ["conv.bias", "fc.bias"]:
            sigma = s_variances['sigma_b']
        elif k == "conv.weight":
            sigma = s_variances['sigma_conv']
        elif k == "fc.weight":
            sigma = s_variances['sigma_fc']
        log_prior += ((param_tensor - prior_tensor)**2).sum() / (2 * sigma**2)
    
    logits = functional_call(net, params, x)
    log_likelihood = model_loss_func_ll(logits, y, temp)
    return log_prior + log_likelihood

def compute_accuracy(net, particles, x_test, y_test):
    """
    Compute the ensemble accuracy on validation data.
    """
    with torch.no_grad():
        probs_list = []
        for params in particles:
            param_dict = unflatten_params(params, net)
            logits = functional_call(net, param_dict, x_test)
            probs = F.softmax(logits, dim=1)
            probs_list.append(probs)
        ensemble_probs = torch.stack(probs_list).mean(dim=0)
        preds = ensemble_probs.argmax(dim=1)
        correct = (preds == y_test.view(-1)).sum().item()
        accuracy = correct / y_test.size(0)
    return accuracy

############################################
# New Functions Using torch.multiprocessing
############################################
def hmc_update_particle(net, x_train, y_train, params_init, num_samples, warmup_steps, step_size, num_steps, temp, prior_params, s_variances):
    """
    Runs an HMC chain using the modified model_loss_func.
    Returns the final sample (flat parameter vector) and the acceptance probability.
    """
    def potential_fn(params_dict):
        params = unflatten_params(params_dict["params"], net)
        return model_loss_func(params, x_train, y_train, temp, net, prior_params=prior_params, s_variances=s_variances)
    
    hmc_kernel = HMC(
        potential_fn=potential_fn,
        step_size=step_size,
        num_steps=num_steps,
        adapt_step_size=False, adapt_mass_matrix=False,
        target_accept_prob=0.65
    )
    
    acceptance_probs = []
    def capture_diagnostics(kernel, params, stage, i):
        diag = kernel.logging()
        if "acc. prob" in diag:
            acceptance_probs.append(diag["acc. prob"])
    
    mcmc_run = MCMC(
        hmc_kernel,
        num_samples=num_samples,
        warmup_steps=warmup_steps,
        initial_params={"params": params_init},
        disable_progbar=True,
        hook_fn=capture_diagnostics,
    )
    mcmc_run.run()
    samples = mcmc_run.get_samples()["params"]
    acc_rate = acceptance_probs[-1]
    return samples[-1], float(acc_rate)

def update_particle(params, net, x_train, y_train, step_size, L, temp, M, prior_params, s_variances, seed):
    """
    Performs the HMC update for a single particle using the separate s_variances.
    Designed to be run in parallel.
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    pyro.set_rng_seed(seed)

    updated_params, acc_rate = hmc_update_particle(
        net, x_train, y_train, params,
        num_samples=M, warmup_steps=0, step_size=step_size,
        num_steps=L, temp=temp, prior_params=prior_params, s_variances=s_variances
    )
    param_dict = unflatten_params(updated_params, net)
    output = functional_call(net, param_dict, x_train)
    loss_val = model_loss_func_ll(output, y_train, temp=1)
    return updated_params, (loss_val.item() if torch.is_tensor(loss_val) else loss_val), acc_rate

def psmc_single(seed, x_train, y_train, x_test, y_test, 
                NetClass, N_particles, step_size, L, M, trajectory, prior_params, numwork, s_variances):
    """
    Runs an SMC sampler (single replicate) with N_particles using tempering and HMC mutations.
    """
    # Seed overall run for reproducibility.
    torch.manual_seed(seed)
    np.random.seed(seed)
    pyro.set_rng_seed(seed)
    random.seed(seed)
    
    net = NetClass().to(device)
    net.eval()
    net.share_memory()  # Allow sharing between processes
    
    d = sum(p.numel() for p in net.parameters())
    num_val = x_test.size(0)
    num_classes = 8  # Filtered MNIST (digits 0-7)
    
    t_start = time.time()
    particles = []
    llike = []
    preds = []
    
    # Particle initialization using the MAP prior and tuning variances.
    for _ in range(N_particles):
        params_init = init_particle(net, prior_params, s_variances)
        particles.append(params_init)
        param_dict = unflatten_params(params_init, net)
        output = functional_call(net, param_dict, x_train)
        loss_val = model_loss_func_ll(output, y_train, temp=1)
        llike.append(loss_val.item() if torch.is_tensor(loss_val) else loss_val)
    llike = np.array(llike)
    
    tempcurr = 0.0
    ZZ = 0.0
    KK = 0.0
    count = 0
    Ls = 0
    
    while tempcurr < 1.0:
        temp_increment = 1.0 - tempcurr
        lwhat = -temp_increment * llike
        lmax = np.max(lwhat)
        w = np.exp(lwhat - lmax)
        w /= np.sum(w)
        ess = 1.0 / np.sum(w**2)
        
        while ess < N_particles / 2:
            temp_increment /= 2.0
            lwhat = -temp_increment * llike
            lmax = np.max(lwhat)
            w = np.exp(lwhat - lmax)
            w /= np.sum(w)
            ess = 1.0 / np.sum(w**2)
        
        proposed_temp = tempcurr + temp_increment
        if proposed_temp >= 1.0 and count == 0:
            proposed_temp = 0.5
        print("Current temperature:", tempcurr, ", dT=", temp_increment, "Step=", count)
        ZZ += np.log(np.mean(np.exp(lwhat - lmax)))
        KK += lmax
        Ls += L * M
        
        # Systematic resampling.
        cumulative_w = np.cumsum(w)
        new_particles = []
        new_llike = []
        positions = (np.arange(N_particles) + np.random.uniform(0, 1)) / N_particles
        i, j = 0, 0
        while i < N_particles:
            if positions[i] < cumulative_w[j]:
                new_particles.append(particles[j])
                new_llike.append(llike[j])
                i += 1
            else:
                j += 1
        particles = new_particles
        llike = np.array(new_llike)
        
        old_particles = particles.copy()
        old_llike = llike.copy()
        
        mutation_success = False
        while not mutation_success:
            # --- Parallel Mutation Phase using torch.multiprocessing ---
            mutation_seeds = [random.randint(0, 2**32 - 1) for _ in range(len(particles))]
            with mp.Pool(processes=numwork) as pool:
                results = pool.starmap(
                    update_particle,
                    [
                        (params, net, x_train, y_train, step_size, L, proposed_temp, M, prior_params, s_variances, seed)
                        for params, seed in zip(particles, mutation_seeds)
                    ]
                )
            updated_particles, updated_llike, acc_rates = zip(*results)
            overall_acc = np.mean(acc_rates)
            print("Overall acceptance for mutation phase:", overall_acc, "Stepsize=", step_size, "L=", L)
            if overall_acc < 0.4 or overall_acc > 0.95:
                if overall_acc < 0.4:
                    step_size *= 0.7
                else:
                    step_size *= 1.1
                print("Unsatisfactory acceptance. Restarting mutation phase with step size", step_size)
                particles = old_particles.copy()
                llike = old_llike.copy()
            else:
                mutation_success = True
                if overall_acc < 0.6:
                    step_size *= 0.7
                elif overall_acc > 0.8:
                    step_size *= 1.1
                particles = updated_particles
                llike = np.array(updated_llike)
            L = min(max(1, int(trajectory/step_size)), 100)
        tempcurr = proposed_temp
        count += 1
    
    for params in particles:
        param_dict = unflatten_params(params, net)
        output = functional_call(net, param_dict, x_test)
        preds.append(output)
    
    accuracy = compute_accuracy(net, particles, x_test, y_test)
    print("Validation accuracy:", accuracy)
    
    particles_tensor = torch.stack(particles).detach().cpu()
    particles_flat = particles_tensor.numpy().reshape(N_particles, d)
    preds_tensor = torch.stack(preds).detach().cpu()
    preds_np = preds_tensor.numpy().reshape(N_particles, x_test.size(0), num_classes)
    t_elapsed = time.time() - t_start
    return t_elapsed, preds_np, particles_flat, ZZ, KK, count, accuracy, Ls

def smc_predict(image, particles_flat, NetClass):
    """
    Given an input image and a set of SMC particles (flattened parameter vectors),
    returns the averaged predictive probability.
    """
    net = NetClass().to(device)
    net.eval()
    predictions_list = []
    for flat in particles_flat:
        params = unflatten_params(torch.tensor(flat, device=device), net)
        logits = functional_call(net, params, image.unsqueeze(0))
        predictions = F.softmax(logits, dim=1)
        predictions_list.append(predictions)
    ensemble_predictions = torch.stack(predictions_list).mean(dim=0)
    return ensemble_predictions.squeeze(0)

############################################
# Simple CNN Architecture (Filtered MNIST: 8 classes)
############################################
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv = nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(4 * 14 * 14, 8)
        
        # Initialize weights from Normal distributions matching the regularization priors.
        nn.init.normal_(self.conv.weight, mean=0, std=sigma_conv)
        if self.conv.bias is not None:
            nn.init.normal_(self.conv.bias, mean=0, std=sigma_b)
        nn.init.normal_(self.fc.weight, mean=0, std=sigma_fc)
        if self.fc.bias is not None:
            nn.init.normal_(self.fc.bias, mean=0, std=sigma_b)
    
    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class FilteredDataset(Dataset):
    def __init__(self, dataset, allowed_labels):
        self.data = [(img, label) for img, label in dataset if label in allowed_labels]
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)
    
############################################
# Helper function for analysis: softmax
############################################
def softmax_np(x):
    # Numerically stable softmax along the last axis.
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

############################################
# Main Script
############################################
if __name__ == '__main__':

    # Set the multiprocessing start method to spawn.
    mp.set_start_method('spawn', force=True)
    
    seed = int(sys.argv[1])
    numwork = int(sys.argv[2])

    ############################################
    # Data Loading and Filtering (Filtered MNIST: digits 0–7)
    ############################################
    N_total_train = 1200
    N_tr          = 1000
    N_val         = 200

    transform = transforms.Compose([transforms.ToTensor()])
    full_train_dataset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform)
    full_test_dataset  = torchvision.datasets.MNIST(
        root='./data', train=False, download=True, transform=transform)

    allowed_labels      = list(range(8))
    filtered_train_pool = FilteredDataset(full_train_dataset, allowed_labels)
    filtered_test_pool  = FilteredDataset(full_test_dataset,  allowed_labels)

    # MAP splits on first 2000 of training pool
    filtered_train_total = Subset(filtered_train_pool, list(range(N_total_train)))
    map_train_dataset    = Subset(filtered_train_total, list(range(N_tr)))
    map_val_dataset      = Subset(
                              filtered_train_total,
                              list(range(N_tr, N_tr + N_val))
                          )

    train_loader = DataLoader(map_train_dataset, batch_size=64, shuffle=True)
    val_loader   = DataLoader(map_val_dataset,   batch_size=64, shuffle=False)

    # Full loader over all 2000 for HMC
    train_loader_full = DataLoader(
        filtered_train_total,
        batch_size=len(filtered_train_total),
        shuffle=False
    )
    x_train, y_train = next(iter(train_loader_full))
    x_train, y_train = x_train.to(device), y_train.to(device)

    ############################################
    # Prepare ID-Test Set (first 1000 of filtered test pool)
    ############################################
    test_id_dataset  = Subset(filtered_test_pool, list(range(1000)))
    test_loader_full = DataLoader(
        test_id_dataset,
        batch_size=len(test_id_dataset),
        shuffle=False
    )
    x_test, y_test = next(iter(test_loader_full))
    x_test, y_test = x_test.to(device), y_test.to(device)

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    pyro.set_rng_seed(seed)

    ############################################
    # MAP (Deterministic) CNN Training on Filtered MNIST
    ############################################
    model_cnn = SimpleCNN().to(device)
    optimizer_cnn = optim.Adam(model_cnn.parameters(), lr=0.001)
    criterion_cnn = nn.CrossEntropyLoss()

    net = model_cnn
    net.eval()
    d = sum(p.numel() for p in net.parameters())
    print(f"The dimension of the parameters in SimpleCNN is {d}")

    train_losses = []
    val_losses = []
    moving_avg_window = 10
    best_moving_avg = float('inf')
    patience = 5
    no_improve_count = 0

    start_time = time.time()

    for epoch in range(1000):
        model_cnn.train()
        running_loss = 0.0
        for images, labels in train_loader:
            optimizer_cnn.zero_grad()
            outputs = model_cnn(images.to(device))
            ce_loss = criterion_cnn(outputs, labels.to(device))
            reg_conv = (torch.sum(model_cnn.conv.weight**2) / (2 * sigma_conv**2) +
                        torch.sum(model_cnn.conv.bias**2) / (2 * sigma_b**2))
            reg_fc = (torch.sum(model_cnn.fc.weight**2) / (2 * sigma_fc**2) +
                      torch.sum(model_cnn.fc.bias**2) / (2 * sigma_b**2))
            reg_loss = reg_conv + reg_fc
            loss = ce_loss + reg_loss/len(map_train_dataset)
            loss.backward()
            optimizer_cnn.step()
            running_loss += loss.item()
        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)
        
        model_cnn.eval()
        val_running_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                outputs = model_cnn(images.to(device))
                ce_loss = criterion_cnn(outputs, labels.to(device))
                reg_conv = (torch.sum(model_cnn.conv.weight**2) / (2 * sigma_conv**2) +
                            torch.sum(model_cnn.conv.bias**2) / (2 * sigma_b**2))
                reg_fc = (torch.sum(model_cnn.fc.weight**2) / (2 * sigma_fc**2) +
                          torch.sum(model_cnn.fc.bias**2) / (2 * sigma_b**2))
                reg_loss = reg_conv + reg_fc
                loss = ce_loss + reg_loss/len(map_train_dataset)
                val_running_loss += loss.item()
        val_loss = val_running_loss / len(val_loader)
        val_losses.append(val_loss)
        
        print(f"MAP Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        if epoch >= moving_avg_window - 1:
            moving_avg = sum(val_losses[-moving_avg_window:]) / moving_avg_window
            if moving_avg < best_moving_avg:
                best_moving_avg = moving_avg
                no_improve_count = 0
            else:
                no_improve_count += 1
            if no_improve_count >= patience:
                print("MAP: Early stopping due to no improvement.")
                break

    total_map_time = time.time() - start_time
    print(f"MAP Total execution time: {total_map_time:.2f} seconds")

    ############################################
    # Define Prior for SMC using the trained MAP model's state dict
    ############################################
    prior_params = model_cnn.state_dict()
    print("Prior parameters extracted from trained MAP model.")

    ############################################
    # Run SMC Inference
    ############################################
    # SMC Sampler Parameters
    N = 10    # Number of particles
    trajectory = 0.0
    step_size = 0.035
    L = min(max(1, int(trajectory / step_size)), 100)        # Number of leapfrog steps per HMC step
    M = 10        # Number of HMC steps per mutation

    start_time = time.time()
    t_elapsed, preds_np, particles_np, Z, K, count, smc_accuracy, Ls = psmc_single(
        seed=seed,
        x_train=x_train,
        y_train=y_train,
        x_test=x_test,
        y_test=y_test,
        NetClass=SimpleCNN,
        N_particles=N,
        step_size=step_size,
        L=L,
        M=M,
        trajectory=trajectory,
        prior_params=prior_params,
        numwork=numwork,
        s_variances=s_variances
    )
    total_smc_time = time.time() - start_time
    print(f"SMC Total execution time: {total_smc_time:.2f} seconds")
    print("SMC Validation Accuracy:", smc_accuracy)
    print("SMC Lsum:", Ls)

    psmc_single_time = np.array(t_elapsed)
    psmc_single_pred = preds_np  # shape: (N_particles, N_val, 8)
    psmc_single_x = np.array(particles_np)

    # --- Save the SMC results ---
    savemat(f'BayesianNN_MNIST_psmc_SimpleNN_results_d{d}_train{N_tr}_val{N_val}_N{N}_M{M}_node{seed}.mat', {
        'psmc_single_time': psmc_single_time,
        'psmc_single_pred': psmc_single_pred,
        'psmc_single_x': psmc_single_x,
        'Z': Z,
        'K': K,
        'count': count,
        'Lsum': Ls
    })

    ############################################
    # SMC Prediction on a Single Sample
    ############################################
    sample_image = x_test[0]
    pred = smc_predict(sample_image, particles_np, SimpleCNN)
    print("SMC Prediction for first sample:", pred)

    ############################################
    # In-Domain Analysis (Validation Data: Digits 0-7)
    ############################################
    # Compute softmax probabilities for each particle
    particle_probs = softmax_np(psmc_single_pred)  # shape: (N_particles, N_val, 8)
    # Ensemble probability (average over particles)
    ensemble_probs = np.mean(particle_probs, axis=0)  # shape: (N_val, 8)
    
    # Convert validation true labels to numpy array
    y_val_np = y_test.cpu().numpy().flatten()  # shape: (N_val,)
    
    # Compute NLL per particle for each sample, then average
    nlls_per_particle = -np.log(particle_probs[:, np.arange(len(y_val_np)), y_val_np] + 1e-12)
    avg_nll = np.sum(np.mean(nlls_per_particle, axis=0))
    
    # Compute Total Entropy on ensemble probability for each sample
    total_entropy = -np.sum(ensemble_probs * np.log(ensemble_probs + 1e-12), axis=1)
    avg_total_entropy = np.mean(total_entropy)
    
    # Compute per-particle entropy and then epistemic uncertainty (mutual information)
    particle_entropy = -np.sum(particle_probs * np.log(particle_probs + 1e-12), axis=2)  # shape: (N_particles, N_val)
    avg_particle_entropy = np.mean(particle_entropy, axis=0)  # shape: (N_val,)
    epistemic_uncertainty = total_entropy - avg_particle_entropy
    avg_epistemic = np.mean(epistemic_uncertainty)
    
    # Compute F1 score (macro) and AUC-PR for in-domain predictions.
    ens_preds = np.argmax(ensemble_probs, axis=1)
    f1_in = f1_score(y_val_np, ens_preds, average='macro', zero_division=0)
    # Binarize true labels for AUC-PR.
    y_val_bin = label_binarize(y_val_np, classes=np.arange(8))
    aucpr_in = average_precision_score(y_val_bin, ensemble_probs, average='macro')
    
    print("\nIn-Domain Analysis (Digits 0-7):")
    print("Average NLL over validation set: {:.4f}".format(avg_nll))
    print("Average Total Entropy: {:.4f}".format(avg_total_entropy))
    print("Average Epistemic Uncertainty (Mutual Information): {:.4f}".format(avg_epistemic))
    print("F1 Score (macro): {:.4f}".format(f1_in))
    print("AUC-PR (macro): {:.4f}".format(aucpr_in))
    
    ############################################
    # Out-Of-Domain Analysis (Digits 8 & 9)
    ############################################
    # Select the first 100 samples of digit 8 and the first 100 samples of digit 9 from full_val_dataset.
    # --- before your OOD section ---

    # 1) Build a small “OOD-only” pool from the full test set:
    od_pool = FilteredDataset(full_test_dataset, [8, 9])

    # 2) If you want exactly 100 of each, you can then subset:
    from torch.utils.data import Subset

    # Collect 100 of each label
    od_indices = []
    count8, count9 = 0, 0
    for idx, (img, label) in enumerate(od_pool):
        if label == 8 and count8 < 100:
            od_indices.append(idx)
            count8 += 1
        elif label == 9 and count9 < 100:
            od_indices.append(idx)
            count9 += 1
        if count8 == 100 and count9 == 100:
            break

    od_dataset = Subset(od_pool, od_indices)

    # 3) Now len(od_dataset) == 200, so batch_size will be positive:
    od_loader = DataLoader(od_dataset, batch_size=len(od_dataset), shuffle=False)

    od_images, od_labels = next(iter(od_loader))
    od_images = od_images.to(device)
    # Note: od_labels are not used for NLL since these are OOD samples.
    
    # Compute SMC predictions for OOD data.
    particle_logits_od = []
    net = SimpleCNN().to(device)
    net.eval()
    for i in range(particles_np.shape[0]):
        flat = particles_np[i]
        params = unflatten_params(torch.tensor(flat, device=device), net)
        logits = functional_call(net, params, od_images)
        particle_logits_od.append(logits.detach().cpu().numpy())
    particle_logits_od = np.stack(particle_logits_od, axis=0)  # shape: (N_particles, N_od, 8)
    
    # Convert logits to softmax probabilities.
    particle_probs_od = softmax_np(particle_logits_od)
    ensemble_probs_od = np.mean(particle_probs_od, axis=0)
    
    # Compute entropy for OOD data.
    total_entropy_od = -np.sum(ensemble_probs_od * np.log(ensemble_probs_od + 1e-12), axis=1)
    avg_total_entropy_od = np.mean(total_entropy_od)
    
    particle_entropy_od = -np.sum(particle_probs_od * np.log(particle_probs_od + 1e-12), axis=2)
    avg_particle_entropy_od = np.mean(particle_entropy_od, axis=0)
    epistemic_uncertainty_od = total_entropy_od - avg_particle_entropy_od
    avg_epistemic_od = np.mean(epistemic_uncertainty_od)
    
    # For OOD predictions, compute F1 score.
    ens_preds_od = np.argmax(ensemble_probs_od, axis=1)
    f1_od = f1_score(od_labels, ens_preds_od, average='macro', zero_division=0)
    # Since the model is trained only on digits 0-7, the AUC-PR is not meaningful.
    aucpr_od = 0.0
    
    print("\nOut-Of-Domain Analysis (Digits 8 & 9):")
    print("Average Total Entropy: {:.4f}".format(avg_total_entropy_od))
    print("Average Epistemic Uncertainty (Mutual Information): {:.4f}".format(avg_epistemic_od))
    print("F1 Score (macro): {:.4f}".format(f1_od))
    print("AUC-PR (macro): {:.4f}".format(aucpr_od))
    
    ############################################
    # Save all computed metrics using savemat
    ############################################
    metrics = {
        'avg_nll_in': avg_nll,
        'avg_total_entropy_in': avg_total_entropy,
        'avg_epistemic_in': avg_epistemic,
        'f1_in': f1_in,
        'aucpr_in': aucpr_in,
        'avg_total_entropy_od': avg_total_entropy_od,
        'avg_epistemic_od': avg_epistemic_od,
        'f1_od': f1_od,
        'aucpr_od': aucpr_od,
        'SMC_Validation_Accuracy': smc_accuracy,
        'psmc_single_time': psmc_single_time,
        'Z': Z,
        'K': K,
        'count': count,
        'Lsum': Ls,
        'total_smc_time': total_smc_time
    }
    
    savemat(f'BayesianNN_MNIST_psmc_SimpleNN_metrics_d{d}_train{N_tr}_val{N_val}_N{N}_M{M}_node{seed}.mat', metrics)
