{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cN47nGeIa5od"
      },
      "source": [
        "# Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The below was run on a `T4 GPU` using a free Google Colab Account (see: https://colab.research.google.com)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QsdqK9OBa5of"
      },
      "source": [
        "# Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "DwfpMMTHa5og"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "from jax.example_libraries import optimizers\n",
        "from functools import partial\n",
        "from keras.datasets import mnist\n",
        "from sklearn.model_selection import train_test_split\n",
        "import matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4gI0CZEqa5oh"
      },
      "source": [
        "# Utility Functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hOyA6kzvQW1b"
      },
      "outputs": [],
      "source": [
        "### Data Extraction & Pre-Processing ###\n",
        "\n",
        "def load_and_preprocess_mnist(\n",
        "    digits_to_include: list = [4, 9],\n",
        "    train_subsample_size: int = 1000,\n",
        "    random_state_subsample: int = 12345,\n",
        "):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    print(\"Loading and Preprocessing Data!\")\n",
        "    (images, labels), _ = mnist.load_data()\n",
        "    images = np.array(images).astype(np.float32)\n",
        "    labels = np.array(labels).astype(np.int32)\n",
        "\n",
        "    print(f\"Filtering for digits: {digits_to_include}\")\n",
        "    indices = np.isin(labels, digits_to_include)\n",
        "    labels = labels[indices]\n",
        "    images = images[indices, :, :]\n",
        "\n",
        "    num_classes = len(digits_to_include)\n",
        "\n",
        "    # Create a mapping from original digit to new class index (0, 1, 2, ...)\n",
        "    label_map = {digit: i for i, digit in enumerate(sorted(digits_to_include))}\n",
        "    new_labels = np.array([label_map[l] for l in labels])\n",
        "    print(f\"Remapped labels. Example: {labels[:5]} -> {new_labels[:5]}\")\n",
        "\n",
        "    images_subsampled, _, labels_subsampled, _ = train_test_split(\n",
        "        images,\n",
        "        new_labels,\n",
        "        train_size=train_subsample_size,\n",
        "        random_state=random_state_subsample,\n",
        "        stratify=new_labels,\n",
        "    )\n",
        "    print(f\"Subsampled {images_subsampled.shape[0]} images.\")\n",
        "\n",
        "    pixel_mean = images_subsampled.mean(axis=0)\n",
        "    pixel_std = images_subsampled.std(axis=0)\n",
        "\n",
        "    pixel_std[pixel_std == 0] = 1.0  # Avoid division by zero\n",
        "    \n",
        "    images_normalized = (images_subsampled - pixel_mean) / pixel_std\n",
        "    print(f\"Normalisation Done! Image shape: {images_normalized.shape}, Num Classes: {num_classes}\")\n",
        "\n",
        "    return images_normalized, labels_subsampled, num_classes\n",
        "\n",
        "### Utility Functions (in JAX) ###\n",
        "\n",
        "@jax.jit\n",
        "def _log_nn_eval(w_particle, v_particle, image_flat_vector):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    arg = jnp.dot(v_particle, jnp.tanh(jnp.dot(w_particle, image_flat_vector)))\n",
        "\n",
        "    return jax.nn.log_softmax(arg)\n",
        "\n",
        "_log_nn_eval_vec_images = jax.vmap(_log_nn_eval, in_axes=(None, None, 0), out_axes=0)\n",
        "\n",
        "@jax.jit\n",
        "def _log_prior_eval(particle_weights, log_sigma_prior):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    flat_weights = particle_weights.reshape(-1)\n",
        "    variance = jnp.exp(2 * log_sigma_prior)\n",
        "    num_dims = flat_weights.size\n",
        "    log_prob = -0.5 * jnp.sum(jnp.square(flat_weights)) / variance \\\n",
        "               - 0.5 * num_dims * (jnp.log(2 * jnp.pi) + 2 * log_sigma_prior)\n",
        "\n",
        "    return log_prob\n",
        "\n",
        "@jax.jit\n",
        "def _log_likelihood_eval_single_particle(w_particle, v_particle, images_flat_batch, labels_batch):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    log_preds_all_classes = _log_nn_eval_vec_images(w_particle, v_particle, images_flat_batch)\n",
        "    log_likelihood_terms = log_preds_all_classes[jnp.arange(labels_batch.size), labels_batch]\n",
        "\n",
        "    return jnp.sum(log_likelihood_terms)\n",
        "\n",
        "@jax.jit\n",
        "def _log_density_eval_single_particle(w_particle, v_particle, alpha_hyper, beta_hyper, images_flat_batch, labels_batch):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    log_prior_w = _log_prior_eval(w_particle, alpha_hyper)\n",
        "    log_prior_v = _log_prior_eval(v_particle, beta_hyper)\n",
        "    log_lik = _log_likelihood_eval_single_particle(w_particle, v_particle, images_flat_batch, labels_batch)\n",
        "\n",
        "    return log_prior_w + log_prior_v + log_lik\n",
        "\n",
        "@jax.jit\n",
        "def _grad_param_eval(particle_weights, log_sigma_hyper):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    flat_weights = particle_weights.reshape(-1)\n",
        "    num_dims = flat_weights.size\n",
        "    variance = jnp.exp(2 * log_sigma_hyper)\n",
        "\n",
        "    return jnp.sum(jnp.square(flat_weights)) / variance - num_dims\n",
        "\n",
        "@jax.jit\n",
        "def ave_grad_param_eval(weights_cloud, log_sigma_hyper):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    return jax.vmap(_grad_param_eval, in_axes=(2, None), out_axes=0)(weights_cloud, log_sigma_hyper).mean()\n",
        "\n",
        "_grad_w_log_density_for_single = jax.grad(_log_density_eval_single_particle, argnums=0)\n",
        "_grad_v_log_density_for_single = jax.grad(_log_density_eval_single_particle, argnums=1)\n",
        "\n",
        "@jax.jit\n",
        "def wgrad_eval(w_cloud, v_cloud, alpha, beta, images_flat_batch, labels_batch):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    return jax.vmap(_grad_w_log_density_for_single, in_axes=(2, 2, None, None, None, None), out_axes=2)(\n",
        "        w_cloud, v_cloud, alpha, beta, images_flat_batch, labels_batch\n",
        "    )\n",
        "\n",
        "@jax.jit\n",
        "def vgrad_eval(w_cloud, v_cloud, alpha, beta, images_flat_batch, labels_batch):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    return jax.vmap(_grad_v_log_density_for_single, in_axes=(2, 2, None, None, None, None), out_axes=2)(\n",
        "        w_cloud, v_cloud, alpha, beta, images_flat_batch, labels_batch\n",
        "    )\n",
        "\n",
        "\n",
        "@jax.jit\n",
        "def _nn_output_metric_single_particle(w_particle, v_particle, image_flat_vector):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    hidden_activations = jnp.tanh(jnp.dot(w_particle, image_flat_vector))\n",
        "    output_logits = jnp.dot(v_particle, hidden_activations)\n",
        "\n",
        "    return jax.nn.softmax(output_logits)\n",
        "\n",
        "_nn_output_metric_vec_images = jax.vmap(_nn_output_metric_single_particle, in_axes=(None, None, 0), out_axes=0)\n",
        "_nn_output_metric_vec_particles_images = jax.vmap(_nn_output_metric_vec_images, in_axes=(2, 2, None), out_axes=2)\n",
        "\n",
        "@jax.jit\n",
        "def log_pointwise_pred_density_metric(w_cloud, v_cloud, images_flat_batch, labels_batch):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    mean_preds = jnp.mean(_nn_output_metric_vec_particles_images(w_cloud, v_cloud, images_flat_batch), axis=2)\n",
        "    probs_true = jnp.clip(mean_preds[jnp.arange(labels_batch.size), labels_batch], a_min=1e-30)\n",
        "\n",
        "    return jnp.mean(jnp.log(probs_true))\n",
        "\n",
        "@jax.jit\n",
        "def test_error_metric(w_cloud, v_cloud, images_flat_batch, labels_batch):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    mean_preds = jnp.mean(_nn_output_metric_vec_particles_images(w_cloud, v_cloud, images_flat_batch), axis=2)\n",
        "    preds_labels = jnp.argmax(mean_preds, axis=1)\n",
        "    return jnp.mean(jnp.abs(labels_batch - preds_labels).astype(jnp.float32) > 0) # Error if prediction is not equal to true label\n",
        "\n",
        "@jax.jit\n",
        "def get_predictions_and_labels(w_cloud, v_cloud, images_flat_batch, labels_batch):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    mean_preds = jnp.mean(_nn_output_metric_vec_particles_images(w_cloud, v_cloud, images_flat_batch), axis=2)\n",
        "    predicted_labels = jnp.argmax(mean_preds, axis=1)\n",
        "    return predicted_labels, labels_batch\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SCwTeVA6a5oi"
      },
      "source": [
        "# JALA-EM\n",
        "\n",
        "Implementation of the JALA-EM algorithm, as described in \"Learning Latent Variable Models via Jarzynski-adjusted Langevin Algorithm\" (see: https://arxiv.org/pdf/2505.18427). Specifically see Appendix C.3 of the aforementioned paper."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9PBHesOIa5oj"
      },
      "outputs": [],
      "source": [
        "@jax.jit\n",
        "def get_grads_U_particle_jala(w_particle, v_particle, alpha_param, beta_param, images_flat_batch, labels_batch):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    grad_w_log_p = _grad_w_log_density_for_single(w_particle, v_particle, alpha_param, beta_param, images_flat_batch, labels_batch)\n",
        "    grad_v_log_p = _grad_v_log_density_for_single(w_particle, v_particle, alpha_param, beta_param, images_flat_batch, labels_batch)\n",
        "\n",
        "    return -grad_w_log_p, -grad_v_log_p\n",
        "\n",
        "@jax.jit\n",
        "def compute_alpha_term_particle_jala(\n",
        "    alpha_param,\n",
        "    beta_param,\n",
        "    wl,\n",
        "    vl,\n",
        "    wr,\n",
        "    vr,\n",
        "    grad_wl_U,\n",
        "    grad_vl_U,\n",
        "    h_step,\n",
        "    images_flat_batch,\n",
        "    labels_batch,\n",
        "):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    U_theta_xl = -_log_density_eval_single_particle(wl, vl, alpha_param, beta_param, images_flat_batch, labels_batch)\n",
        "\n",
        "    diff_w = wr - wl\n",
        "    diff_v = vr - vl\n",
        "\n",
        "    dot_prod = jnp.sum(diff_w * grad_wl_U) + jnp.sum(diff_v * grad_vl_U)\n",
        "    norm_sq_grad = jnp.sum(jnp.square(grad_wl_U)) + jnp.sum(jnp.square(grad_vl_U))\n",
        "\n",
        "    return U_theta_xl + 0.5 * dot_prod + (h_step / 4.0) * norm_sq_grad\n",
        "\n",
        "@jax.jit\n",
        "def compute_ess_jala(log_A_values):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    log_A_size_float = jnp.array(log_A_values.size, dtype=jnp.float32)\n",
        "    log_weights_unnorm = jnp.nan_to_num(log_A_values, nan=-jnp.inf, posinf=jnp.log(jnp.finfo(jnp.float32).max)-10, neginf=-jnp.inf)\n",
        "    log_sum_weights = jax.scipy.special.logsumexp(log_weights_unnorm)\n",
        "    log_weights_norm = jnp.where(\n",
        "        jnp.isneginf(log_sum_weights),\n",
        "        -jnp.log(jnp.maximum(1.0, log_A_size_float)),\n",
        "        log_weights_unnorm - log_sum_weights\n",
        "    )\n",
        "\n",
        "    weights_norm = jnp.exp(log_weights_norm)\n",
        "    ess = 1.0 / (jnp.sum(jnp.square(weights_norm)) + 1e-30)\n",
        "\n",
        "    return ess\n",
        "\n",
        "@partial(jax.jit, static_argnums=(2,))\n",
        "def systematic_resample_jala(weights_norm, key, N_particles):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    weights_norm_safe = jnp.nan_to_num(weights_norm, nan=0.0)\n",
        "    sum_weights = jnp.sum(weights_norm_safe)\n",
        "    N_particles_float = jnp.array(N_particles, dtype=jnp.float32)\n",
        "    weights_norm_final = jnp.where(\n",
        "        sum_weights > 1e-12,\n",
        "        weights_norm_safe / sum_weights,\n",
        "        jnp.ones(N_particles, dtype=jnp.float32) / N_particles_float\n",
        "    )\n",
        "\n",
        "    u_offset = jax.random.uniform(key, (), dtype=jnp.float32)\n",
        "    positions = (jnp.arange(N_particles, dtype=jnp.float32) + u_offset) / N_particles_float\n",
        "    cumulative_sum = jnp.cumsum(weights_norm_final)\n",
        "    indexes = jnp.searchsorted(cumulative_sum, positions)\n",
        "\n",
        "    return jnp.clip(indexes, 0, N_particles - 1)\n",
        "\n",
        "@jax.jit\n",
        "def jala_em_particle_step(\n",
        "    particle_pack,\n",
        "    alpha_k_param,\n",
        "    beta_k_param,\n",
        "    alpha_kp1_param,\n",
        "    beta_kp1_param,\n",
        "    h_particle_step,\n",
        "    images_flat_batch,\n",
        "    labels_batch,\n",
        "    key_particle_noise,\n",
        "    max_grad_norm_particle_val: float,\n",
        "):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    w_k_i, v_k_i, log_A_k_i = particle_pack\n",
        "    key_noise_w, key_noise_v = jax.random.split(key_particle_noise)\n",
        "\n",
        "    grad_w_Uk_Xki, grad_v_Uk_Xki = get_grads_U_particle_jala(\n",
        "        w_k_i,\n",
        "        v_k_i,\n",
        "        alpha_k_param,\n",
        "        beta_k_param,\n",
        "        images_flat_batch,\n",
        "        labels_batch,\n",
        "    )\n",
        "    flat_grad_k = jnp.concatenate([grad_w_Uk_Xki.reshape(-1), grad_v_Uk_Xki.reshape(-1)])\n",
        "    grad_norm_k = jnp.linalg.norm(flat_grad_k) + 1e-9\n",
        "\n",
        "    trigger_clip_k = grad_norm_k > max_grad_norm_particle_val\n",
        "    scale_factor_k = jnp.where(trigger_clip_k, max_grad_norm_particle_val / grad_norm_k, 1.0)\n",
        "    grad_w_Uk_Xki_c = grad_w_Uk_Xki * scale_factor_k\n",
        "    grad_v_Uk_Xki_c = grad_v_Uk_Xki * scale_factor_k\n",
        "\n",
        "    noise_w_i = jax.random.normal(key_noise_w, w_k_i.shape, dtype=w_k_i.dtype)\n",
        "    noise_v_i = jax.random.normal(key_noise_v, v_k_i.shape, dtype=v_k_i.dtype)\n",
        "    w_kp1_i = w_k_i - h_particle_step * grad_w_Uk_Xki_c + jnp.sqrt(2 * h_particle_step) * noise_w_i\n",
        "    v_kp1_i = v_k_i - h_particle_step * grad_v_Uk_Xki_c + jnp.sqrt(2 * h_particle_step) * noise_v_i\n",
        "\n",
        "    term_alpha_k_Xk_Xkp1 = compute_alpha_term_particle_jala(\n",
        "        alpha_k_param,\n",
        "        beta_k_param,\n",
        "        w_k_i,\n",
        "        v_k_i,\n",
        "        w_kp1_i,\n",
        "        v_kp1_i,\n",
        "        grad_w_Uk_Xki_c,\n",
        "        grad_v_Uk_Xki_c,\n",
        "        h_particle_step,\n",
        "        images_flat_batch,\n",
        "        labels_batch,\n",
        "    )\n",
        "\n",
        "    grad_w_Ukp1_Xkp1i, grad_v_Ukp1_Xkp1i = get_grads_U_particle_jala(\n",
        "        w_kp1_i,\n",
        "        v_kp1_i,\n",
        "        alpha_kp1_param,\n",
        "        beta_kp1_param,\n",
        "        images_flat_batch,\n",
        "        labels_batch,\n",
        "    )\n",
        "    flat_grad_kp1 = jnp.concatenate([grad_w_Ukp1_Xkp1i.reshape(-1), grad_v_Ukp1_Xkp1i.reshape(-1)])\n",
        "    grad_norm_kp1 = jnp.linalg.norm(flat_grad_kp1) + 1e-9\n",
        "\n",
        "    trigger_clip_kp1 = grad_norm_kp1 > max_grad_norm_particle_val\n",
        "    scale_factor_kp1 = jnp.where(trigger_clip_kp1, max_grad_norm_particle_val / grad_norm_kp1, 1.0)\n",
        "    grad_w_Ukp1_Xkp1i_c = grad_w_Ukp1_Xkp1i * scale_factor_kp1\n",
        "    grad_v_Ukp1_Xkp1i_c = grad_v_Ukp1_Xkp1i * scale_factor_kp1\n",
        "\n",
        "    term_alpha_kp1_Xkp1_Xk = compute_alpha_term_particle_jala(\n",
        "        alpha_kp1_param,\n",
        "        beta_kp1_param,\n",
        "        w_kp1_i,\n",
        "        v_kp1_i,\n",
        "        w_k_i,\n",
        "        v_k_i,\n",
        "        grad_w_Ukp1_Xkp1i_c,\n",
        "        grad_v_Ukp1_Xkp1i_c,\n",
        "        h_particle_step,\n",
        "        images_flat_batch,\n",
        "        labels_batch,\n",
        "    )\n",
        "\n",
        "    log_A_kp1_i = log_A_k_i - term_alpha_kp1_Xkp1_Xk + term_alpha_k_Xk_Xkp1\n",
        "\n",
        "    return (w_kp1_i, v_kp1_i, log_A_kp1_i)\n",
        "\n",
        "def jala_em(\n",
        "        ltrain_data,\n",
        "        itrain_data_28x28,\n",
        "        ltest_data,\n",
        "        itest_data_28x28,\n",
        "        h_particle_val,\n",
        "        K_iters_val,\n",
        "        N_particles_val,\n",
        "        a_init_scalar_val,\n",
        "        b_init_scalar_val,\n",
        "        w_init_cloud_val,\n",
        "        v_init_cloud_val,\n",
        "        use_adam_optimizer: bool = True,\n",
        "        h_theta_sgd_val: float = 1e-7,\n",
        "        h_theta_adam_val: float = 1e-5,\n",
        "        ess_threshold_C_fraction: float = 0.5,\n",
        "        master_key_val: jax.random.PRNGKey = None,\n",
        "        alpha_clip_min: float = -7.0,\n",
        "        alpha_clip_max: float = 7.0,\n",
        "        beta_clip_min: float = -7.0,\n",
        "        beta_clip_max: float = 7.0,\n",
        "        max_grad_norm_particle: float = 100.0,\n",
        "        print_every: int = 10\n",
        "    ):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    print(\"Running JAX-compatible JALA-EM...\")\n",
        "\n",
        "    itrain_flat = itrain_data_28x28.reshape(itrain_data_28x28.shape[0], -1)\n",
        "    itest_flat = itest_data_28x28.reshape(itest_data_28x28.shape[0], -1)\n",
        "\n",
        "    static_itrain_jax = jnp.array(itrain_flat, dtype=jnp.float32)\n",
        "    static_ltrain_jax = jnp.array(ltrain_data, dtype=jnp.int32)\n",
        "    static_itest_jax = jnp.array(itest_flat, dtype=jnp.float32)\n",
        "    static_ltest_jax = jnp.array(ltest_data, dtype=jnp.int32)\n",
        "\n",
        "    w_particles = jnp.array(w_init_cloud_val, dtype=jnp.float32)\n",
        "    v_particles = jnp.array(v_init_cloud_val, dtype=jnp.float32)\n",
        "    log_A = jnp.zeros(N_particles_val, dtype=jnp.float32)\n",
        "\n",
        "    lppd_history = np.zeros(K_iters_val, dtype=np.float32)\n",
        "    error_history = np.zeros(K_iters_val, dtype=np.float32)\n",
        "    ess_history = np.zeros(K_iters_val, dtype=np.float32)\n",
        "\n",
        "    Dw = jnp.maximum(1.0, jnp.array(w_init_cloud_val.shape[0] * w_init_cloud_val.shape[1], dtype=jnp.float32))\n",
        "    Dv = jnp.maximum(1.0, jnp.array(v_init_cloud_val.shape[0] * v_init_cloud_val.shape[1], dtype=jnp.float32))\n",
        "\n",
        "    vmapped_jala_particle_step_fn = jax.vmap(\n",
        "        jala_em_particle_step,\n",
        "        in_axes=( (0, 0, 0), None, None, None, None, None, None, None, 0, None),\n",
        "        out_axes=(0, 0, 0)\n",
        "    )\n",
        "\n",
        "    log_cap_val = jnp.log(jnp.finfo(jnp.float32).max / jnp.maximum(1.0, N_particles_val) + 1e-9) - 1.0\n",
        "\n",
        "    opt_state_alpha, opt_state_beta = None, None\n",
        "    if use_adam_optimizer:\n",
        "        opt_init_alpha, opt_update_alpha, get_params_alpha = optimizers.adam(step_size=h_theta_adam_val)\n",
        "        opt_init_beta, opt_update_beta, get_params_beta = optimizers.adam(step_size=h_theta_adam_val)\n",
        "        opt_state_alpha = opt_init_alpha(jnp.array(a_init_scalar_val, dtype=jnp.float32))\n",
        "        opt_state_beta = opt_init_beta(jnp.array(b_init_scalar_val, dtype=jnp.float32))\n",
        "\n",
        "    actual_a_history = [float(a_init_scalar_val)]\n",
        "    actual_b_history = [float(b_init_scalar_val)]\n",
        "\n",
        "    for k_iter in range(K_iters_val):\n",
        "        loop_key, master_key_val = jax.random.split(master_key_val)\n",
        "        key_k_iter, key_resample_iter = jax.random.split(loop_key)\n",
        "\n",
        "        if use_adam_optimizer:\n",
        "            current_alpha = get_params_alpha(opt_state_alpha)\n",
        "            current_beta = get_params_beta(opt_state_beta)\n",
        "        else:\n",
        "            current_alpha = jnp.array(actual_a_history[-1], dtype=jnp.float32)\n",
        "            current_beta = jnp.array(actual_b_history[-1], dtype=jnp.float32)\n",
        "\n",
        "        error_history[k_iter] = float(test_error_metric(w_particles, v_particles, static_itest_jax, static_ltest_jax))\n",
        "        lppd_history[k_iter] = float(log_pointwise_pred_density_metric(w_particles, v_particles, static_itest_jax, static_ltest_jax))\n",
        "        \n",
        "        ess_val = float(compute_ess_jala(log_A))\n",
        "        ess_history[k_iter] = ess_val\n",
        "\n",
        "        if k_iter % print_every == 0 or k_iter == K_iters_val - 1:\n",
        "            print(f\"JALA-EM Iter {k_iter+1}/{K_iters_val}, ESS: {ess_val:.2f}, Test Error: {error_history[k_iter]:.4f}, LPPD: {lppd_history[k_iter]:.4f}, Alpha: {current_alpha.item():.4f}, Beta: {current_beta.item():.4f}\")\n",
        "\n",
        "        log_A_size_float = jnp.array(log_A.size, dtype=jnp.float32)\n",
        "        log_weights_unnorm_k = log_A\n",
        "        log_sum_weights_k = jax.scipy.special.logsumexp(log_weights_unnorm_k)\n",
        "        log_normalized_weights_k = jnp.where(\n",
        "            jnp.isneginf(log_sum_weights_k),\n",
        "            -jnp.log(jnp.maximum(1.0, log_A_size_float)),\n",
        "            log_weights_unnorm_k - log_sum_weights_k\n",
        "        )\n",
        "        normalized_weights_k = jnp.exp(log_normalized_weights_k)\n",
        "\n",
        "        all_alpha_grads_logprior = jax.vmap(_grad_param_eval, in_axes=(2, None), out_axes=0)(w_particles, current_alpha)\n",
        "        sum_weighted_nabla_logP_alpha = jnp.sum(normalized_weights_k * all_alpha_grads_logprior)\n",
        "\n",
        "        all_beta_grads_logprior = jax.vmap(_grad_param_eval, in_axes=(2, None), out_axes=0)(v_particles, current_beta)\n",
        "        sum_weighted_nabla_logP_beta = jnp.sum(normalized_weights_k * all_beta_grads_logprior)\n",
        "\n",
        "        if use_adam_optimizer:\n",
        "            grad_for_adam_alpha = -(sum_weighted_nabla_logP_alpha / Dw)\n",
        "            opt_state_alpha = opt_update_alpha(k_iter, grad_for_adam_alpha, opt_state_alpha)\n",
        "            next_alpha_unclipped = get_params_alpha(opt_state_alpha)\n",
        "\n",
        "            grad_for_adam_beta = -(sum_weighted_nabla_logP_beta / Dv)\n",
        "            opt_state_beta = opt_update_beta(k_iter, grad_for_adam_beta, opt_state_beta)\n",
        "            next_beta_unclipped = get_params_beta(opt_state_beta)\n",
        "        else:\n",
        "            next_alpha_unclipped = current_alpha + h_theta_sgd_val * (sum_weighted_nabla_logP_alpha / Dw)\n",
        "            next_beta_unclipped = current_beta + h_theta_sgd_val * (sum_weighted_nabla_logP_beta / Dv)\n",
        "\n",
        "        next_alpha_clipped = jnp.clip(next_alpha_unclipped, alpha_clip_min, alpha_clip_max)\n",
        "        next_beta_clipped = jnp.clip(next_beta_unclipped, beta_clip_min, beta_clip_max)\n",
        "\n",
        "        actual_a_history.append(float(next_alpha_clipped.item()))\n",
        "        actual_b_history.append(float(next_beta_clipped.item()))\n",
        "\n",
        "        keys_for_particles = jax.random.split(key_k_iter, N_particles_val)\n",
        "        w_particles_transposed = jnp.transpose(w_particles, (2, 0, 1))\n",
        "        v_particles_transposed = jnp.transpose(v_particles, (2, 0, 1))\n",
        "\n",
        "        updated_w_transposed, updated_v_transposed, next_log_A_raw = vmapped_jala_particle_step_fn(\n",
        "            (w_particles_transposed, v_particles_transposed, log_A),\n",
        "            current_alpha,\n",
        "            current_beta,\n",
        "            next_alpha_clipped,\n",
        "            next_beta_clipped,\n",
        "            h_particle_val,\n",
        "            static_itrain_jax,\n",
        "            static_ltrain_jax,\n",
        "            keys_for_particles,\n",
        "            max_grad_norm_particle\n",
        "        )\n",
        "\n",
        "        w_particles = jnp.transpose(updated_w_transposed, (1, 2, 0))\n",
        "        v_particles = jnp.transpose(updated_v_transposed, (1, 2, 0))\n",
        "        log_A = jnp.nan_to_num(next_log_A_raw, nan=-jnp.inf, posinf=log_cap_val, neginf=-jnp.inf)\n",
        "\n",
        "        if ess_val < ess_threshold_C_fraction * N_particles_val:\n",
        "            if k_iter % print_every == 0 or k_iter == K_iters_val - 1:\n",
        "                print(f\"Resampling triggered for Iter {k_iter+1} (ESS={ess_val:.2f})\")\n",
        "\n",
        "            log_A_current_for_resample = log_A\n",
        "            log_sum_weights_res = jax.scipy.special.logsumexp(log_A_current_for_resample)\n",
        "            log_normalized_weights_res = jnp.where(\n",
        "                jnp.isneginf(log_sum_weights_res),\n",
        "                -jnp.log(jnp.maximum(1.0, jnp.array(log_A_current_for_resample.size, dtype=jnp.float32))),\n",
        "                log_A_current_for_resample - log_sum_weights_res\n",
        "            )\n",
        "            normalized_weights_resample = jnp.exp(log_normalized_weights_res)\n",
        "            indices = systematic_resample_jala(normalized_weights_resample, key_resample_iter, N_particles_val)\n",
        "\n",
        "            w_particles = w_particles[:, :, indices]\n",
        "            v_particles = v_particles[:, :, indices]\n",
        "            log_A = jnp.zeros(N_particles_val, dtype=jnp.float32)\n",
        "\n",
        "    a_history_np = np.array(actual_a_history, dtype=np.float32)\n",
        "    b_history_np = np.array(actual_b_history, dtype=np.float32)\n",
        "\n",
        "    return a_history_np, b_history_np, np.array(w_particles), np.array(v_particles), lppd_history, error_history, ess_history, np.array(log_A)\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "raFsInAya5ok"
      },
      "source": [
        "# PGD\n",
        "\n",
        "Implementation of the PGD algorithm as described in \"Particle algorithms for maximum likelihood training of latent variable models\" (see: https://arxiv.org/pdf/2204.12965). Also see: https://colab.research.google.com/github/juankuntz/ParEM/blob/main/jax/bayesian_neural_network.ipynb#scrollTo=LtqjvuN98ZmU for inspiration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HldFuORda5ok"
      },
      "outputs": [],
      "source": [
        "def pgd_jax(\n",
        "        ltrain_data,\n",
        "        itrain_data_28x28,\n",
        "        ltest_data,\n",
        "        itest_data_28x28,\n",
        "        h_step,\n",
        "        K_iters,\n",
        "        N_particles,\n",
        "        a_init_scalar,\n",
        "        b_init_scalar,\n",
        "        w_init_cloud,\n",
        "        v_init_cloud,\n",
        "        key_pgd\n",
        "    ):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    itrain_flat = itrain_data_28x28.reshape(itrain_data_28x28.shape[0], -1)\n",
        "    itest_flat = itest_data_28x28.reshape(itest_data_28x28.shape[0], -1)\n",
        "\n",
        "    Dw = w_init_cloud.shape[0] * w_init_cloud.shape[1]\n",
        "    Dv = v_init_cloud.shape[0] * v_init_cloud.shape[1]\n",
        "    Dw = np.maximum(1.0, Dw)\n",
        "    Dv = np.maximum(1.0, Dv)\n",
        "\n",
        "    a_hist_pgd = [float(a_init_scalar)]\n",
        "    b_hist_pgd = [float(b_init_scalar)]\n",
        "    w_pgd = jnp.array(w_init_cloud, dtype=jnp.float32)\n",
        "    v_pgd = jnp.array(v_init_cloud, dtype=jnp.float32)\n",
        "\n",
        "    lppd_pgd_hist = np.zeros(K_iters)\n",
        "    error_pgd_hist = np.zeros(K_iters)\n",
        "\n",
        "    for k in range(K_iters):\n",
        "        key_pgd, key_w_noise, key_v_noise = jax.random.split(key_pgd, 3)\n",
        "        current_a_val = jnp.array(a_hist_pgd[-1], dtype=jnp.float32)\n",
        "        current_b_val = jnp.array(b_hist_pgd[-1], dtype=jnp.float32)\n",
        "\n",
        "        error_pgd_hist[k] = float(test_error_metric(w_pgd, v_pgd, itest_flat, ltest_data))\n",
        "        lppd_pgd_hist[k] = float(log_pointwise_pred_density_metric(w_pgd, v_pgd, itest_flat, ltest_data))\n",
        "\n",
        "        if k % max(1, K_iters // 10) == 0 or k == K_iters -1:\n",
        "            print(f\"PGD Iter {k+1}/{K_iters}, Test Error: {error_pgd_hist[k]:.4f}, LPPD: {lppd_pgd_hist[k]:.4f}, Alpha: {current_a_val.item():.4f}, Beta: {current_b_val.item():.4f}\")\n",
        "\n",
        "        wk_pgd, vk_pgd = w_pgd, v_pgd\n",
        "        grad_a = ave_grad_param_eval(wk_pgd, current_a_val)\n",
        "        grad_b = ave_grad_param_eval(vk_pgd, current_b_val)\n",
        "        new_a = current_a_val + h_step * grad_a / Dw\n",
        "        new_b = current_b_val + h_step * grad_b / Dv\n",
        "\n",
        "        a_hist_pgd.append(float(new_a.item()))\n",
        "        b_hist_pgd.append(float(new_b.item()))\n",
        "\n",
        "        w_grad_val = wgrad_eval(wk_pgd, vk_pgd, current_a_val, current_b_val, itrain_flat, ltrain_data)\n",
        "        v_grad_val = vgrad_eval(wk_pgd, vk_pgd, current_a_val, current_b_val, itrain_flat, ltrain_data)\n",
        "\n",
        "        w_pgd = wk_pgd + h_step * w_grad_val + \\\n",
        "                jnp.sqrt(2 * h_step) * jax.random.normal(key_w_noise, wk_pgd.shape, dtype=wk_pgd.dtype)\n",
        "        v_pgd = vk_pgd + h_step * v_grad_val + \\\n",
        "                jnp.sqrt(2 * h_step) * jax.random.normal(key_v_noise, vk_pgd.shape, dtype=vk_pgd.dtype)\n",
        "\n",
        "    return np.array(a_hist_pgd), np.array(b_hist_pgd), np.array(w_pgd), np.array(v_pgd), lppd_pgd_hist, error_pgd_hist"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iAFBCOtZa5ol"
      },
      "source": [
        "# SOUL\n",
        "\n",
        "Implementation of the SOUL algorithm as described in \"Particle algorithms for maximum likelihood training of latent variable models\" (see: https://arxiv.org/pdf/2204.12965). Also see: https://colab.research.google.com/github/juankuntz/ParEM/blob/main/jax/bayesian_neural_network.ipynb#scrollTo=LtqjvuN98ZmU for inspiration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3QPkEcwZa5ol"
      },
      "outputs": [],
      "source": [
        "def soul_jax(\n",
        "        ltrain_data,\n",
        "        itrain_data_28x28,\n",
        "        ltest_data,\n",
        "        itest_data_28x28,\n",
        "        h_step,\n",
        "        K_iters,\n",
        "        N_particles,\n",
        "        a_init_scalar,\n",
        "        b_init_scalar,\n",
        "        w_init_cloud,\n",
        "        v_init_cloud,\n",
        "        key_soul\n",
        "    ):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    itrain_flat = itrain_data_28x28.reshape(itrain_data_28x28.shape[0], -1)\n",
        "    itest_flat = itest_data_28x28.reshape(itest_data_28x28.shape[0], -1)\n",
        "\n",
        "    Dw = w_init_cloud.shape[0] * w_init_cloud.shape[1]\n",
        "    Dv = v_init_cloud.shape[0] * v_init_cloud.shape[1]\n",
        "    Dw = np.maximum(1.0, Dw)\n",
        "    Dv = np.maximum(1.0, Dv)\n",
        "\n",
        "    a_hist_soul = [float(a_init_scalar)]\n",
        "    b_hist_soul = [float(b_init_scalar)]\n",
        "    w_soul = jnp.array(w_init_cloud, dtype=jnp.float32)\n",
        "    v_soul = jnp.array(v_init_cloud, dtype=jnp.float32)\n",
        "\n",
        "    num_inner_steps = N_particles\n",
        "\n",
        "    lppd_soul_hist = np.zeros(K_iters)\n",
        "    error_soul_hist = np.zeros(K_iters)\n",
        "\n",
        "    for k in range(K_iters):\n",
        "        key_soul, key_ula_loop = jax.random.split(key_soul)\n",
        "        current_a_val = jnp.array(a_hist_soul[-1], dtype=jnp.float32)\n",
        "        current_b_val = jnp.array(b_hist_soul[-1], dtype=jnp.float32)\n",
        "\n",
        "        lppd_soul_hist[k] = float(log_pointwise_pred_density_metric(w_soul, v_soul, itest_flat, ltest_data))\n",
        "        error_soul_hist[k] = float(test_error_metric(w_soul, v_soul, itest_flat, ltest_data))\n",
        "\n",
        "        if k % max(1, K_iters // 10) == 0 or k == K_iters -1 :\n",
        "            print(f\"SOUL Iter {k+1}/{K_iters}, Test Error: {error_soul_hist[k]:.4f}, LPPD: {lppd_soul_hist[k]:.4f}, Alpha: {current_a_val.item():.4f}, Beta: {current_b_val.item():.4f}\")\n",
        "\n",
        "        wkn = w_soul[:, :, -1:]\n",
        "        vkn = v_soul[:, :, -1:]\n",
        "\n",
        "        w_soul_next_iter_list = []\n",
        "        v_soul_next_iter_list = []\n",
        "        keys_ula_inner_w = jax.random.split(key_ula_loop, num_inner_steps)\n",
        "        key_ula_loop, _ = jax.random.split(key_ula_loop)\n",
        "        keys_ula_inner_v = jax.random.split(key_ula_loop, num_inner_steps)\n",
        "\n",
        "        for n_ula in range(num_inner_steps):\n",
        "            w_grad_ula = wgrad_eval(wkn, vkn, current_a_val, current_b_val, itrain_flat, ltrain_data)\n",
        "            v_grad_ula = vgrad_eval(wkn, vkn, current_a_val, current_b_val, itrain_flat, ltrain_data)\n",
        "            wkn = wkn + h_step * w_grad_ula + \\\n",
        "                  jnp.sqrt(2 * h_step) * jax.random.normal(keys_ula_inner_w[n_ula], wkn.shape, dtype=wkn.dtype)\n",
        "            vkn = vkn + h_step * v_grad_ula + \\\n",
        "                  jnp.sqrt(2 * h_step) * jax.random.normal(keys_ula_inner_v[n_ula], vkn.shape, dtype=vkn.dtype)\n",
        "            w_soul_next_iter_list.append(wkn.squeeze(axis=2))\n",
        "            v_soul_next_iter_list.append(vkn.squeeze(axis=2))\n",
        "\n",
        "        if num_inner_steps > 0 : # in case `N_particles` is 0 (can probably remove this...)\n",
        "             w_soul = jnp.transpose(jnp.stack(w_soul_next_iter_list), (1, 2, 0))\n",
        "             v_soul = jnp.transpose(jnp.stack(v_soul_next_iter_list), (1, 2, 0))\n",
        "\n",
        "        grad_a = ave_grad_param_eval(w_soul, current_a_val)\n",
        "        grad_b = ave_grad_param_eval(v_soul, current_b_val)\n",
        "        new_a = current_a_val + h_step * grad_a / Dw\n",
        "        new_b = current_b_val + h_step * grad_b / Dv\n",
        "        a_hist_soul.append(float(new_a.item()))\n",
        "        b_hist_soul.append(float(new_b.item()))\n",
        "\n",
        "    return np.array(a_hist_soul), np.array(b_hist_soul), np.array(w_soul), np.array(v_soul), lppd_soul_hist, error_soul_hist"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5cxxIjsGa5oo"
      },
      "source": [
        "# Experiment Pipeline (Multiple Runs)\n",
        "\n",
        "Logic to run the core experimental trial, for all algorithms, over a number of repeats.\n",
        "\n",
        "*(TODO: Abstract function specific logic to experiment config)*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BVuhlPXwaCpu"
      },
      "outputs": [],
      "source": [
        "def run_experiment(hidden_neurons: int = 40):\n",
        "    \"\"\"\n",
        "    TODO\n",
        "    \"\"\"\n",
        "    DIGITS_TO_USE = [2, 4, 7, 9]\n",
        "    N_RUNS = 10\n",
        "\n",
        "    pgd_final_errors = []\n",
        "    jala_final_errors = []\n",
        "    soul_final_errors = []\n",
        "\n",
        "    # Loop for each experimental trial\n",
        "    for i in range(N_RUNS):\n",
        "        print(f\"\\n{'='*25} RUN {i+1}/{N_RUNS} {'='*25}\")\n",
        "\n",
        "        master_key = jax.random.PRNGKey(12345+i)\n",
        "        print(f\"Master Key: {master_key}\")\n",
        "\n",
        "        key_pgd_run, key_soul_run, key_jala_run, key_init_particles = jax.random.split(master_key, 4)\n",
        "        print(key_pgd_run, key_soul_run, key_jala_run, key_init_particles)\n",
        "\n",
        "        print(\"Loading Dataset D!\")\n",
        "        all_images_processed_28x28, all_labels_processed, num_classes = load_and_preprocess_mnist(\n",
        "            digits_to_include=DIGITS_TO_USE,\n",
        "            train_subsample_size=2500,\n",
        "            random_state_subsample=12345,\n",
        "        )\n",
        "\n",
        "        print(\"\\nSplitting D into D_train and D_test!\")\n",
        "        itrain_common_28x28, itest_common_28x28, ltrain_common, ltest_common = train_test_split(\n",
        "            all_images_processed_28x28,\n",
        "            all_labels_processed,\n",
        "            test_size=0.2,\n",
        "            random_state=0,\n",
        "            stratify=all_labels_processed,\n",
        "        )\n",
        "        print(f\"Final training images shape (28x28): {itrain_common_28x28.shape}\")\n",
        "        print(f\"Final testing images shape (28x28): {itest_common_28x28.shape}\\n\")\n",
        "\n",
        "        # re-initialise algorithm for each run\n",
        "        N_common_particles = 50\n",
        "        K_common_iters = 500\n",
        "        a0_init_scalar = 0.0\n",
        "        b0_init_scalar = 0.0\n",
        "\n",
        "        D_w_rows, D_w_cols = hidden_neurons, 28 ** 2\n",
        "        D_v_rows, D_v_cols = num_classes, hidden_neurons\n",
        "\n",
        "        n_params_w = D_w_rows * D_w_cols\n",
        "        n_params_v = D_v_rows * D_v_cols\n",
        "        total_params = n_params_w + n_params_v\n",
        "        print(f\"\\nBNN Architecture:\")\n",
        "        print(f\"Hidden Neurons: {hidden_neurons}\")\n",
        "        print(f\"Total Parameters: {total_params} (W: {n_params_w}, V: {n_params_v})\")\n",
        "\n",
        "        key_w_init, key_v_init = jax.random.split(key_init_particles)\n",
        "        w0_common_cloud = jax.random.normal(key_w_init, (D_w_rows, D_w_cols, N_common_particles), dtype=jnp.float32) * jnp.exp(jnp.float32(a0_init_scalar))\n",
        "        v0_common_cloud = jax.random.normal(key_v_init, (D_v_rows, D_v_cols, N_common_particles), dtype=jnp.float32) * jnp.exp(jnp.float32(b0_init_scalar))\n",
        "\n",
        "        # Run PGD\n",
        "        h_pgd_run = 7.5e-2\n",
        "        print(f\"\\nRunning PGD (iters={K_common_iters}, h={h_pgd_run})...\")\n",
        "        _, _, _, _, _, error_pgd = pgd_jax(\n",
        "            ltrain_common,\n",
        "            itrain_common_28x28,\n",
        "            ltest_common,\n",
        "            itest_common_28x28,\n",
        "            h_pgd_run,\n",
        "            K_common_iters,\n",
        "            N_common_particles,\n",
        "            a0_init_scalar,\n",
        "            b0_init_scalar,\n",
        "            w0_common_cloud,\n",
        "            v0_common_cloud,\n",
        "            key_pgd_run,\n",
        "        )\n",
        "        pgd_final_errors.append(error_pgd[-1])\n",
        "        print(f\"PGD run {i+1} finished with Test Error: {error_pgd[-1]:.4f}\")\n",
        "\n",
        "        # Run JALA-EM\n",
        "        h_particle_jala_run = 7.5e-2\n",
        "        h_theta_run = 7.5e-2\n",
        "        print(f\"\\nRunning JALA-EM (iters={K_common_iters})...\")\n",
        "        _, _, _, _, _, error_jala, _, _ = jala_em(\n",
        "            ltrain_common,\n",
        "            itrain_common_28x28,\n",
        "            ltest_common,\n",
        "            itest_common_28x28,\n",
        "            h_particle_val=h_particle_jala_run,\n",
        "            K_iters_val=K_common_iters,\n",
        "            N_particles_val=N_common_particles,\n",
        "            a_init_scalar_val=a0_init_scalar,\n",
        "            b_init_scalar_val=b0_init_scalar,\n",
        "            w_init_cloud_val=w0_common_cloud,\n",
        "            v_init_cloud_val=v0_common_cloud,\n",
        "            use_adam_optimizer=False,\n",
        "            h_theta_sgd_val=h_theta_run,\n",
        "            master_key_val=key_jala_run,\n",
        "            max_grad_norm_particle=200000.0,\n",
        "            ess_threshold_C_fraction=0.95,\n",
        "            print_every=max(1, K_common_iters // 10),\n",
        "        )\n",
        "        jala_final_errors.append(error_jala[-1])\n",
        "        print(f\"JALA-EM run {i+1} finished with Test Error: {error_jala[-1]:.4f}\")\n",
        "\n",
        "        # Run SOUL\n",
        "        h_soul_run = 7.5e-2\n",
        "        print(f\"\\nRunning SOUL (iters={K_common_iters}, h={h_soul_run})...\")\n",
        "        _, _, _, _, _, error_soul = soul_jax(\n",
        "            ltrain_common,\n",
        "            itrain_common_28x28,\n",
        "            ltest_common,\n",
        "            itest_common_28x28,\n",
        "            h_soul_run,\n",
        "            K_common_iters,\n",
        "            N_common_particles,\n",
        "            a0_init_scalar,\n",
        "            b0_init_scalar,\n",
        "            w0_common_cloud,\n",
        "            v0_common_cloud,\n",
        "            key_soul_run,\n",
        "        )\n",
        "        soul_final_errors.append(error_soul[-1])\n",
        "        print(f\"SOUL run {i+1} finished with Test Error: {error_soul[-1]:.4f}\")\n",
        "\n",
        "    print(f\"\\n\\n{'='*20} FINAL SUMMARY (OVER {N_RUNS} RUNS) {'='*20}\")\n",
        "\n",
        "    # Calculate mean and standard deviations\n",
        "    mean_error_pgd = np.mean(pgd_final_errors)\n",
        "    std_error_pgd = np.std(pgd_final_errors)\n",
        "\n",
        "    mean_error_jala = np.mean(jala_final_errors)\n",
        "    std_error_jala = np.std(jala_final_errors)\n",
        "\n",
        "    mean_error_soul = np.mean(soul_final_errors)\n",
        "    std_error_soul = np.std(soul_final_errors)\n",
        "\n",
        "    print(\"\\nFinal Test Error Statistics:\")\n",
        "    print(f\"PGD     -> Mean: {mean_error_pgd:.4f}, Std: {std_error_pgd:.4f}\")\n",
        "    print(f\"JALA-EM -> Mean: {mean_error_jala:.4f}, Std: {std_error_jala:.4f}\")\n",
        "    print(f\"SOUL    -> Mean: {mean_error_soul:.4f}, Std: {std_error_soul:.4f}\")\n",
        "    print(f\"{'='*60}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Experiemnt: $D_{h} = 40$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XJtVl3npa5op",
        "outputId": "d3c2a924-c612-4739-a127-9e12ff277b57"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "========================= RUN 1/10 =========================\n",
            "Master Key: [    0 12345]\n",
            "[1214163296  439912094] [ 867802714 3762255628] [ 795667951 2300365598] [ 709272294 1661227809]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
            "\u001b[1m11490434/11490434\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 0us/step\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 40\n",
            "Total Parameters: 31520 (W: 31360, V: 160)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.6880, LPPD: -1.3734, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0600, LPPD: -0.2486, Alpha: 2.5391, Beta: 8.2260\n",
            "PGD Iter 101/500, Test Error: 0.0620, LPPD: -0.2660, Alpha: 2.6019, Beta: 4.5055\n",
            "PGD Iter 151/500, Test Error: 0.0580, LPPD: -0.2368, Alpha: 2.6299, Beta: 2.9344\n",
            "PGD Iter 201/500, Test Error: 0.0620, LPPD: -0.2966, Alpha: 2.6465, Beta: 2.8078\n",
            "PGD Iter 251/500, Test Error: 0.0560, LPPD: -0.2938, Alpha: 2.6589, Beta: 2.7134\n",
            "PGD Iter 301/500, Test Error: 0.0500, LPPD: -0.2496, Alpha: 2.6684, Beta: 2.6584\n",
            "PGD Iter 351/500, Test Error: 0.0480, LPPD: -0.2262, Alpha: 2.6762, Beta: 2.6233\n",
            "PGD Iter 401/500, Test Error: 0.0520, LPPD: -0.1978, Alpha: 2.6826, Beta: 2.5940\n",
            "PGD Iter 451/500, Test Error: 0.0460, LPPD: -0.1969, Alpha: 2.6885, Beta: 2.5725\n",
            "PGD Iter 500/500, Test Error: 0.0440, LPPD: -0.1948, Alpha: 2.6928, Beta: 2.5500\n",
            "PGD run 1 finished with Test Error: 0.0440\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.6880, LPPD: -1.3734, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.1460, LPPD: -7.9699, Alpha: 2.0417, Beta: 3.5725\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.1700, LPPD: -8.6715, Alpha: 2.0876, Beta: 2.7786\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.1320, LPPD: -6.3250, Alpha: 2.1033, Beta: 2.5359\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.1500, LPPD: -6.5448, Alpha: 2.1176, Beta: 2.5353\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.1100, LPPD: -5.6916, Alpha: 2.1285, Beta: 2.4976\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.1100, LPPD: -4.1401, Alpha: 2.1384, Beta: 2.5006\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0980, LPPD: -3.7999, Alpha: 2.1418, Beta: 2.4161\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0840, LPPD: -3.2410, Alpha: 2.1488, Beta: 2.4632\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.0900, LPPD: -3.5344, Alpha: 2.1514, Beta: 2.3745\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0980, LPPD: -2.2332, Alpha: 2.1504, Beta: 2.2856\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 1 finished with Test Error: 0.0980\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.6880, LPPD: -1.3734, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0840, LPPD: -0.9336, Alpha: 2.2154, Beta: 20.4738\n",
            "SOUL Iter 101/500, Test Error: 0.0860, LPPD: -2.0115, Alpha: 2.4153, Beta: 16.7237\n",
            "SOUL Iter 151/500, Test Error: 0.0780, LPPD: -1.5915, Alpha: 2.5662, Beta: 12.9737\n",
            "SOUL Iter 201/500, Test Error: 0.1020, LPPD: -1.8537, Alpha: 2.6822, Beta: 9.2237\n",
            "SOUL Iter 251/500, Test Error: 0.0980, LPPD: -2.1003, Alpha: 2.7686, Beta: 5.4775\n",
            "SOUL Iter 301/500, Test Error: 0.0920, LPPD: -1.9883, Alpha: 2.8234, Beta: 3.2140\n",
            "SOUL Iter 351/500, Test Error: 0.0900, LPPD: -1.7011, Alpha: 2.8737, Beta: 2.9798\n",
            "SOUL Iter 401/500, Test Error: 0.1040, LPPD: -1.9863, Alpha: 2.9106, Beta: 2.5798\n",
            "SOUL Iter 451/500, Test Error: 0.0840, LPPD: -1.1849, Alpha: 2.9413, Beta: 2.1897\n",
            "SOUL Iter 500/500, Test Error: 0.0840, LPPD: -1.5211, Alpha: 2.9649, Beta: 2.1554\n",
            "SOUL run 1 finished with Test Error: 0.0840\n",
            "\n",
            "========================= RUN 2/10 =========================\n",
            "Master Key: [    0 12346]\n",
            "[ 632167804 2604013364] [1621771532 2839293200] [3728677534 1583167789] [2821711515 2397397985]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 40\n",
            "Total Parameters: 31520 (W: 31360, V: 160)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7820, LPPD: -1.4270, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0660, LPPD: -0.3912, Alpha: 2.5540, Beta: 8.9322\n",
            "PGD Iter 101/500, Test Error: 0.0540, LPPD: -0.2421, Alpha: 2.6122, Beta: 5.1901\n",
            "PGD Iter 151/500, Test Error: 0.0540, LPPD: -0.1927, Alpha: 2.6419, Beta: 2.9880\n",
            "PGD Iter 201/500, Test Error: 0.0460, LPPD: -0.2812, Alpha: 2.6593, Beta: 2.8414\n",
            "PGD Iter 251/500, Test Error: 0.0520, LPPD: -0.2266, Alpha: 2.6712, Beta: 2.7631\n",
            "PGD Iter 301/500, Test Error: 0.0520, LPPD: -0.3100, Alpha: 2.6802, Beta: 2.7111\n",
            "PGD Iter 351/500, Test Error: 0.0480, LPPD: -0.1707, Alpha: 2.6875, Beta: 2.6711\n",
            "PGD Iter 401/500, Test Error: 0.0540, LPPD: -0.3524, Alpha: 2.6934, Beta: 2.6336\n",
            "PGD Iter 451/500, Test Error: 0.0520, LPPD: -0.2185, Alpha: 2.6988, Beta: 2.6015\n",
            "PGD Iter 500/500, Test Error: 0.0540, LPPD: -0.2704, Alpha: 2.7029, Beta: 2.5822\n",
            "PGD run 2 finished with Test Error: 0.0540\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7820, LPPD: -1.4270, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.1620, LPPD: -7.9503, Alpha: 1.9695, Beta: 3.0980\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.1300, LPPD: -5.1879, Alpha: 2.0127, Beta: 2.7287\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.1440, LPPD: -6.8944, Alpha: 2.0424, Beta: 2.6918\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.1040, LPPD: -3.8111, Alpha: 2.0585, Beta: 2.5998\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.0840, LPPD: -3.4739, Alpha: 2.0744, Beta: 2.6411\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.0820, LPPD: -3.3969, Alpha: 2.0854, Beta: 2.5947\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0900, LPPD: -3.5840, Alpha: 2.0913, Beta: 2.4453\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.1060, LPPD: -3.7694, Alpha: 2.0953, Beta: 2.4109\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.1080, LPPD: -3.6571, Alpha: 2.0994, Beta: 2.2557\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0820, LPPD: -2.5660, Alpha: 2.1006, Beta: 2.2448\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 2 finished with Test Error: 0.0820\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7820, LPPD: -1.4270, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0920, LPPD: -1.2771, Alpha: 2.4319, Beta: 33.0631\n",
            "SOUL Iter 101/500, Test Error: 0.0700, LPPD: -1.2994, Alpha: 2.5540, Beta: 29.3130\n",
            "SOUL Iter 151/500, Test Error: 0.0880, LPPD: -1.2316, Alpha: 2.6746, Beta: 25.5630\n",
            "SOUL Iter 201/500, Test Error: 0.0940, LPPD: -1.5991, Alpha: 2.7445, Beta: 21.8130\n",
            "SOUL Iter 251/500, Test Error: 0.0820, LPPD: -1.7916, Alpha: 2.8074, Beta: 18.0629\n",
            "SOUL Iter 301/500, Test Error: 0.0900, LPPD: -1.6554, Alpha: 2.8587, Beta: 14.3129\n",
            "SOUL Iter 351/500, Test Error: 0.0720, LPPD: -1.3028, Alpha: 2.9112, Beta: 10.5629\n",
            "SOUL Iter 401/500, Test Error: 0.0740, LPPD: -1.4670, Alpha: 2.9541, Beta: 6.8134\n",
            "SOUL Iter 451/500, Test Error: 0.1020, LPPD: -2.4056, Alpha: 2.9968, Beta: 3.6518\n",
            "SOUL Iter 500/500, Test Error: 0.0820, LPPD: -1.8593, Alpha: 3.0303, Beta: 3.2699\n",
            "SOUL run 2 finished with Test Error: 0.0820\n",
            "\n",
            "========================= RUN 3/10 =========================\n",
            "Master Key: [    0 12347]\n",
            "[2760833302 2604130258] [2353787085 2855123098] [1645460419  607253567] [1693954434 2313133562]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 40\n",
            "Total Parameters: 31520 (W: 31360, V: 160)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7640, LPPD: -1.4282, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0660, LPPD: -0.5892, Alpha: 2.5915, Beta: 8.4523\n",
            "PGD Iter 101/500, Test Error: 0.0600, LPPD: -0.3329, Alpha: 2.6527, Beta: 4.7244\n",
            "PGD Iter 151/500, Test Error: 0.0440, LPPD: -0.2999, Alpha: 2.6801, Beta: 3.0148\n",
            "PGD Iter 201/500, Test Error: 0.0540, LPPD: -0.2315, Alpha: 2.6980, Beta: 2.8788\n",
            "PGD Iter 251/500, Test Error: 0.0500, LPPD: -0.1745, Alpha: 2.7104, Beta: 2.7974\n",
            "PGD Iter 301/500, Test Error: 0.0480, LPPD: -0.2485, Alpha: 2.7195, Beta: 2.7289\n",
            "PGD Iter 351/500, Test Error: 0.0540, LPPD: -0.2858, Alpha: 2.7264, Beta: 2.6894\n",
            "PGD Iter 401/500, Test Error: 0.0440, LPPD: -0.1932, Alpha: 2.7327, Beta: 2.6533\n",
            "PGD Iter 451/500, Test Error: 0.0400, LPPD: -0.2957, Alpha: 2.7377, Beta: 2.6160\n",
            "PGD Iter 500/500, Test Error: 0.0440, LPPD: -0.1650, Alpha: 2.7425, Beta: 2.5983\n",
            "PGD run 3 finished with Test Error: 0.0440\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7640, LPPD: -1.4282, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.1000, LPPD: -5.4308, Alpha: 2.1073, Beta: 3.6918\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.0900, LPPD: -4.5294, Alpha: 2.1598, Beta: 3.0866\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.0880, LPPD: -4.3773, Alpha: 2.1850, Beta: 2.9714\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.0760, LPPD: -3.2729, Alpha: 2.1948, Beta: 2.7785\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.0820, LPPD: -4.1342, Alpha: 2.1990, Beta: 2.7228\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.0660, LPPD: -3.2113, Alpha: 2.2059, Beta: 2.6749\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0820, LPPD: -3.8132, Alpha: 2.2105, Beta: 2.5744\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0820, LPPD: -3.7152, Alpha: 2.2145, Beta: 2.5149\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.0920, LPPD: -3.3556, Alpha: 2.2162, Beta: 2.5667\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0740, LPPD: -2.1133, Alpha: 2.2165, Beta: 2.4821\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 3 finished with Test Error: 0.0740\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7640, LPPD: -1.4282, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0800, LPPD: -0.9451, Alpha: 2.2362, Beta: 20.2888\n",
            "SOUL Iter 101/500, Test Error: 0.0460, LPPD: -1.1798, Alpha: 2.4477, Beta: 16.5387\n",
            "SOUL Iter 151/500, Test Error: 0.0740, LPPD: -1.4126, Alpha: 2.6049, Beta: 12.7887\n",
            "SOUL Iter 201/500, Test Error: 0.0960, LPPD: -1.2491, Alpha: 2.6865, Beta: 9.0387\n",
            "SOUL Iter 251/500, Test Error: 0.0800, LPPD: -1.7198, Alpha: 2.7592, Beta: 5.2956\n",
            "SOUL Iter 301/500, Test Error: 0.0960, LPPD: -2.1624, Alpha: 2.8282, Beta: 3.2077\n",
            "SOUL Iter 351/500, Test Error: 0.1140, LPPD: -2.0879, Alpha: 2.8874, Beta: 2.9408\n",
            "SOUL Iter 401/500, Test Error: 0.0960, LPPD: -1.6644, Alpha: 2.9190, Beta: 2.6362\n",
            "SOUL Iter 451/500, Test Error: 0.1080, LPPD: -1.1343, Alpha: 2.9548, Beta: 2.1568\n",
            "SOUL Iter 500/500, Test Error: 0.0940, LPPD: -2.1299, Alpha: 2.9829, Beta: 2.1145\n",
            "SOUL run 3 finished with Test Error: 0.0940\n",
            "\n",
            "========================= RUN 4/10 =========================\n",
            "Master Key: [    0 12348]\n",
            "[204896617 362581595] [ 556737389 2100154806] [3692117573 3309360152] [1340738131 1883904564]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 40\n",
            "Total Parameters: 31520 (W: 31360, V: 160)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7520, LPPD: -1.3960, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0640, LPPD: -0.3647, Alpha: 2.5607, Beta: 8.1908\n",
            "PGD Iter 101/500, Test Error: 0.0580, LPPD: -0.3375, Alpha: 2.6194, Beta: 4.4732\n",
            "PGD Iter 151/500, Test Error: 0.0580, LPPD: -0.4006, Alpha: 2.6472, Beta: 2.9391\n",
            "PGD Iter 201/500, Test Error: 0.0500, LPPD: -0.2168, Alpha: 2.6649, Beta: 2.8053\n",
            "PGD Iter 251/500, Test Error: 0.0540, LPPD: -0.2867, Alpha: 2.6769, Beta: 2.7168\n",
            "PGD Iter 301/500, Test Error: 0.0460, LPPD: -0.2046, Alpha: 2.6863, Beta: 2.6589\n",
            "PGD Iter 351/500, Test Error: 0.0420, LPPD: -0.3051, Alpha: 2.6928, Beta: 2.6073\n",
            "PGD Iter 401/500, Test Error: 0.0460, LPPD: -0.3148, Alpha: 2.6985, Beta: 2.5934\n",
            "PGD Iter 451/500, Test Error: 0.0480, LPPD: -0.2849, Alpha: 2.7035, Beta: 2.5724\n",
            "PGD Iter 500/500, Test Error: 0.0440, LPPD: -0.2498, Alpha: 2.7076, Beta: 2.5419\n",
            "PGD run 4 finished with Test Error: 0.0440\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7520, LPPD: -1.3960, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.1740, LPPD: -10.0012, Alpha: 2.1226, Beta: 3.2047\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.1280, LPPD: -5.8814, Alpha: 2.1664, Beta: 2.9750\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.1060, LPPD: -4.9258, Alpha: 2.1810, Beta: 2.8505\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.1100, LPPD: -4.8933, Alpha: 2.1943, Beta: 2.6911\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.0920, LPPD: -4.0325, Alpha: 2.2027, Beta: 2.7032\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.0960, LPPD: -3.9745, Alpha: 2.2101, Beta: 2.5587\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.1220, LPPD: -5.3101, Alpha: 2.2154, Beta: 2.4186\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.1020, LPPD: -3.8737, Alpha: 2.2171, Beta: 2.4603\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.1000, LPPD: -4.1074, Alpha: 2.2192, Beta: 2.4522\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0900, LPPD: -2.6721, Alpha: 2.2260, Beta: 2.4720\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 4 finished with Test Error: 0.0900\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7520, LPPD: -1.3960, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.1040, LPPD: -1.1803, Alpha: 2.2741, Beta: 15.9986\n",
            "SOUL Iter 101/500, Test Error: 0.0700, LPPD: -1.4075, Alpha: 2.4690, Beta: 12.2486\n",
            "SOUL Iter 151/500, Test Error: 0.0860, LPPD: -1.4840, Alpha: 2.5991, Beta: 8.4986\n",
            "SOUL Iter 201/500, Test Error: 0.0900, LPPD: -1.7500, Alpha: 2.6910, Beta: 4.7650\n",
            "SOUL Iter 251/500, Test Error: 0.0960, LPPD: -2.0863, Alpha: 2.7627, Beta: 3.0737\n",
            "SOUL Iter 301/500, Test Error: 0.0960, LPPD: -2.4570, Alpha: 2.8262, Beta: 2.8068\n",
            "SOUL Iter 351/500, Test Error: 0.1020, LPPD: -1.6471, Alpha: 2.8704, Beta: 2.3671\n",
            "SOUL Iter 401/500, Test Error: 0.1060, LPPD: -2.0682, Alpha: 2.9048, Beta: 2.3362\n",
            "SOUL Iter 451/500, Test Error: 0.1080, LPPD: -1.6970, Alpha: 2.9457, Beta: 2.2688\n",
            "SOUL Iter 500/500, Test Error: 0.0960, LPPD: -2.2132, Alpha: 2.9706, Beta: 2.2533\n",
            "SOUL run 4 finished with Test Error: 0.0960\n",
            "\n",
            "========================= RUN 5/10 =========================\n",
            "Master Key: [    0 12349]\n",
            "[1516789638  985863638] [2166140375  885614934] [826974616 878131554] [ 902379975 3020925066]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 40\n",
            "Total Parameters: 31520 (W: 31360, V: 160)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7860, LPPD: -1.4469, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0660, LPPD: -0.4183, Alpha: 2.5505, Beta: 8.8304\n",
            "PGD Iter 101/500, Test Error: 0.0540, LPPD: -0.2972, Alpha: 2.6111, Beta: 5.0897\n",
            "PGD Iter 151/500, Test Error: 0.0600, LPPD: -0.3419, Alpha: 2.6399, Beta: 2.9464\n",
            "PGD Iter 201/500, Test Error: 0.0560, LPPD: -0.2649, Alpha: 2.6582, Beta: 2.8155\n",
            "PGD Iter 251/500, Test Error: 0.0520, LPPD: -0.3823, Alpha: 2.6706, Beta: 2.7220\n",
            "PGD Iter 301/500, Test Error: 0.0480, LPPD: -0.3004, Alpha: 2.6795, Beta: 2.6674\n",
            "PGD Iter 351/500, Test Error: 0.0460, LPPD: -0.2662, Alpha: 2.6868, Beta: 2.6401\n",
            "PGD Iter 401/500, Test Error: 0.0460, LPPD: -0.2883, Alpha: 2.6926, Beta: 2.6105\n",
            "PGD Iter 451/500, Test Error: 0.0480, LPPD: -0.2714, Alpha: 2.6976, Beta: 2.5902\n",
            "PGD Iter 500/500, Test Error: 0.0500, LPPD: -0.2258, Alpha: 2.7020, Beta: 2.5634\n",
            "PGD run 5 finished with Test Error: 0.0500\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7860, LPPD: -1.4469, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.1720, LPPD: -9.1691, Alpha: 2.0691, Beta: 3.2592\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.1160, LPPD: -5.0180, Alpha: 2.1057, Beta: 2.9887\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.1360, LPPD: -6.5012, Alpha: 2.1274, Beta: 2.8518\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.0860, LPPD: -4.2010, Alpha: 2.1417, Beta: 2.8167\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.0880, LPPD: -4.0914, Alpha: 2.1484, Beta: 2.7484\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.0880, LPPD: -3.7459, Alpha: 2.1571, Beta: 2.7342\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0980, LPPD: -3.8245, Alpha: 2.1650, Beta: 2.7205\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0780, LPPD: -2.5711, Alpha: 2.1699, Beta: 2.6272\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.0960, LPPD: -4.1803, Alpha: 2.1725, Beta: 2.6020\n",
            "JALA-EM Iter 500/500, ESS: 1.03, Test Error: 0.0840, LPPD: -2.9965, Alpha: 2.1776, Beta: 2.6704\n",
            "Resampling triggered for Iter 500 (ESS=1.03)\n",
            "JALA-EM run 5 finished with Test Error: 0.0840\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7860, LPPD: -1.4469, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0740, LPPD: -0.7702, Alpha: 2.4106, Beta: 26.3622\n",
            "SOUL Iter 101/500, Test Error: 0.1040, LPPD: -1.3880, Alpha: 2.5771, Beta: 22.6122\n",
            "SOUL Iter 151/500, Test Error: 0.0680, LPPD: -1.2977, Alpha: 2.6651, Beta: 18.8621\n",
            "SOUL Iter 201/500, Test Error: 0.0780, LPPD: -1.4022, Alpha: 2.7419, Beta: 15.1121\n",
            "SOUL Iter 251/500, Test Error: 0.0760, LPPD: -1.2539, Alpha: 2.8155, Beta: 11.3621\n",
            "SOUL Iter 301/500, Test Error: 0.0940, LPPD: -1.7175, Alpha: 2.8548, Beta: 7.6122\n",
            "SOUL Iter 351/500, Test Error: 0.0960, LPPD: -2.2505, Alpha: 2.9163, Beta: 4.0064\n",
            "SOUL Iter 401/500, Test Error: 0.0840, LPPD: -1.5110, Alpha: 2.9599, Beta: 3.2807\n",
            "SOUL Iter 451/500, Test Error: 0.0880, LPPD: -2.2240, Alpha: 2.9932, Beta: 3.1524\n",
            "SOUL Iter 500/500, Test Error: 0.0920, LPPD: -2.0442, Alpha: 3.0199, Beta: 3.0253\n",
            "SOUL run 5 finished with Test Error: 0.0920\n",
            "\n",
            "========================= RUN 6/10 =========================\n",
            "Master Key: [    0 12350]\n",
            "[ 663308273 3542856509] [ 255441095 3931730473] [2352717910 3214489336] [1824469452 2789628418]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 40\n",
            "Total Parameters: 31520 (W: 31360, V: 160)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7320, LPPD: -1.4167, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0700, LPPD: -0.5581, Alpha: 2.5643, Beta: 8.5543\n",
            "PGD Iter 101/500, Test Error: 0.0640, LPPD: -0.2016, Alpha: 2.6249, Beta: 4.8198\n",
            "PGD Iter 151/500, Test Error: 0.0520, LPPD: -0.1926, Alpha: 2.6506, Beta: 2.9212\n",
            "PGD Iter 201/500, Test Error: 0.0520, LPPD: -0.3369, Alpha: 2.6671, Beta: 2.7968\n",
            "PGD Iter 251/500, Test Error: 0.0480, LPPD: -0.1968, Alpha: 2.6785, Beta: 2.7060\n",
            "PGD Iter 301/500, Test Error: 0.0540, LPPD: -0.1803, Alpha: 2.6865, Beta: 2.6463\n",
            "PGD Iter 351/500, Test Error: 0.0460, LPPD: -0.1926, Alpha: 2.6936, Beta: 2.6115\n",
            "PGD Iter 401/500, Test Error: 0.0520, LPPD: -0.1879, Alpha: 2.6992, Beta: 2.5937\n",
            "PGD Iter 451/500, Test Error: 0.0520, LPPD: -0.1733, Alpha: 2.7042, Beta: 2.5519\n",
            "PGD Iter 500/500, Test Error: 0.0480, LPPD: -0.1670, Alpha: 2.7085, Beta: 2.5422\n",
            "PGD run 6 finished with Test Error: 0.0480\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7320, LPPD: -1.4167, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.1120, LPPD: -6.0921, Alpha: 2.0833, Beta: 3.2839\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.1220, LPPD: -5.6853, Alpha: 2.1161, Beta: 2.7752\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.1400, LPPD: -6.0568, Alpha: 2.1339, Beta: 2.6216\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.1800, LPPD: -8.6532, Alpha: 2.1410, Beta: 2.5650\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.1060, LPPD: -4.3156, Alpha: 2.1487, Beta: 2.5692\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.0960, LPPD: -3.8286, Alpha: 2.1537, Beta: 2.5546\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.1220, LPPD: -5.5423, Alpha: 2.1578, Beta: 2.4812\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0960, LPPD: -4.0018, Alpha: 2.1623, Beta: 2.4830\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.1140, LPPD: -4.5477, Alpha: 2.1681, Beta: 2.4090\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.1000, LPPD: -3.3172, Alpha: 2.1738, Beta: 2.5121\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 6 finished with Test Error: 0.1000\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7320, LPPD: -1.4167, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0800, LPPD: -0.9225, Alpha: 2.2775, Beta: 18.9233\n",
            "SOUL Iter 101/500, Test Error: 0.0940, LPPD: -1.0546, Alpha: 2.4507, Beta: 15.1733\n",
            "SOUL Iter 151/500, Test Error: 0.0760, LPPD: -1.4269, Alpha: 2.5763, Beta: 11.4233\n",
            "SOUL Iter 201/500, Test Error: 0.0940, LPPD: -1.8462, Alpha: 2.6698, Beta: 7.6734\n",
            "SOUL Iter 251/500, Test Error: 0.0960, LPPD: -1.6669, Alpha: 2.7483, Beta: 4.0029\n",
            "SOUL Iter 301/500, Test Error: 0.1000, LPPD: -1.8406, Alpha: 2.8043, Beta: 2.8997\n",
            "SOUL Iter 351/500, Test Error: 0.1100, LPPD: -1.3758, Alpha: 2.8538, Beta: 2.5753\n",
            "SOUL Iter 401/500, Test Error: 0.1120, LPPD: -2.0241, Alpha: 2.9017, Beta: 2.2262\n",
            "SOUL Iter 451/500, Test Error: 0.0860, LPPD: -1.7676, Alpha: 2.9236, Beta: 2.0543\n",
            "SOUL Iter 500/500, Test Error: 0.0900, LPPD: -1.8935, Alpha: 2.9500, Beta: 2.1894\n",
            "SOUL run 6 finished with Test Error: 0.0900\n",
            "\n",
            "========================= RUN 7/10 =========================\n",
            "Master Key: [    0 12351]\n",
            "[4291887911 2647564198] [ 205525861 3451136917] [   9081508 2997207858] [3725356808 2941390772]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 40\n",
            "Total Parameters: 31520 (W: 31360, V: 160)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7640, LPPD: -1.4337, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0620, LPPD: -0.2186, Alpha: 2.5607, Beta: 9.0345\n",
            "PGD Iter 101/500, Test Error: 0.0560, LPPD: -0.2545, Alpha: 2.6180, Beta: 5.2909\n",
            "PGD Iter 151/500, Test Error: 0.0580, LPPD: -0.3299, Alpha: 2.6459, Beta: 2.9961\n",
            "PGD Iter 201/500, Test Error: 0.0560, LPPD: -0.3567, Alpha: 2.6640, Beta: 2.8438\n",
            "PGD Iter 251/500, Test Error: 0.0540, LPPD: -0.2908, Alpha: 2.6770, Beta: 2.7719\n",
            "PGD Iter 301/500, Test Error: 0.0460, LPPD: -0.2583, Alpha: 2.6865, Beta: 2.7131\n",
            "PGD Iter 351/500, Test Error: 0.0480, LPPD: -0.2559, Alpha: 2.6942, Beta: 2.6717\n",
            "PGD Iter 401/500, Test Error: 0.0460, LPPD: -0.2210, Alpha: 2.7007, Beta: 2.6391\n",
            "PGD Iter 451/500, Test Error: 0.0500, LPPD: -0.2159, Alpha: 2.7057, Beta: 2.6205\n",
            "PGD Iter 500/500, Test Error: 0.0500, LPPD: -0.1724, Alpha: 2.7105, Beta: 2.5972\n",
            "PGD run 7 finished with Test Error: 0.0500\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7640, LPPD: -1.4337, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.1080, LPPD: -5.7384, Alpha: 2.1270, Beta: 3.6003\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.0780, LPPD: -3.9856, Alpha: 2.1654, Beta: 3.0719\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.0980, LPPD: -4.6334, Alpha: 2.1859, Beta: 2.9450\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.1020, LPPD: -4.8263, Alpha: 2.1984, Beta: 2.8261\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.0940, LPPD: -4.9417, Alpha: 2.2071, Beta: 2.8030\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.1180, LPPD: -5.1420, Alpha: 2.2151, Beta: 2.7566\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.1120, LPPD: -4.9455, Alpha: 2.2197, Beta: 2.7668\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.1020, LPPD: -4.8444, Alpha: 2.2225, Beta: 2.8178\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.1060, LPPD: -4.5971, Alpha: 2.2252, Beta: 2.7904\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0980, LPPD: -3.8030, Alpha: 2.2248, Beta: 2.7383\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 7 finished with Test Error: 0.0980\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7640, LPPD: -1.4337, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0800, LPPD: -1.1183, Alpha: 2.2601, Beta: 17.9814\n",
            "SOUL Iter 101/500, Test Error: 0.0920, LPPD: -1.2088, Alpha: 2.4430, Beta: 14.2313\n",
            "SOUL Iter 151/500, Test Error: 0.0780, LPPD: -1.1872, Alpha: 2.5702, Beta: 10.4814\n",
            "SOUL Iter 201/500, Test Error: 0.0800, LPPD: -1.4290, Alpha: 2.6745, Beta: 6.7316\n",
            "SOUL Iter 251/500, Test Error: 0.0860, LPPD: -1.8620, Alpha: 2.7660, Beta: 3.3856\n",
            "SOUL Iter 301/500, Test Error: 0.0920, LPPD: -1.5630, Alpha: 2.8227, Beta: 2.8056\n",
            "SOUL Iter 351/500, Test Error: 0.1000, LPPD: -2.4627, Alpha: 2.8738, Beta: 2.4930\n",
            "SOUL Iter 401/500, Test Error: 0.0960, LPPD: -1.6687, Alpha: 2.9076, Beta: 2.2789\n",
            "SOUL Iter 451/500, Test Error: 0.0840, LPPD: -1.4583, Alpha: 2.9419, Beta: 2.2530\n",
            "SOUL Iter 500/500, Test Error: 0.0900, LPPD: -1.2978, Alpha: 2.9663, Beta: 2.0556\n",
            "SOUL run 7 finished with Test Error: 0.0900\n",
            "\n",
            "========================= RUN 8/10 =========================\n",
            "Master Key: [    0 12352]\n",
            "[4145063284  827122148] [3377278252 3595848238] [1835196978 3325478056] [ 194721067 3532734789]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 40\n",
            "Total Parameters: 31520 (W: 31360, V: 160)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7920, LPPD: -1.4382, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0740, LPPD: -0.3508, Alpha: 2.5542, Beta: 8.5808\n",
            "PGD Iter 101/500, Test Error: 0.0580, LPPD: -0.3369, Alpha: 2.6130, Beta: 4.8458\n",
            "PGD Iter 151/500, Test Error: 0.0580, LPPD: -0.2487, Alpha: 2.6420, Beta: 2.9406\n",
            "PGD Iter 201/500, Test Error: 0.0520, LPPD: -0.1886, Alpha: 2.6599, Beta: 2.8083\n",
            "PGD Iter 251/500, Test Error: 0.0540, LPPD: -0.1951, Alpha: 2.6722, Beta: 2.7288\n",
            "PGD Iter 301/500, Test Error: 0.0480, LPPD: -0.1733, Alpha: 2.6809, Beta: 2.6608\n",
            "PGD Iter 351/500, Test Error: 0.0440, LPPD: -0.3112, Alpha: 2.6875, Beta: 2.6022\n",
            "PGD Iter 401/500, Test Error: 0.0480, LPPD: -0.2386, Alpha: 2.6936, Beta: 2.5870\n",
            "PGD Iter 451/500, Test Error: 0.0500, LPPD: -0.2147, Alpha: 2.6985, Beta: 2.5688\n",
            "PGD Iter 500/500, Test Error: 0.0460, LPPD: -0.2937, Alpha: 2.7024, Beta: 2.5525\n",
            "PGD run 8 finished with Test Error: 0.0460\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7920, LPPD: -1.4382, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.1340, LPPD: -7.3457, Alpha: 2.1328, Beta: 3.6515\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.1280, LPPD: -6.2839, Alpha: 2.1605, Beta: 2.9720\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.1420, LPPD: -7.6361, Alpha: 2.1770, Beta: 2.8063\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.0840, LPPD: -3.7523, Alpha: 2.1910, Beta: 2.7027\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.1340, LPPD: -5.5788, Alpha: 2.1987, Beta: 2.6672\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.1180, LPPD: -5.3331, Alpha: 2.2015, Beta: 2.6604\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0900, LPPD: -4.5786, Alpha: 2.2049, Beta: 2.6357\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0800, LPPD: -4.0803, Alpha: 2.2086, Beta: 2.6017\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.1100, LPPD: -4.4747, Alpha: 2.2129, Beta: 2.6140\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.1020, LPPD: -2.7147, Alpha: 2.2159, Beta: 2.5287\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 8 finished with Test Error: 0.1020\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7920, LPPD: -1.4382, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0920, LPPD: -0.9546, Alpha: 2.2245, Beta: 16.8149\n",
            "SOUL Iter 101/500, Test Error: 0.0760, LPPD: -0.9350, Alpha: 2.4322, Beta: 13.0649\n",
            "SOUL Iter 151/500, Test Error: 0.0780, LPPD: -1.6990, Alpha: 2.5775, Beta: 9.3149\n",
            "SOUL Iter 201/500, Test Error: 0.0780, LPPD: -1.5772, Alpha: 2.6673, Beta: 5.5679\n",
            "SOUL Iter 251/500, Test Error: 0.1000, LPPD: -1.6693, Alpha: 2.7389, Beta: 3.0265\n",
            "SOUL Iter 301/500, Test Error: 0.0700, LPPD: -1.2309, Alpha: 2.7971, Beta: 2.6294\n",
            "SOUL Iter 351/500, Test Error: 0.1040, LPPD: -1.8831, Alpha: 2.8362, Beta: 2.4022\n",
            "SOUL Iter 401/500, Test Error: 0.1040, LPPD: -1.0071, Alpha: 2.8681, Beta: 2.2313\n",
            "SOUL Iter 451/500, Test Error: 0.0920, LPPD: -1.6508, Alpha: 2.9016, Beta: 2.1060\n",
            "SOUL Iter 500/500, Test Error: 0.1020, LPPD: -1.4370, Alpha: 2.9358, Beta: 2.0233\n",
            "SOUL run 8 finished with Test Error: 0.1020\n",
            "\n",
            "========================= RUN 9/10 =========================\n",
            "Master Key: [    0 12353]\n",
            "[2991864547 2732004738] [1738224256 3454507916] [193717663  78461347] [3179925608 4234927111]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 40\n",
            "Total Parameters: 31520 (W: 31360, V: 160)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7620, LPPD: -1.4218, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0640, LPPD: -0.4450, Alpha: 2.5592, Beta: 8.8799\n",
            "PGD Iter 101/500, Test Error: 0.0600, LPPD: -0.2736, Alpha: 2.6165, Beta: 5.1380\n",
            "PGD Iter 151/500, Test Error: 0.0500, LPPD: -0.3054, Alpha: 2.6443, Beta: 2.9379\n",
            "PGD Iter 201/500, Test Error: 0.0540, LPPD: -0.2376, Alpha: 2.6604, Beta: 2.7861\n",
            "PGD Iter 251/500, Test Error: 0.0580, LPPD: -0.3473, Alpha: 2.6727, Beta: 2.6980\n",
            "PGD Iter 301/500, Test Error: 0.0540, LPPD: -0.3230, Alpha: 2.6815, Beta: 2.6471\n",
            "PGD Iter 351/500, Test Error: 0.0460, LPPD: -0.1736, Alpha: 2.6896, Beta: 2.6076\n",
            "PGD Iter 401/500, Test Error: 0.0500, LPPD: -0.2361, Alpha: 2.6958, Beta: 2.5767\n",
            "PGD Iter 451/500, Test Error: 0.0500, LPPD: -0.2643, Alpha: 2.7017, Beta: 2.5661\n",
            "PGD Iter 500/500, Test Error: 0.0500, LPPD: -0.3561, Alpha: 2.7055, Beta: 2.5362\n",
            "PGD run 9 finished with Test Error: 0.0500\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7620, LPPD: -1.4218, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.1100, LPPD: -6.1515, Alpha: 2.1391, Beta: 3.6477\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.1300, LPPD: -6.6899, Alpha: 2.1720, Beta: 2.8338\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.1120, LPPD: -5.1482, Alpha: 2.1866, Beta: 2.7146\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.1180, LPPD: -4.9901, Alpha: 2.1968, Beta: 2.6501\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.1100, LPPD: -5.1272, Alpha: 2.2040, Beta: 2.5798\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.1020, LPPD: -4.3037, Alpha: 2.2067, Beta: 2.5389\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0840, LPPD: -3.5143, Alpha: 2.2103, Beta: 2.4694\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0880, LPPD: -2.6725, Alpha: 2.2150, Beta: 2.4046\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.0780, LPPD: -3.3714, Alpha: 2.2197, Beta: 2.3865\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0920, LPPD: -3.0765, Alpha: 2.2239, Beta: 2.4237\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 9 finished with Test Error: 0.0920\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7620, LPPD: -1.4218, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0980, LPPD: -1.1506, Alpha: 2.2746, Beta: 23.7078\n",
            "SOUL Iter 101/500, Test Error: 0.0840, LPPD: -1.2294, Alpha: 2.4804, Beta: 19.9578\n",
            "SOUL Iter 151/500, Test Error: 0.0920, LPPD: -1.2345, Alpha: 2.6171, Beta: 16.2077\n",
            "SOUL Iter 201/500, Test Error: 0.1060, LPPD: -2.0964, Alpha: 2.7182, Beta: 12.4578\n",
            "SOUL Iter 251/500, Test Error: 0.0800, LPPD: -1.0553, Alpha: 2.7765, Beta: 8.7078\n",
            "SOUL Iter 301/500, Test Error: 0.0880, LPPD: -1.6510, Alpha: 2.8356, Beta: 4.9712\n",
            "SOUL Iter 351/500, Test Error: 0.0840, LPPD: -1.5292, Alpha: 2.8794, Beta: 3.1639\n",
            "SOUL Iter 401/500, Test Error: 0.0900, LPPD: -1.3736, Alpha: 2.9330, Beta: 2.9521\n",
            "SOUL Iter 451/500, Test Error: 0.0920, LPPD: -1.9531, Alpha: 2.9694, Beta: 2.6709\n",
            "SOUL Iter 500/500, Test Error: 0.1100, LPPD: -2.0149, Alpha: 3.0017, Beta: 2.3671\n",
            "SOUL run 9 finished with Test Error: 0.1100\n",
            "\n",
            "========================= RUN 10/10 =========================\n",
            "Master Key: [    0 12354]\n",
            "[4231861230 3988647259] [ 130090031 2951097947] [1227610480 1699833013] [2036372542 1321690354]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 40\n",
            "Total Parameters: 31520 (W: 31360, V: 160)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7400, LPPD: -1.4063, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0700, LPPD: -0.4269, Alpha: 2.5234, Beta: 8.5548\n",
            "PGD Iter 101/500, Test Error: 0.0580, LPPD: -0.3334, Alpha: 2.5824, Beta: 4.8197\n",
            "PGD Iter 151/500, Test Error: 0.0560, LPPD: -0.2984, Alpha: 2.6113, Beta: 2.9180\n",
            "PGD Iter 201/500, Test Error: 0.0480, LPPD: -0.3159, Alpha: 2.6291, Beta: 2.7929\n",
            "PGD Iter 251/500, Test Error: 0.0500, LPPD: -0.3788, Alpha: 2.6428, Beta: 2.7132\n",
            "PGD Iter 301/500, Test Error: 0.0500, LPPD: -0.3925, Alpha: 2.6524, Beta: 2.6561\n",
            "PGD Iter 351/500, Test Error: 0.0480, LPPD: -0.2056, Alpha: 2.6595, Beta: 2.6085\n",
            "PGD Iter 401/500, Test Error: 0.0540, LPPD: -0.1604, Alpha: 2.6661, Beta: 2.5873\n",
            "PGD Iter 451/500, Test Error: 0.0420, LPPD: -0.2686, Alpha: 2.6712, Beta: 2.5644\n",
            "PGD Iter 500/500, Test Error: 0.0440, LPPD: -0.2498, Alpha: 2.6757, Beta: 2.5378\n",
            "PGD run 10 finished with Test Error: 0.0440\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7400, LPPD: -1.4063, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.2400, LPPD: -13.6995, Alpha: 2.1836, Beta: 3.5715\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.1100, LPPD: -6.0578, Alpha: 2.2200, Beta: 3.0210\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.1600, LPPD: -8.7024, Alpha: 2.2362, Beta: 2.8452\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.1000, LPPD: -4.5152, Alpha: 2.2482, Beta: 2.7240\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.1080, LPPD: -4.9902, Alpha: 2.2534, Beta: 2.7170\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.0920, LPPD: -4.3437, Alpha: 2.2618, Beta: 2.7046\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0960, LPPD: -4.2800, Alpha: 2.2674, Beta: 2.6291\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.1040, LPPD: -4.6041, Alpha: 2.2733, Beta: 2.6151\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.0980, LPPD: -4.7388, Alpha: 2.2769, Beta: 2.6156\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.1040, LPPD: -3.2935, Alpha: 2.2802, Beta: 2.5271\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 10 finished with Test Error: 0.1040\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7400, LPPD: -1.4063, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0700, LPPD: -0.9009, Alpha: 2.1986, Beta: 24.9609\n",
            "SOUL Iter 101/500, Test Error: 0.0800, LPPD: -1.1877, Alpha: 2.4340, Beta: 21.2109\n",
            "SOUL Iter 151/500, Test Error: 0.0740, LPPD: -1.2237, Alpha: 2.5775, Beta: 17.4609\n",
            "SOUL Iter 201/500, Test Error: 0.0840, LPPD: -1.4973, Alpha: 2.6634, Beta: 13.7108\n",
            "SOUL Iter 251/500, Test Error: 0.0800, LPPD: -1.6892, Alpha: 2.7370, Beta: 9.9609\n",
            "SOUL Iter 301/500, Test Error: 0.0820, LPPD: -1.3494, Alpha: 2.7979, Beta: 6.2119\n",
            "SOUL Iter 351/500, Test Error: 0.1280, LPPD: -2.4090, Alpha: 2.8514, Beta: 3.2939\n",
            "SOUL Iter 401/500, Test Error: 0.1000, LPPD: -2.3624, Alpha: 2.8979, Beta: 2.9760\n",
            "SOUL Iter 451/500, Test Error: 0.0880, LPPD: -2.2918, Alpha: 2.9268, Beta: 2.7249\n",
            "SOUL Iter 500/500, Test Error: 0.0860, LPPD: -1.4091, Alpha: 2.9550, Beta: 2.4245\n",
            "SOUL run 10 finished with Test Error: 0.0860\n",
            "\n",
            "\n",
            "==================== FINAL SUMMARY (OVER 10 RUNS) ====================\n",
            "\n",
            "Final Test Error Statistics:\n",
            "PGD     -> Mean: 0.0474, Std: 0.0034\n",
            "JALA-EM -> Mean: 0.0924, Std: 0.0093\n",
            "SOUL    -> Mean: 0.0926, Std: 0.0081\n",
            "============================================================\n"
          ]
        }
      ],
      "source": [
        "run_experiment(hidden_neurons=40)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Experiemnt: $D_{h} = 512$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "uK9O4R4Ta5op",
        "outputId": "cbcb8c06-cc4e-42c0-f44c-ebabfa236ceb"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "========================= RUN 1/10 =========================\n",
            "Master Key: [    0 12345]\n",
            "[1214163296  439912094] [ 867802714 3762255628] [ 795667951 2300365598] [ 709272294 1661227809]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
            "\u001b[1m11490434/11490434\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 0us/step\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 512\n",
            "Total Parameters: 403456 (W: 401408, V: 2048)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7800, LPPD: -1.4224, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0580, LPPD: -0.5626, Alpha: 2.3580, Beta: 6.1150\n",
            "PGD Iter 101/500, Test Error: 0.0540, LPPD: -0.5468, Alpha: 2.4480, Beta: 3.5865\n",
            "PGD Iter 151/500, Test Error: 0.0580, LPPD: -0.6621, Alpha: 2.4921, Beta: 3.5648\n",
            "PGD Iter 201/500, Test Error: 0.0560, LPPD: -0.4047, Alpha: 2.5198, Beta: 3.5757\n",
            "PGD Iter 251/500, Test Error: 0.0560, LPPD: -0.5309, Alpha: 2.5382, Beta: 3.5801\n",
            "PGD Iter 301/500, Test Error: 0.0480, LPPD: -0.6269, Alpha: 2.5527, Beta: 3.5829\n",
            "PGD Iter 351/500, Test Error: 0.0500, LPPD: -0.7582, Alpha: 2.5639, Beta: 3.5843\n",
            "PGD Iter 401/500, Test Error: 0.0460, LPPD: -0.5282, Alpha: 2.5761, Beta: 3.5865\n",
            "PGD Iter 451/500, Test Error: 0.0460, LPPD: -0.5263, Alpha: 2.5891, Beta: 3.5899\n",
            "PGD Iter 500/500, Test Error: 0.0440, LPPD: -0.5178, Alpha: 2.6012, Beta: 3.5944\n",
            "PGD run 1 finished with Test Error: 0.0440\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7800, LPPD: -1.4224, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.0740, LPPD: -5.0596, Alpha: 2.0339, Beta: 3.6574\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.0640, LPPD: -4.2850, Alpha: 2.1192, Beta: 3.2920\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.0780, LPPD: -5.2343, Alpha: 2.1313, Beta: 3.2878\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.0740, LPPD: -4.5994, Alpha: 2.1332, Beta: 3.2891\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.0780, LPPD: -5.0801, Alpha: 2.1364, Beta: 3.2890\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.0860, LPPD: -5.2511, Alpha: 2.1382, Beta: 3.2898\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0700, LPPD: -4.1719, Alpha: 2.1411, Beta: 3.2935\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0660, LPPD: -4.3735, Alpha: 2.1447, Beta: 3.2959\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.0760, LPPD: -4.9265, Alpha: 2.1481, Beta: 3.2963\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0660, LPPD: -3.1789, Alpha: 2.1509, Beta: 3.2928\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 1 finished with Test Error: 0.0660\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7800, LPPD: -1.4224, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0720, LPPD: -2.6601, Alpha: 2.6467, Beta: 13.9901\n",
            "SOUL Iter 101/500, Test Error: 0.0740, LPPD: -2.9781, Alpha: 2.9144, Beta: 10.2401\n",
            "SOUL Iter 151/500, Test Error: 0.0780, LPPD: -3.5579, Alpha: 3.0600, Beta: 6.4924\n",
            "SOUL Iter 201/500, Test Error: 0.0860, LPPD: -3.3813, Alpha: 3.1742, Beta: 3.9383\n",
            "SOUL Iter 251/500, Test Error: 0.0680, LPPD: -2.6149, Alpha: 3.2491, Beta: 3.8981\n",
            "SOUL Iter 301/500, Test Error: 0.0840, LPPD: -3.4372, Alpha: 3.3041, Beta: 3.8980\n",
            "SOUL Iter 351/500, Test Error: 0.0860, LPPD: -4.0110, Alpha: 3.3454, Beta: 3.8952\n",
            "SOUL Iter 401/500, Test Error: 0.0820, LPPD: -3.4537, Alpha: 3.3833, Beta: 3.8707\n",
            "SOUL Iter 451/500, Test Error: 0.0740, LPPD: -3.2595, Alpha: 3.4150, Beta: 3.8607\n",
            "SOUL Iter 500/500, Test Error: 0.0880, LPPD: -3.6683, Alpha: 3.4419, Beta: 3.8625\n",
            "SOUL run 1 finished with Test Error: 0.0880\n",
            "\n",
            "========================= RUN 2/10 =========================\n",
            "Master Key: [    0 12346]\n",
            "[ 632167804 2604013364] [1621771532 2839293200] [3728677534 1583167789] [2821711515 2397397985]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 512\n",
            "Total Parameters: 403456 (W: 401408, V: 2048)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7520, LPPD: -1.3996, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0540, LPPD: -0.8114, Alpha: 2.3474, Beta: 5.9721\n",
            "PGD Iter 101/500, Test Error: 0.0560, LPPD: -0.6758, Alpha: 2.4412, Beta: 3.5797\n",
            "PGD Iter 151/500, Test Error: 0.0500, LPPD: -0.5375, Alpha: 2.4847, Beta: 3.5651\n",
            "PGD Iter 201/500, Test Error: 0.0480, LPPD: -0.6674, Alpha: 2.5105, Beta: 3.5734\n",
            "PGD Iter 251/500, Test Error: 0.0500, LPPD: -0.6587, Alpha: 2.5276, Beta: 3.5780\n",
            "PGD Iter 301/500, Test Error: 0.0520, LPPD: -0.5298, Alpha: 2.5453, Beta: 3.5842\n",
            "PGD Iter 351/500, Test Error: 0.0460, LPPD: -0.2647, Alpha: 2.5602, Beta: 3.5891\n",
            "PGD Iter 401/500, Test Error: 0.0460, LPPD: -0.5225, Alpha: 2.5756, Beta: 3.5948\n",
            "PGD Iter 451/500, Test Error: 0.0520, LPPD: -0.5255, Alpha: 2.5870, Beta: 3.5964\n",
            "PGD Iter 500/500, Test Error: 0.0460, LPPD: -0.5240, Alpha: 2.5972, Beta: 3.5980\n",
            "PGD run 2 finished with Test Error: 0.0460\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7520, LPPD: -1.3996, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.0760, LPPD: -4.9108, Alpha: 1.9961, Beta: 3.4214\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.0620, LPPD: -4.1185, Alpha: 2.0866, Beta: 3.3271\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.0800, LPPD: -5.2771, Alpha: 2.0960, Beta: 3.3228\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.0640, LPPD: -3.9300, Alpha: 2.1425, Beta: 3.3656\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.0680, LPPD: -4.3990, Alpha: 2.1584, Beta: 3.3696\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.0620, LPPD: -4.2829, Alpha: 2.1626, Beta: 3.3672\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0640, LPPD: -4.2377, Alpha: 2.1648, Beta: 3.3634\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0680, LPPD: -4.4608, Alpha: 2.1692, Beta: 3.3625\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.0680, LPPD: -4.2779, Alpha: 2.1715, Beta: 3.3619\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0800, LPPD: -3.4363, Alpha: 2.1753, Beta: 3.3614\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 2 finished with Test Error: 0.0800\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7520, LPPD: -1.3996, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0840, LPPD: -3.6108, Alpha: 2.7202, Beta: 15.4407\n",
            "SOUL Iter 101/500, Test Error: 0.0720, LPPD: -3.1366, Alpha: 2.9713, Beta: 11.6907\n",
            "SOUL Iter 151/500, Test Error: 0.0700, LPPD: -2.7931, Alpha: 3.1161, Beta: 7.9409\n",
            "SOUL Iter 201/500, Test Error: 0.0920, LPPD: -3.5058, Alpha: 3.1995, Beta: 4.4228\n",
            "SOUL Iter 251/500, Test Error: 0.0820, LPPD: -3.7740, Alpha: 3.2810, Beta: 3.9638\n",
            "SOUL Iter 301/500, Test Error: 0.0820, LPPD: -2.9737, Alpha: 3.3376, Beta: 3.9595\n",
            "SOUL Iter 351/500, Test Error: 0.0620, LPPD: -2.7078, Alpha: 3.3799, Beta: 3.9526\n",
            "SOUL Iter 401/500, Test Error: 0.0860, LPPD: -4.0403, Alpha: 3.4118, Beta: 3.9476\n",
            "SOUL Iter 451/500, Test Error: 0.0860, LPPD: -3.1311, Alpha: 3.4489, Beta: 3.9413\n",
            "SOUL Iter 500/500, Test Error: 0.0680, LPPD: -2.9997, Alpha: 3.4804, Beta: 3.9445\n",
            "SOUL run 2 finished with Test Error: 0.0680\n",
            "\n",
            "========================= RUN 3/10 =========================\n",
            "Master Key: [    0 12347]\n",
            "[2760833302 2604130258] [2353787085 2855123098] [1645460419  607253567] [1693954434 2313133562]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 512\n",
            "Total Parameters: 403456 (W: 401408, V: 2048)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7420, LPPD: -1.3963, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0540, LPPD: -0.5580, Alpha: 2.3234, Beta: 5.7552\n",
            "PGD Iter 101/500, Test Error: 0.0540, LPPD: -0.5880, Alpha: 2.4180, Beta: 3.5363\n",
            "PGD Iter 151/500, Test Error: 0.0540, LPPD: -0.4046, Alpha: 2.4671, Beta: 3.5386\n",
            "PGD Iter 201/500, Test Error: 0.0500, LPPD: -0.7942, Alpha: 2.4930, Beta: 3.5460\n",
            "PGD Iter 251/500, Test Error: 0.0480, LPPD: -0.9115, Alpha: 2.5131, Beta: 3.5539\n",
            "PGD Iter 301/500, Test Error: 0.0540, LPPD: -0.6640, Alpha: 2.5298, Beta: 3.5602\n",
            "PGD Iter 351/500, Test Error: 0.0500, LPPD: -0.6618, Alpha: 2.5480, Beta: 3.5689\n",
            "PGD Iter 401/500, Test Error: 0.0520, LPPD: -0.7652, Alpha: 2.5619, Beta: 3.5748\n",
            "PGD Iter 451/500, Test Error: 0.0540, LPPD: -0.5296, Alpha: 2.5743, Beta: 3.5768\n",
            "PGD Iter 500/500, Test Error: 0.0480, LPPD: -0.6123, Alpha: 2.5863, Beta: 3.5797\n",
            "PGD run 3 finished with Test Error: 0.0480\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7420, LPPD: -1.3963, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.1460, LPPD: -9.5458, Alpha: 2.0043, Beta: 3.3939\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.0560, LPPD: -3.7912, Alpha: 2.1068, Beta: 3.2732\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.0660, LPPD: -4.3116, Alpha: 2.1170, Beta: 3.2714\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.0720, LPPD: -4.6375, Alpha: 2.1186, Beta: 3.2707\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.0780, LPPD: -5.1360, Alpha: 2.1200, Beta: 3.2669\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.0660, LPPD: -4.2692, Alpha: 2.1215, Beta: 3.2637\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0620, LPPD: -4.1401, Alpha: 2.1248, Beta: 3.2641\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0700, LPPD: -4.3611, Alpha: 2.1266, Beta: 3.2628\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.0780, LPPD: -4.9671, Alpha: 2.1293, Beta: 3.2644\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0760, LPPD: -3.2181, Alpha: 2.1360, Beta: 3.2650\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 3 finished with Test Error: 0.0760\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7420, LPPD: -1.3963, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0700, LPPD: -2.3680, Alpha: 2.6759, Beta: 18.3246\n",
            "SOUL Iter 101/500, Test Error: 0.0880, LPPD: -3.2928, Alpha: 2.9429, Beta: 14.5745\n",
            "SOUL Iter 151/500, Test Error: 0.0920, LPPD: -4.2836, Alpha: 3.1186, Beta: 10.8245\n",
            "SOUL Iter 201/500, Test Error: 0.0980, LPPD: -4.1779, Alpha: 3.2299, Beta: 7.0755\n",
            "SOUL Iter 251/500, Test Error: 0.0960, LPPD: -3.9985, Alpha: 3.3161, Beta: 4.1235\n",
            "SOUL Iter 301/500, Test Error: 0.0880, LPPD: -3.6110, Alpha: 3.3810, Beta: 4.0295\n",
            "SOUL Iter 351/500, Test Error: 0.0840, LPPD: -4.5449, Alpha: 3.4285, Beta: 4.0289\n",
            "SOUL Iter 401/500, Test Error: 0.0800, LPPD: -3.8507, Alpha: 3.4608, Beta: 4.0201\n",
            "SOUL Iter 451/500, Test Error: 0.0920, LPPD: -3.9972, Alpha: 3.4894, Beta: 4.0083\n",
            "SOUL Iter 500/500, Test Error: 0.1000, LPPD: -3.4274, Alpha: 3.5125, Beta: 3.9959\n",
            "SOUL run 3 finished with Test Error: 0.1000\n",
            "\n",
            "========================= RUN 4/10 =========================\n",
            "Master Key: [    0 12348]\n",
            "[204896617 362581595] [ 556737389 2100154806] [3692117573 3309360152] [1340738131 1883904564]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 512\n",
            "Total Parameters: 403456 (W: 401408, V: 2048)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7140, LPPD: -1.3821, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0600, LPPD: -0.6847, Alpha: 2.3351, Beta: 5.7295\n",
            "PGD Iter 101/500, Test Error: 0.0480, LPPD: -0.5408, Alpha: 2.4242, Beta: 3.5438\n",
            "PGD Iter 151/500, Test Error: 0.0500, LPPD: -0.6793, Alpha: 2.4721, Beta: 3.5444\n",
            "PGD Iter 201/500, Test Error: 0.0540, LPPD: -0.5292, Alpha: 2.4997, Beta: 3.5541\n",
            "PGD Iter 251/500, Test Error: 0.0480, LPPD: -0.5271, Alpha: 2.5181, Beta: 3.5586\n",
            "PGD Iter 301/500, Test Error: 0.0520, LPPD: -0.4876, Alpha: 2.5313, Beta: 3.5600\n",
            "PGD Iter 351/500, Test Error: 0.0520, LPPD: -0.5277, Alpha: 2.5465, Beta: 3.5669\n",
            "PGD Iter 401/500, Test Error: 0.0480, LPPD: -0.6564, Alpha: 2.5595, Beta: 3.5703\n",
            "PGD Iter 451/500, Test Error: 0.0500, LPPD: -0.5265, Alpha: 2.5721, Beta: 3.5744\n",
            "PGD Iter 500/500, Test Error: 0.0500, LPPD: -0.5264, Alpha: 2.5830, Beta: 3.5767\n",
            "PGD run 4 finished with Test Error: 0.0500\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7140, LPPD: -1.3821, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.0660, LPPD: -4.2552, Alpha: 2.0850, Beta: 3.5819\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.0860, LPPD: -5.5598, Alpha: 2.1516, Beta: 3.3437\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.0720, LPPD: -4.6503, Alpha: 2.1582, Beta: 3.3421\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.0700, LPPD: -4.3084, Alpha: 2.1605, Beta: 3.3415\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.0860, LPPD: -5.5526, Alpha: 2.1620, Beta: 3.3393\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.0800, LPPD: -5.4222, Alpha: 2.1635, Beta: 3.3396\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0780, LPPD: -5.0720, Alpha: 2.1675, Beta: 3.3382\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0720, LPPD: -4.8196, Alpha: 2.1703, Beta: 3.3375\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.0720, LPPD: -4.8638, Alpha: 2.2298, Beta: 3.3789\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0700, LPPD: -3.1695, Alpha: 2.2863, Beta: 3.4210\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 4 finished with Test Error: 0.0700\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7140, LPPD: -1.3821, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0680, LPPD: -2.5548, Alpha: 2.6842, Beta: 12.5514\n",
            "SOUL Iter 101/500, Test Error: 0.0860, LPPD: -4.0428, Alpha: 2.9549, Beta: 8.8014\n",
            "SOUL Iter 151/500, Test Error: 0.0680, LPPD: -2.9514, Alpha: 3.0837, Beta: 5.0940\n",
            "SOUL Iter 201/500, Test Error: 0.0680, LPPD: -1.9213, Alpha: 3.1867, Beta: 3.8920\n",
            "SOUL Iter 251/500, Test Error: 0.0840, LPPD: -3.3402, Alpha: 3.2592, Beta: 3.8857\n",
            "SOUL Iter 301/500, Test Error: 0.0680, LPPD: -2.5099, Alpha: 3.3166, Beta: 3.8864\n",
            "SOUL Iter 351/500, Test Error: 0.0760, LPPD: -3.4503, Alpha: 3.3600, Beta: 3.8656\n",
            "SOUL Iter 401/500, Test Error: 0.0760, LPPD: -3.1073, Alpha: 3.3902, Beta: 3.8608\n",
            "SOUL Iter 451/500, Test Error: 0.0640, LPPD: -2.7094, Alpha: 3.4209, Beta: 3.8593\n",
            "SOUL Iter 500/500, Test Error: 0.0740, LPPD: -3.3733, Alpha: 3.4465, Beta: 3.8528\n",
            "SOUL run 4 finished with Test Error: 0.0740\n",
            "\n",
            "========================= RUN 5/10 =========================\n",
            "Master Key: [    0 12349]\n",
            "[1516789638  985863638] [2166140375  885614934] [826974616 878131554] [ 902379975 3020925066]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 512\n",
            "Total Parameters: 403456 (W: 401408, V: 2048)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7480, LPPD: -1.4145, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0560, LPPD: -0.6792, Alpha: 2.3844, Beta: 6.1431\n",
            "PGD Iter 101/500, Test Error: 0.0500, LPPD: -0.6630, Alpha: 2.4778, Beta: 3.6362\n",
            "PGD Iter 151/500, Test Error: 0.0520, LPPD: -0.6662, Alpha: 2.5251, Beta: 3.6153\n",
            "PGD Iter 201/500, Test Error: 0.0480, LPPD: -0.7865, Alpha: 2.5525, Beta: 3.6221\n",
            "PGD Iter 251/500, Test Error: 0.0520, LPPD: -0.6561, Alpha: 2.5698, Beta: 3.6239\n",
            "PGD Iter 301/500, Test Error: 0.0460, LPPD: -0.7560, Alpha: 2.5896, Beta: 3.6348\n",
            "PGD Iter 351/500, Test Error: 0.0460, LPPD: -0.5365, Alpha: 2.6039, Beta: 3.6370\n",
            "PGD Iter 401/500, Test Error: 0.0500, LPPD: -0.5328, Alpha: 2.6154, Beta: 3.6386\n",
            "PGD Iter 451/500, Test Error: 0.0520, LPPD: -0.5259, Alpha: 2.6276, Beta: 3.6423\n",
            "PGD Iter 500/500, Test Error: 0.0580, LPPD: -0.6590, Alpha: 2.6367, Beta: 3.6430\n",
            "PGD run 5 finished with Test Error: 0.0580\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7480, LPPD: -1.4145, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.0660, LPPD: -4.3216, Alpha: 1.8672, Beta: 2.9814\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.0720, LPPD: -4.6909, Alpha: 1.9607, Beta: 3.0502\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.0700, LPPD: -4.4790, Alpha: 1.9644, Beta: 3.0482\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.0800, LPPD: -5.0978, Alpha: 1.9662, Beta: 3.0491\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.0800, LPPD: -5.2371, Alpha: 1.9703, Beta: 3.0435\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.0800, LPPD: -5.2132, Alpha: 1.9744, Beta: 3.0409\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0800, LPPD: -5.4090, Alpha: 1.9804, Beta: 3.0350\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0820, LPPD: -5.4787, Alpha: 1.9903, Beta: 3.0346\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.0820, LPPD: -5.4165, Alpha: 2.0946, Beta: 3.1186\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0820, LPPD: -3.5833, Alpha: 2.1077, Beta: 3.1195\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 5 finished with Test Error: 0.0820\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7480, LPPD: -1.4145, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0600, LPPD: -2.5442, Alpha: 2.6991, Beta: 29.7405\n",
            "SOUL Iter 101/500, Test Error: 0.0640, LPPD: -2.7517, Alpha: 2.9573, Beta: 25.9905\n",
            "SOUL Iter 151/500, Test Error: 0.0820, LPPD: -3.9110, Alpha: 3.1048, Beta: 22.2405\n",
            "SOUL Iter 201/500, Test Error: 0.0760, LPPD: -3.2243, Alpha: 3.1929, Beta: 18.4904\n",
            "SOUL Iter 251/500, Test Error: 0.0880, LPPD: -3.5702, Alpha: 3.2797, Beta: 14.7404\n",
            "SOUL Iter 301/500, Test Error: 0.0860, LPPD: -4.1294, Alpha: 3.3336, Beta: 10.9904\n",
            "SOUL Iter 351/500, Test Error: 0.0940, LPPD: -5.0255, Alpha: 3.3964, Beta: 7.2413\n",
            "SOUL Iter 401/500, Test Error: 0.0980, LPPD: -4.3429, Alpha: 3.4455, Beta: 4.2338\n",
            "SOUL Iter 451/500, Test Error: 0.0860, LPPD: -3.2435, Alpha: 3.4893, Beta: 4.1203\n",
            "SOUL Iter 500/500, Test Error: 0.0880, LPPD: -3.5841, Alpha: 3.5204, Beta: 4.1180\n",
            "SOUL run 5 finished with Test Error: 0.0880\n",
            "\n",
            "========================= RUN 6/10 =========================\n",
            "Master Key: [    0 12350]\n",
            "[ 663308273 3542856509] [ 255441095 3931730473] [2352717910 3214489336] [1824469452 2789628418]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 512\n",
            "Total Parameters: 403456 (W: 401408, V: 2048)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7500, LPPD: -1.4218, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0580, LPPD: -0.5543, Alpha: 2.3177, Beta: 5.5862\n",
            "PGD Iter 101/500, Test Error: 0.0560, LPPD: -0.6667, Alpha: 2.4113, Beta: 3.5396\n",
            "PGD Iter 151/500, Test Error: 0.0560, LPPD: -0.5255, Alpha: 2.4600, Beta: 3.5459\n",
            "PGD Iter 201/500, Test Error: 0.0500, LPPD: -0.5252, Alpha: 2.4843, Beta: 3.5489\n",
            "PGD Iter 251/500, Test Error: 0.0480, LPPD: -0.4842, Alpha: 2.5044, Beta: 3.5546\n",
            "PGD Iter 301/500, Test Error: 0.0480, LPPD: -0.5306, Alpha: 2.5216, Beta: 3.5602\n",
            "PGD Iter 351/500, Test Error: 0.0460, LPPD: -0.3969, Alpha: 2.5375, Beta: 3.5651\n",
            "PGD Iter 401/500, Test Error: 0.0460, LPPD: -0.6560, Alpha: 2.5518, Beta: 3.5689\n",
            "PGD Iter 451/500, Test Error: 0.0520, LPPD: -0.6648, Alpha: 2.5655, Beta: 3.5739\n",
            "PGD Iter 500/500, Test Error: 0.0420, LPPD: -0.5247, Alpha: 2.5798, Beta: 3.5796\n",
            "PGD run 6 finished with Test Error: 0.0420\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7500, LPPD: -1.4218, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.0780, LPPD: -5.0082, Alpha: 2.0700, Beta: 3.3255\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.0600, LPPD: -3.7004, Alpha: 2.1350, Beta: 3.3379\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.0660, LPPD: -4.3023, Alpha: 2.1410, Beta: 3.3386\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.0720, LPPD: -4.6206, Alpha: 2.1441, Beta: 3.3384\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.0720, LPPD: -4.5961, Alpha: 2.1448, Beta: 3.3418\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.0640, LPPD: -4.3007, Alpha: 2.1462, Beta: 3.3389\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0700, LPPD: -4.3726, Alpha: 2.1487, Beta: 3.3417\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0700, LPPD: -4.6597, Alpha: 2.1530, Beta: 3.3406\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.0800, LPPD: -5.3200, Alpha: 2.1568, Beta: 3.3394\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0660, LPPD: -3.3859, Alpha: 2.1617, Beta: 3.3393\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 6 finished with Test Error: 0.0660\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7500, LPPD: -1.4218, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0720, LPPD: -2.7342, Alpha: 2.7271, Beta: 9.1186\n",
            "SOUL Iter 101/500, Test Error: 0.0760, LPPD: -2.6529, Alpha: 2.9451, Beta: 5.3883\n",
            "SOUL Iter 151/500, Test Error: 0.0560, LPPD: -2.2828, Alpha: 3.0948, Beta: 3.8634\n",
            "SOUL Iter 201/500, Test Error: 0.0940, LPPD: -4.1504, Alpha: 3.1791, Beta: 3.8447\n",
            "SOUL Iter 251/500, Test Error: 0.0640, LPPD: -3.4361, Alpha: 3.2453, Beta: 3.8445\n",
            "SOUL Iter 301/500, Test Error: 0.0860, LPPD: -3.4751, Alpha: 3.2922, Beta: 3.8506\n",
            "SOUL Iter 351/500, Test Error: 0.0840, LPPD: -4.0110, Alpha: 3.3311, Beta: 3.8365\n",
            "SOUL Iter 401/500, Test Error: 0.0800, LPPD: -3.5966, Alpha: 3.3663, Beta: 3.8246\n",
            "SOUL Iter 451/500, Test Error: 0.0900, LPPD: -4.0288, Alpha: 3.3967, Beta: 3.8233\n",
            "SOUL Iter 500/500, Test Error: 0.0900, LPPD: -4.0071, Alpha: 3.4258, Beta: 3.8034\n",
            "SOUL run 6 finished with Test Error: 0.0900\n",
            "\n",
            "========================= RUN 7/10 =========================\n",
            "Master Key: [    0 12351]\n",
            "[4291887911 2647564198] [ 205525861 3451136917] [   9081508 2997207858] [3725356808 2941390772]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 512\n",
            "Total Parameters: 403456 (W: 401408, V: 2048)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7420, LPPD: -1.4255, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0620, LPPD: -0.8230, Alpha: 2.2786, Beta: 5.8329\n",
            "PGD Iter 101/500, Test Error: 0.0560, LPPD: -0.5358, Alpha: 2.3750, Beta: 3.5013\n",
            "PGD Iter 151/500, Test Error: 0.0540, LPPD: -0.4015, Alpha: 2.4189, Beta: 3.4930\n",
            "PGD Iter 201/500, Test Error: 0.0540, LPPD: -0.6596, Alpha: 2.4453, Beta: 3.5019\n",
            "PGD Iter 251/500, Test Error: 0.0580, LPPD: -0.3999, Alpha: 2.4640, Beta: 3.5077\n",
            "PGD Iter 301/500, Test Error: 0.0500, LPPD: -0.5252, Alpha: 2.4806, Beta: 3.5132\n",
            "PGD Iter 351/500, Test Error: 0.0520, LPPD: -0.6566, Alpha: 2.5008, Beta: 3.5228\n",
            "PGD Iter 401/500, Test Error: 0.0500, LPPD: -0.6554, Alpha: 2.5139, Beta: 3.5236\n",
            "PGD Iter 451/500, Test Error: 0.0500, LPPD: -0.4875, Alpha: 2.5264, Beta: 3.5273\n",
            "PGD Iter 500/500, Test Error: 0.0580, LPPD: -0.4006, Alpha: 2.5382, Beta: 3.5300\n",
            "PGD run 7 finished with Test Error: 0.0580\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7420, LPPD: -1.4255, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.0960, LPPD: -6.3482, Alpha: 1.8890, Beta: 3.5988\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.0700, LPPD: -4.5931, Alpha: 1.9446, Beta: 3.0990\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.0680, LPPD: -4.2867, Alpha: 1.9478, Beta: 3.1001\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.0640, LPPD: -4.1076, Alpha: 1.9528, Beta: 3.1007\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.0620, LPPD: -3.8677, Alpha: 1.9599, Beta: 3.1013\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.0760, LPPD: -4.6098, Alpha: 1.9688, Beta: 3.1020\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0740, LPPD: -4.6176, Alpha: 2.0366, Beta: 3.1533\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0700, LPPD: -4.7306, Alpha: 2.1198, Beta: 3.1969\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.0760, LPPD: -4.9517, Alpha: 2.1243, Beta: 3.1942\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0740, LPPD: -3.4004, Alpha: 2.1284, Beta: 3.1920\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 7 finished with Test Error: 0.0740\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7420, LPPD: -1.4255, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0820, LPPD: -2.6544, Alpha: 2.7054, Beta: 16.3503\n",
            "SOUL Iter 101/500, Test Error: 0.0820, LPPD: -3.2050, Alpha: 2.9659, Beta: 12.6004\n",
            "SOUL Iter 151/500, Test Error: 0.0920, LPPD: -3.6227, Alpha: 3.1048, Beta: 8.8504\n",
            "SOUL Iter 201/500, Test Error: 0.0780, LPPD: -3.3469, Alpha: 3.2020, Beta: 5.1438\n",
            "SOUL Iter 251/500, Test Error: 0.0980, LPPD: -3.7313, Alpha: 3.2762, Beta: 3.9690\n",
            "SOUL Iter 301/500, Test Error: 0.0780, LPPD: -2.9841, Alpha: 3.3355, Beta: 3.9603\n",
            "SOUL Iter 351/500, Test Error: 0.0760, LPPD: -3.5042, Alpha: 3.3801, Beta: 3.9405\n",
            "SOUL Iter 401/500, Test Error: 0.0840, LPPD: -3.8026, Alpha: 3.4131, Beta: 3.9458\n",
            "SOUL Iter 451/500, Test Error: 0.0780, LPPD: -3.6764, Alpha: 3.4450, Beta: 3.9415\n",
            "SOUL Iter 500/500, Test Error: 0.0920, LPPD: -4.1075, Alpha: 3.4733, Beta: 3.9439\n",
            "SOUL run 7 finished with Test Error: 0.0920\n",
            "\n",
            "========================= RUN 8/10 =========================\n",
            "Master Key: [    0 12352]\n",
            "[4145063284  827122148] [3377278252 3595848238] [1835196978 3325478056] [ 194721067 3532734789]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 512\n",
            "Total Parameters: 403456 (W: 401408, V: 2048)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.6820, LPPD: -1.3659, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0540, LPPD: -0.6869, Alpha: 2.3338, Beta: 5.6908\n",
            "PGD Iter 101/500, Test Error: 0.0520, LPPD: -0.6645, Alpha: 2.4244, Beta: 3.5394\n",
            "PGD Iter 151/500, Test Error: 0.0460, LPPD: -0.3956, Alpha: 2.4687, Beta: 3.5361\n",
            "PGD Iter 201/500, Test Error: 0.0460, LPPD: -0.5241, Alpha: 2.4962, Beta: 3.5452\n",
            "PGD Iter 251/500, Test Error: 0.0520, LPPD: -0.6586, Alpha: 2.5162, Beta: 3.5506\n",
            "PGD Iter 301/500, Test Error: 0.0480, LPPD: -0.5204, Alpha: 2.5333, Beta: 3.5566\n",
            "PGD Iter 351/500, Test Error: 0.0480, LPPD: -0.3956, Alpha: 2.5490, Beta: 3.5627\n",
            "PGD Iter 401/500, Test Error: 0.0540, LPPD: -0.5251, Alpha: 2.5638, Beta: 3.5673\n",
            "PGD Iter 451/500, Test Error: 0.0520, LPPD: -0.6567, Alpha: 2.5748, Beta: 3.5692\n",
            "PGD Iter 500/500, Test Error: 0.0480, LPPD: -0.6533, Alpha: 2.5844, Beta: 3.5705\n",
            "PGD run 8 finished with Test Error: 0.0480\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.6820, LPPD: -1.3659, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.1180, LPPD: -7.7604, Alpha: 1.8974, Beta: 3.1645\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.0840, LPPD: -5.7085, Alpha: 1.9625, Beta: 3.1265\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.0720, LPPD: -4.9309, Alpha: 1.9673, Beta: 3.1248\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.0700, LPPD: -4.4977, Alpha: 1.9707, Beta: 3.1221\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.0900, LPPD: -5.8433, Alpha: 1.9746, Beta: 3.1220\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.0900, LPPD: -5.5142, Alpha: 1.9796, Beta: 3.1207\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0920, LPPD: -6.2583, Alpha: 1.9868, Beta: 3.1191\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0720, LPPD: -4.4442, Alpha: 2.0709, Beta: 3.1736\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.0740, LPPD: -5.0403, Alpha: 2.1443, Beta: 3.2315\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0780, LPPD: -4.1845, Alpha: 2.1483, Beta: 3.2315\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 8 finished with Test Error: 0.0780\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.6820, LPPD: -1.3659, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0680, LPPD: -2.8596, Alpha: 2.6651, Beta: 15.5048\n",
            "SOUL Iter 101/500, Test Error: 0.0900, LPPD: -3.2386, Alpha: 2.9231, Beta: 11.7548\n",
            "SOUL Iter 151/500, Test Error: 0.0760, LPPD: -3.5756, Alpha: 3.0893, Beta: 8.0049\n",
            "SOUL Iter 201/500, Test Error: 0.0940, LPPD: -3.5342, Alpha: 3.1887, Beta: 4.4449\n",
            "SOUL Iter 251/500, Test Error: 0.0880, LPPD: -4.1124, Alpha: 3.2713, Beta: 3.9211\n",
            "SOUL Iter 301/500, Test Error: 0.0860, LPPD: -3.8941, Alpha: 3.3206, Beta: 3.9144\n",
            "SOUL Iter 351/500, Test Error: 0.0660, LPPD: -3.5273, Alpha: 3.3619, Beta: 3.9167\n",
            "SOUL Iter 401/500, Test Error: 0.0920, LPPD: -3.4758, Alpha: 3.4016, Beta: 3.9153\n",
            "SOUL Iter 451/500, Test Error: 0.0980, LPPD: -4.8538, Alpha: 3.4361, Beta: 3.9217\n",
            "SOUL Iter 500/500, Test Error: 0.0660, LPPD: -2.8023, Alpha: 3.4677, Beta: 3.9189\n",
            "SOUL run 8 finished with Test Error: 0.0660\n",
            "\n",
            "========================= RUN 9/10 =========================\n",
            "Master Key: [    0 12353]\n",
            "[2991864547 2732004738] [1738224256 3454507916] [193717663  78461347] [3179925608 4234927111]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 512\n",
            "Total Parameters: 403456 (W: 401408, V: 2048)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7520, LPPD: -1.4165, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0540, LPPD: -0.8112, Alpha: 2.3555, Beta: 5.9911\n",
            "PGD Iter 101/500, Test Error: 0.0500, LPPD: -0.4056, Alpha: 2.4442, Beta: 3.5723\n",
            "PGD Iter 151/500, Test Error: 0.0580, LPPD: -0.6590, Alpha: 2.4898, Beta: 3.5646\n",
            "PGD Iter 201/500, Test Error: 0.0500, LPPD: -0.6122, Alpha: 2.5145, Beta: 3.5687\n",
            "PGD Iter 251/500, Test Error: 0.0500, LPPD: -0.6354, Alpha: 2.5334, Beta: 3.5747\n",
            "PGD Iter 301/500, Test Error: 0.0500, LPPD: -0.6539, Alpha: 2.5498, Beta: 3.5791\n",
            "PGD Iter 351/500, Test Error: 0.0520, LPPD: -0.5226, Alpha: 2.5625, Beta: 3.5820\n",
            "PGD Iter 401/500, Test Error: 0.0460, LPPD: -0.6505, Alpha: 2.5756, Beta: 3.5861\n",
            "PGD Iter 451/500, Test Error: 0.0560, LPPD: -0.5465, Alpha: 2.5888, Beta: 3.5909\n",
            "PGD Iter 500/500, Test Error: 0.0540, LPPD: -0.5280, Alpha: 2.5998, Beta: 3.5925\n",
            "PGD run 9 finished with Test Error: 0.0540\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7520, LPPD: -1.4165, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.1600, LPPD: -10.7099, Alpha: 1.9901, Beta: 3.6561\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.0680, LPPD: -4.3076, Alpha: 2.0634, Beta: 3.2565\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.0640, LPPD: -4.1329, Alpha: 2.0684, Beta: 3.2576\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.0560, LPPD: -3.6872, Alpha: 2.1290, Beta: 3.3147\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.0580, LPPD: -3.5587, Alpha: 2.1359, Beta: 3.3147\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.0540, LPPD: -3.5564, Alpha: 2.1370, Beta: 3.3151\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0560, LPPD: -3.7554, Alpha: 2.1392, Beta: 3.3158\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0620, LPPD: -4.2538, Alpha: 2.1411, Beta: 3.3151\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.0680, LPPD: -4.3633, Alpha: 2.1447, Beta: 3.3136\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0680, LPPD: -3.5084, Alpha: 2.1475, Beta: 3.3135\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 9 finished with Test Error: 0.0680\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7520, LPPD: -1.4165, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0760, LPPD: -2.7479, Alpha: 2.6702, Beta: 28.3118\n",
            "SOUL Iter 101/500, Test Error: 0.0740, LPPD: -3.3777, Alpha: 2.9225, Beta: 24.5618\n",
            "SOUL Iter 151/500, Test Error: 0.0800, LPPD: -3.3299, Alpha: 3.0714, Beta: 20.8117\n",
            "SOUL Iter 201/500, Test Error: 0.0680, LPPD: -3.3033, Alpha: 3.1891, Beta: 17.0617\n",
            "SOUL Iter 251/500, Test Error: 0.0820, LPPD: -3.1040, Alpha: 3.2708, Beta: 13.3117\n",
            "SOUL Iter 301/500, Test Error: 0.0860, LPPD: -3.8561, Alpha: 3.3414, Beta: 9.5617\n",
            "SOUL Iter 351/500, Test Error: 0.0760, LPPD: -3.0855, Alpha: 3.3950, Beta: 5.8265\n",
            "SOUL Iter 401/500, Test Error: 0.0720, LPPD: -3.2297, Alpha: 3.4443, Beta: 4.1324\n",
            "SOUL Iter 451/500, Test Error: 0.0800, LPPD: -3.1138, Alpha: 3.4915, Beta: 4.1286\n",
            "SOUL Iter 500/500, Test Error: 0.0740, LPPD: -3.8716, Alpha: 3.5308, Beta: 4.1272\n",
            "SOUL run 9 finished with Test Error: 0.0740\n",
            "\n",
            "========================= RUN 10/10 =========================\n",
            "Master Key: [    0 12354]\n",
            "[4231861230 3988647259] [ 130090031 2951097947] [1227610480 1699833013] [2036372542 1321690354]\n",
            "Loading Dataset D!\n",
            "Loading and Preprocessing Data!\n",
            "Filtering for digits: [2, 4, 7, 9]\n",
            "Remapped labels. Example: [4 9 2 4 7] -> [1 3 0 1 2]\n",
            "Subsampled 2500 images.\n",
            "Normalisation Done! Image shape: (2500, 28, 28), Num Classes: 4\n",
            "\n",
            "Splitting D into D_train and D_test!\n",
            "Final training images shape (28x28): (2000, 28, 28)\n",
            "Final testing images shape (28x28): (500, 28, 28)\n",
            "\n",
            "\n",
            "BNN Architecture:\n",
            "Hidden Neurons: 512\n",
            "Total Parameters: 403456 (W: 401408, V: 2048)\n",
            "\n",
            "Running PGD (iters=500, h=0.075)...\n",
            "PGD Iter 1/500, Test Error: 0.7320, LPPD: -1.3970, Alpha: 0.0000, Beta: 0.0000\n",
            "PGD Iter 51/500, Test Error: 0.0560, LPPD: -0.6928, Alpha: 2.3117, Beta: 5.6339\n",
            "PGD Iter 101/500, Test Error: 0.0560, LPPD: -0.6702, Alpha: 2.4057, Beta: 3.5195\n",
            "PGD Iter 151/500, Test Error: 0.0540, LPPD: -0.6708, Alpha: 2.4514, Beta: 3.5209\n",
            "PGD Iter 201/500, Test Error: 0.0560, LPPD: -0.6676, Alpha: 2.4788, Beta: 3.5307\n",
            "PGD Iter 251/500, Test Error: 0.0520, LPPD: -0.5308, Alpha: 2.4939, Beta: 3.5329\n",
            "PGD Iter 301/500, Test Error: 0.0520, LPPD: -0.5348, Alpha: 2.5108, Beta: 3.5390\n",
            "PGD Iter 351/500, Test Error: 0.0500, LPPD: -0.6386, Alpha: 2.5269, Beta: 3.5447\n",
            "PGD Iter 401/500, Test Error: 0.0580, LPPD: -0.6693, Alpha: 2.5413, Beta: 3.5493\n",
            "PGD Iter 451/500, Test Error: 0.0540, LPPD: -0.6641, Alpha: 2.5549, Beta: 3.5527\n",
            "PGD Iter 500/500, Test Error: 0.0540, LPPD: -0.6594, Alpha: 2.5677, Beta: 3.5568\n",
            "PGD run 10 finished with Test Error: 0.0540\n",
            "\n",
            "Running JALA-EM (iters=500)...\n",
            "Running JAX-compatible JALA-EM...\n",
            "JALA-EM Iter 1/500, ESS: 50.00, Test Error: 0.7320, LPPD: -1.3970, Alpha: 0.0000, Beta: 0.0000\n",
            "JALA-EM Iter 51/500, ESS: 50.00, Test Error: 0.0700, LPPD: -4.1343, Alpha: 1.9061, Beta: 3.2120\n",
            "JALA-EM Iter 101/500, ESS: 50.00, Test Error: 0.0660, LPPD: -4.0285, Alpha: 1.9620, Beta: 3.1892\n",
            "JALA-EM Iter 151/500, ESS: 50.00, Test Error: 0.0620, LPPD: -4.1261, Alpha: 1.9705, Beta: 3.1887\n",
            "JALA-EM Iter 201/500, ESS: 50.00, Test Error: 0.0640, LPPD: -4.1646, Alpha: 1.9776, Beta: 3.1928\n",
            "JALA-EM Iter 251/500, ESS: 50.00, Test Error: 0.0740, LPPD: -4.6778, Alpha: 1.9853, Beta: 3.1920\n",
            "JALA-EM Iter 301/500, ESS: 50.00, Test Error: 0.1380, LPPD: -9.3005, Alpha: 2.0803, Beta: 3.2487\n",
            "JALA-EM Iter 351/500, ESS: 50.00, Test Error: 0.0600, LPPD: -3.8387, Alpha: 2.1302, Beta: 3.2875\n",
            "JALA-EM Iter 401/500, ESS: 50.00, Test Error: 0.0640, LPPD: -4.0017, Alpha: 2.1344, Beta: 3.2872\n",
            "JALA-EM Iter 451/500, ESS: 50.00, Test Error: 0.0620, LPPD: -4.1816, Alpha: 2.1373, Beta: 3.2887\n",
            "JALA-EM Iter 500/500, ESS: 1.00, Test Error: 0.0560, LPPD: -2.7005, Alpha: 2.1429, Beta: 3.2886\n",
            "Resampling triggered for Iter 500 (ESS=1.00)\n",
            "JALA-EM run 10 finished with Test Error: 0.0560\n",
            "\n",
            "Running SOUL (iters=500, h=0.075)...\n",
            "SOUL Iter 1/500, Test Error: 0.7320, LPPD: -1.3970, Alpha: 0.0000, Beta: 0.0000\n",
            "SOUL Iter 51/500, Test Error: 0.0940, LPPD: -2.6426, Alpha: 2.6290, Beta: 16.0409\n",
            "SOUL Iter 101/500, Test Error: 0.0740, LPPD: -3.1257, Alpha: 2.9119, Beta: 12.2909\n",
            "SOUL Iter 151/500, Test Error: 0.0940, LPPD: -4.1365, Alpha: 3.0614, Beta: 8.5409\n",
            "SOUL Iter 201/500, Test Error: 0.0800, LPPD: -3.6476, Alpha: 3.1626, Beta: 4.8638\n",
            "SOUL Iter 251/500, Test Error: 0.0820, LPPD: -2.7941, Alpha: 3.2425, Beta: 3.9307\n",
            "SOUL Iter 301/500, Test Error: 0.0740, LPPD: -2.8788, Alpha: 3.3013, Beta: 3.9309\n",
            "SOUL Iter 351/500, Test Error: 0.0800, LPPD: -3.5313, Alpha: 3.3446, Beta: 3.9138\n",
            "SOUL Iter 401/500, Test Error: 0.0660, LPPD: -2.8665, Alpha: 3.3852, Beta: 3.9026\n",
            "SOUL Iter 451/500, Test Error: 0.0900, LPPD: -3.9425, Alpha: 3.4182, Beta: 3.8995\n",
            "SOUL Iter 500/500, Test Error: 0.0860, LPPD: -3.5220, Alpha: 3.4455, Beta: 3.8757\n",
            "SOUL run 10 finished with Test Error: 0.0860\n",
            "\n",
            "\n",
            "==================== FINAL SUMMARY (OVER 10 RUNS) ====================\n",
            "\n",
            "Final Test Error Statistics:\n",
            "PGD     -> Mean: 0.0502, Std: 0.0053\n",
            "JALA-EM -> Mean: 0.0716, Std: 0.0075\n",
            "SOUL    -> Mean: 0.0826, Std: 0.0107\n",
            "============================================================\n"
          ]
        }
      ],
      "source": [
        "run_experiment(hidden_neurons=512)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
