{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "24a55a23",
   "metadata": {},
   "source": [
    "# Setup\n",
    "\n",
    "Please ensure all the relevant modules are installed before running the below! This can be simply achieved via the command below in your virtual environment:\n",
    "```\n",
    "pip install -r requirements.txt\n",
    "```\n",
    "This can be run locally on a personal laptop."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "706d9fbe",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fbc843e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy.stats\n",
    "import scipy.optimize\n",
    "import scipy.special\n",
    "\n",
    "import pickle\n",
    "\n",
    "import time\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from tueplots import bundles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7889eedc",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.seterr(divide='ignore', over='ignore', under='ignore', invalid='ignore') "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "259a2657",
   "metadata": {},
   "source": [
    "# Utility Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8616abd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Data Generation ###\n",
    "\n",
    "def generate_synthetic_data(\n",
    "        M, \n",
    "        D, \n",
    "        true_sigma_scale, \n",
    "        true_w_params, \n",
    "        likelihood_type='gaussian',\n",
    "        true_nu=None,\n",
    "    ):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    true_w = np.random.multivariate_normal(true_w_params['mean'], true_w_params['cov'], size=1).T\n",
    "    \n",
    "    X = np.random.normal(0, 1, size=(M,D))\n",
    "    y_mean = X @ true_w\n",
    "    \n",
    "    if likelihood_type == 'student_t': # Student-t model\n",
    "        assert true_nu is not None, \"true_nu must be provided for Student-t data\"\n",
    "        errors = scipy.stats.t.rvs(df=true_nu, loc=0, scale=true_sigma_scale, size=(M,1))\n",
    "\n",
    "    else: # Gaussian model\n",
    "        errors = np.random.normal(0, true_sigma_scale, size=(M,1))\n",
    "        \n",
    "    y = y_mean + errors\n",
    "\n",
    "    return X, y, true_w\n",
    "\n",
    "\n",
    "### Bayesian Linear Regression Model Utility Functions (Gaussian Error) ###\n",
    "\n",
    "def potential_U_gaussian(theta_phi, w, X, y):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    phi1, phi2 = theta_phi\n",
    "    sigma_sq, alpha = np.exp(phi1), np.exp(phi2)\n",
    "    M, D = X.shape\n",
    "    w = w.reshape(D, 1)\n",
    "\n",
    "    if sigma_sq < 1e-12:\n",
    "        return 1e20 \n",
    "\n",
    "    log_L_term = (0.5 / sigma_sq) * np.sum((y - X @ w)**2)\n",
    "    log_P_w_term = (0.5 * alpha) * np.sum(w**2)\n",
    "    const_L = (M / 2) * (np.log(2 * np.pi) + phi1)\n",
    "    const_P_w = (D / 2) * (np.log(2 * np.pi) - phi2)\n",
    "    val = log_L_term + log_P_w_term + const_L + const_P_w\n",
    "\n",
    "    return val if np.isfinite(val) else 1e20\n",
    "\n",
    "def grad_U_theta_phi_gaussian(theta_phi, w, X, y):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    phi1, phi2 = theta_phi\n",
    "    sigma_sq, alpha = np.exp(phi1), np.exp(phi2)\n",
    "    M, D = X.shape\n",
    "    w = w.reshape(D, 1)\n",
    "\n",
    "    grad_phi1 = - (0.5 / sigma_sq) * np.sum((y - X @ w) ** 2) + M / 2 if sigma_sq > 1e-12 else M/2\n",
    "    grad_phi2 = (0.5 * alpha) * np.sum(w ** 2) - D / 2\n",
    "    grads = np.array([grad_phi1, grad_phi2])\n",
    "\n",
    "    return np.nan_to_num(grads, nan=0.0, posinf=1e10, neginf=-1e10)\n",
    "\n",
    "def grad_U_w_gaussian(theta_phi, w, X, y):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    phi1, phi2 = theta_phi\n",
    "    sigma_sq, alpha = np.exp(phi1), np.exp(phi2)\n",
    "    D_feat = X.shape[1]\n",
    "    w = w.reshape(D_feat, 1)\n",
    "\n",
    "    if sigma_sq < 1e-12:\n",
    "        sigma_sq = 1e-12\n",
    "\n",
    "    grad = (1 / sigma_sq) * X.T @ (X @ w - y) + alpha * w\n",
    "\n",
    "    return np.nan_to_num(grad, nan=0.0, posinf=1e10, neginf=-1e10)\n",
    "\n",
    "def initialize_particles_gaussian(theta_phi, X, y, N_particles, D_features):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    sigma_sq, alpha = np.exp(theta_phi[0]), np.exp(theta_phi[1])\n",
    "\n",
    "    if sigma_sq < 1e-12:\n",
    "        sigma_sq = 1e-12\n",
    "    if alpha < 1e-12:\n",
    "        alpha = 1e-12\n",
    "\n",
    "    Sigma_post_inv_0 = (1 / sigma_sq) * X.T @ X + alpha * np.eye(D_features) + 1e-9 * np.eye(D_features)\n",
    "\n",
    "    try:\n",
    "        Sigma_post_0 = np.linalg.inv(Sigma_post_inv_0)\n",
    "\n",
    "    except np.linalg.LinAlgError: \n",
    "        Sigma_post_0 = np.linalg.pinv(Sigma_post_inv_0)\n",
    "\n",
    "    mu_post_0 = Sigma_post_0 @ ((1 / sigma_sq) * X.T @ y)\n",
    "\n",
    "    return np.random.multivariate_normal(mu_post_0.flatten(), Sigma_post_0, size=N_particles)\n",
    "\n",
    "def analytical_marginal_likelihood_gaussian(theta_phi, X, y, D_features_not_used=None, return_log=True):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    phi1, phi2 = theta_phi\n",
    "    sigma_sq, alpha = np.exp(phi1), np.exp(phi2)\n",
    "    M, D = X.shape\n",
    "\n",
    "    if not (1e-12 < sigma_sq < 1e12 and 1e-12 < alpha < 1e12):\n",
    "        return -np.inf if return_log else 0.0\n",
    "    \n",
    "    Sigma_y = sigma_sq * np.eye(M) + (1/alpha) * (X @ X.T) + 1e-9 * np.eye(M)\n",
    "\n",
    "    try:\n",
    "        L = np.linalg.cholesky(Sigma_y)\n",
    "        logdet_Sigma_y = 2 * np.sum(np.log(np.diag(L)))\n",
    "        y_flat = y.flatten()\n",
    "        alpha_vec = np.linalg.solve(L,y_flat)\n",
    "        y_Sigma_inv_y = np.dot(alpha_vec, alpha_vec)\n",
    "        log_ml = -0.5 * (M * np.log(2 * np.pi) + logdet_Sigma_y + y_Sigma_inv_y)\n",
    "\n",
    "    except np.linalg.LinAlgError:\n",
    "        return -np.inf if return_log else 0.0\n",
    "        \n",
    "    return log_ml if return_log and np.isfinite(log_ml) else (-np.inf if return_log else 0.0)\n",
    "\n",
    "def get_theta_ML_gaussian(X, y, initial_theta_phi_guess):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    bounds = [(-15, 15), (-15, 15)]\n",
    "\n",
    "    def objective_fn_gauss(theta_phi_gauss, X_data, y_data):\n",
    "        return -analytical_marginal_likelihood_gaussian(theta_phi_gauss, X_data, y_data, D_features_not_used=X_data.shape[1], return_log=True)\n",
    "\n",
    "    result = scipy.optimize.minimize(\n",
    "        objective_fn_gauss,\n",
    "        initial_theta_phi_guess,\n",
    "        args=(X,y),\n",
    "        method='L-BFGS-B',\n",
    "        bounds=bounds,\n",
    "        options={'ftol': 1e-8, 'gtol': 1e-6, 'maxiter': 200},\n",
    "    )\n",
    "    \n",
    "    if result.success:\n",
    "        return result.x\n",
    "    \n",
    "    else: \n",
    "        result_nm = scipy.optimize.minimize(\n",
    "            objective_fn_gauss,\n",
    "            initial_theta_phi_guess,\n",
    "            args=(X,y),\n",
    "            method='Nelder-Mead',\n",
    "            options={'xatol': 1e-5, 'fatol': 1e-5, 'maxiter': 400}\n",
    "        )\n",
    "\n",
    "        return result_nm.x if result_nm.success else initial_theta_phi_guess\n",
    "\n",
    "def log_gaussian_pdf(x, mu, Sigma, Sigma_inv=None, logdet_Sigma=None):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    D = len(mu)\n",
    "\n",
    "    if Sigma_inv is None or logdet_Sigma is None:\n",
    "        try:\n",
    "            Sigma_stable = Sigma + 1e-9 * np.eye(D)\n",
    "            L_Sigma = np.linalg.cholesky(Sigma_stable)\n",
    "            logdet_Sigma = 2 * np.sum(np.log(np.diag(L_Sigma)))\n",
    "            Sigma_inv = np.linalg.inv(Sigma_stable)\n",
    "\n",
    "        except np.linalg.LinAlgError:\n",
    "            try:\n",
    "                eigvals = np.linalg.eigvalsh(Sigma)\n",
    "                logdet_Sigma = np.sum(np.log(np.maximum(eigvals, 1e-50)))\n",
    "                Sigma_inv = np.linalg.pinv(Sigma)\n",
    "            except np.linalg.LinAlgError:\n",
    "                return -np.inf\n",
    "            \n",
    "    x_minus_mu = x.reshape(D,1) - mu.reshape(D,1)\n",
    "    term = -0.5 * x_minus_mu.T @ Sigma_inv @ x_minus_mu\n",
    "    norm_const = -0.5 * D * np.log(2 * np.pi) - 0.5 * logdet_Sigma\n",
    "\n",
    "    return (norm_const + term).item()\n",
    "\n",
    "### Bayesian Linear Regression Model Utility Functions (Student-t Error) ###\n",
    "\n",
    "def student_t_log_pdf_scalar(y_scalar, loc, scale_sq, nu):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    if nu <= 1e-1 or scale_sq <= 1e-12:\n",
    "        return -np.inf\n",
    "    \n",
    "    log_const = scipy.special.gammaln((nu + 1) / 2) - scipy.special.gammaln(nu / 2) - 0.5 * np.log(np.pi * nu * scale_sq)\n",
    "\n",
    "    delta_sq_norm = (y_scalar - loc)**2 / (nu * scale_sq)\n",
    "\n",
    "    log_main = -((nu + 1) / 2) * np.log(1 + delta_sq_norm)\n",
    "\n",
    "    val = log_const + log_main\n",
    "\n",
    "    return val if np.isfinite(val) else -1e20\n",
    "\n",
    "def potential_U_studentT(theta_phi, w, X, y):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    phi1, phi2, phi3 = theta_phi\n",
    "    sigma_sq, alpha, nu = np.exp(phi1), np.exp(phi2), np.exp(phi3)\n",
    "    M, D = X.shape\n",
    "    w = w.reshape(D, 1)\n",
    "\n",
    "    nu_calc = max(nu, 0.1)\n",
    "    sigma_sq_calc = max(sigma_sq, 1e-12)\n",
    "    alpha_calc = max(alpha, 1e-12)\n",
    "    \n",
    "    log_likelihood_sum = 0\n",
    "    y_pred = X @ w\n",
    "\n",
    "    for i in range(M):\n",
    "        log_likelihood_sum += student_t_log_pdf_scalar(y[i, 0], y_pred[i, 0], sigma_sq_calc, nu_calc)\n",
    "\n",
    "    neg_log_likelihood = -log_likelihood_sum\n",
    "    \n",
    "    log_prior_w_term = (0.5 * alpha_calc) * np.sum(w ** 2)\n",
    "    const_P_w = (D / 2) * (np.log(2 * np.pi) - phi2)\n",
    "    lambda_nu_prior = 0.1\n",
    "    neg_log_prior_nu = lambda_nu_prior * nu - np.log(lambda_nu_prior) if nu > 0 else np.inf\n",
    "\n",
    "    val = neg_log_likelihood + log_prior_w_term + const_P_w + neg_log_prior_nu\n",
    "\n",
    "    return val if np.isfinite(val) else 1e20\n",
    "\n",
    "def grad_U_w_studentT(theta_phi, w, X, y):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    phi1, phi2, phi3 = theta_phi\n",
    "    sigma_sq, alpha, nu = np.exp(phi1), np.exp(phi2), np.exp(phi3)\n",
    "    M, D_feat = X.shape\n",
    "    w = w.reshape(D_feat, 1)\n",
    "\n",
    "    nu_calc = max(nu, 0.1)\n",
    "    sigma_sq_calc = max(sigma_sq, 1e-12)\n",
    "    alpha_calc = max(alpha, 1e-12)\n",
    "\n",
    "    y_pred = X @ w\n",
    "    delta = y - y_pred\n",
    "    \n",
    "    numerator = (nu_calc + 1) * (-delta)\n",
    "    denominator = (nu_calc * sigma_sq_calc + delta**2)\n",
    "    denominator[denominator < 1e-12] = 1e-12\n",
    "    \n",
    "    term_in_sum = numerator / denominator\n",
    "    grad_w_likelihood = X.T @ term_in_sum\n",
    "\n",
    "    grad_w_prior = alpha_calc * w\n",
    "    grad = grad_w_likelihood + grad_w_prior\n",
    "\n",
    "    return np.nan_to_num(grad, nan=0.0, posinf=1e10, neginf=-1e10)\n",
    "\n",
    "def grad_U_theta_phi_studentT(theta_phi, w, X, y):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    phi1, phi2, phi3 = theta_phi\n",
    "    sigma_sq, alpha, nu = np.exp(phi1), np.exp(phi2), np.exp(phi3)\n",
    "    M, D = X.shape\n",
    "    w = w.reshape(D,1)\n",
    "\n",
    "    nu_calc = max(nu, 0.1)\n",
    "    sigma_sq_calc = max(sigma_sq, 1e-12) \n",
    "    alpha_calc = max(alpha, 1e-12)\n",
    "\n",
    "    y_pred = X @ w\n",
    "    delta_sq = (y - y_pred) ** 2\n",
    "\n",
    "    A_i_for_phi1 = delta_sq / (nu_calc * sigma_sq_calc + 1e-100)\n",
    "    term_for_phi1 = (nu_calc + 1) * A_i_for_phi1 / (1 + A_i_for_phi1 + 1e-100)\n",
    "    grad_phi1 = np.sum(0.5 * (1 - term_for_phi1))\n",
    "    \n",
    "    grad_phi2 = (0.5 * alpha_calc) * np.sum(w**2) - D / 2\n",
    "\n",
    "    digamma_term_sum = M * 0.5 * (scipy.special.digamma((nu_calc + 1) / 2) - scipy.special.digamma(nu_calc / 2))\n",
    "    inv_nu_term_sum = - M / (2 * nu_calc)\n",
    "    \n",
    "    A_i_for_phi3 = delta_sq / (nu_calc * sigma_sq_calc + 1e-100)\n",
    "    log_1_plus_A_sum = 0.5 * np.sum(np.log(1 + A_i_for_phi3 + 1e-100))\n",
    "    \n",
    "    frac_term_A_phi3 = A_i_for_phi3 / (1 + A_i_for_phi3 + 1e-100)\n",
    "    last_term_sum = ((nu_calc + 1)/(2 *nu_calc)) * np.sum(frac_term_A_phi3)\n",
    "    \n",
    "    sum_dLidnu = digamma_term_sum + inv_nu_term_sum - log_1_plus_A_sum + last_term_sum\n",
    "    \n",
    "    lambda_nu_prior = 0.1\n",
    "    grad_prior_nu_term_d_nu = lambda_nu_prior\n",
    "    \n",
    "    grad_nu = -sum_dLidnu + grad_prior_nu_term_d_nu\n",
    "    grad_phi3 = grad_nu * nu\n",
    "\n",
    "    grads = np.array([grad_phi1, grad_phi2, grad_phi3])\n",
    "\n",
    "    return np.nan_to_num(grads, nan=0.0, posinf=1e10, neginf=-1e10)\n",
    "\n",
    "def initialize_particles_studentT(\n",
    "        theta_phi, \n",
    "        X, \n",
    "        y, \n",
    "        N_particles, \n",
    "        D_features,\n",
    "        mcmc_steps=200, \n",
    "        mcmc_step_size_init=0.001\n",
    "    ):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    print(f\"Initializing {N_particles} Student-T particles using {mcmc_steps} MCMC steps each...\")\n",
    "    \n",
    "    particles = np.zeros((N_particles, D_features))\n",
    "    alpha_init = np.exp(theta_phi[1]) if len(theta_phi) > 1 else 1.0\n",
    "    alpha_init = max(alpha_init, 1e-9)\n",
    "    std_dev_init = 1.0 / np.sqrt(alpha_init)\n",
    "\n",
    "    for i in range(N_particles):\n",
    "        w_current = np.random.normal(loc=0, scale=std_dev_init, size=(D_features,1))\n",
    "        current_mcmc_step_size = mcmc_step_size_init\n",
    "\n",
    "        for step in range(mcmc_steps):\n",
    "            grad_log_target_w = -grad_U_w_studentT(theta_phi, w_current, X, y)\n",
    "\n",
    "            noise = np.random.normal(size=(D_features,1))\n",
    "            w_proposal = w_current + 0.5 * current_mcmc_step_size * grad_log_target_w + np.sqrt(current_mcmc_step_size) * noise\n",
    "\n",
    "            if np.linalg.norm(grad_log_target_w) > D_features * 1000 and current_mcmc_step_size > 1e-6:\n",
    "                current_mcmc_step_size *= 0.9\n",
    "\n",
    "            elif np.linalg.norm(grad_log_target_w) < D_features * 10 and current_mcmc_step_size < 0.1:\n",
    "                 current_mcmc_step_size *= 1.05\n",
    "\n",
    "            w_current = w_proposal\n",
    "        particles[i,:] = w_current.flatten()\n",
    "\n",
    "    return particles\n",
    "\n",
    "\n",
    "def estimate_Z0_studentT_MC_IS(theta_phi, X, y, D_features, S_samples=5000):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    print(f\"Estimating Z0 for Student-T model using {S_samples} Importance Samples...\")\n",
    "\n",
    "    phi1, phi2, phi3 = theta_phi\n",
    "    sigma_sq_0, alpha_0, nu_0 = np.exp(phi1), np.exp(phi2), np.exp(phi3)\n",
    "\n",
    "    sigma_sq_q = max(sigma_sq_0, 1e-12)\n",
    "    alpha_q = max(alpha_0, 1e-12)\n",
    "\n",
    "    Sigma_q_inv = alpha_q * np.eye(D_features) + (1 / sigma_sq_q) * (X.T @ X)\n",
    "    Sigma_q_inv += 1e-9 * np.eye(D_features)\n",
    "\n",
    "    try:\n",
    "        Sigma_q = np.linalg.inv(Sigma_q_inv)\n",
    "\n",
    "    except np.linalg.LinAlgError:\n",
    "        Sigma_q = np.linalg.pinv(Sigma_q_inv)\n",
    "\n",
    "    mu_q = Sigma_q @ ((1 / sigma_sq_q) * X.T @ y)\n",
    "    mu_q = mu_q.flatten()\n",
    "\n",
    "    # Sample w_s from q(w)\n",
    "    try:\n",
    "        w_samples = np.random.multivariate_normal(mu_q, Sigma_q, size=S_samples)\n",
    "\n",
    "    except np.linalg.LinAlgError:\n",
    "        print(\"Warning: Sampling from q(w) failed in Student-T Z0 IS. Using approximate samples.\")\n",
    "        w_samples = np.random.multivariate_normal(\n",
    "            np.zeros(D_features), np.eye(D_features) / alpha_q, size=S_samples) # Fallback to prior-like sampling\n",
    "        \n",
    "    log_importance_weights = np.zeros(S_samples)\n",
    "\n",
    "    for s in range(S_samples):\n",
    "        w_s = w_samples[s,:].reshape(D_features,1)\n",
    "        \n",
    "        # Log-likelihood\n",
    "        current_log_L_student_t = 0\n",
    "        y_pred_s = X @ w_s\n",
    "        for i_obs in range(X.shape[0]): \n",
    "            current_log_L_student_t += student_t_log_pdf_scalar(y[i_obs,0], y_pred_s[i_obs,0], sigma_sq_0, nu_0)\n",
    "        \n",
    "        # Log-prior\n",
    "        log_prior_w_strict = -0.5 * D_features * np.log(2 * np.pi) \\\n",
    "                             +0.5 * D_features * np.log(max(alpha_0, 1e-12)) \\\n",
    "                             -0.5 * max(alpha_0, 1e-12) * np.sum(w_s ** 2)\n",
    "\n",
    "\n",
    "        # Log-proposal\n",
    "        log_q_w_s = scipy.stats.multivariate_normal.logpdf(w_samples[s,:], mean=mu_q, cov=Sigma_q, allow_singular=True)\n",
    "\n",
    "        if np.isinf(current_log_L_student_t) or np.isinf(log_prior_w_strict) or np.isinf(log_q_w_s):\n",
    "            log_importance_weights[s] = -np.inf\n",
    "\n",
    "        else:\n",
    "            log_importance_weights[s] = current_log_L_student_t + log_prior_w_strict - log_q_w_s\n",
    "            \n",
    "    log_Z0_est = scipy.special.logsumexp(log_importance_weights) - np.log(S_samples)\n",
    "    \n",
    "    return log_Z0_est if np.isfinite(log_Z0_est) else -1e20"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2434e4f8",
   "metadata": {},
   "source": [
    "# JALA-EM\n",
    "\n",
    "Implementation of the JALA-EM algorithm, as described in \"Learning Latent Variable Models via Jarzynski-adjusted Langevin Algorithm\" (see: https://arxiv.org/pdf/2505.18427). Specifically see Appendix C.4 of the aforementioned paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "043125c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_jala_em(\n",
    "        X_data, \n",
    "        y_data,\n",
    "        model_fns,\n",
    "        K_iters, \n",
    "        N_particles,\n",
    "        h_langevin,\n",
    "        opt_learning_rate=0.01, \n",
    "        ess_threshold_frac=0.5,\n",
    "        use_resampling=True, \n",
    "        D_features=None,\n",
    "    ):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    M, D = X_data.shape\n",
    "\n",
    "    if D_features is None:\n",
    "        D_features = D\n",
    "\n",
    "    elif D_features != D:\n",
    "        raise ValueError(f\"D_features ({D_features}) in model_fns does not match X_data.shape[1] ({D})!!!!\")\n",
    "\n",
    "    theta_phi_k = np.array(model_fns['theta_phi_initial'], dtype=float)\n",
    "\n",
    "    theta_phi_history = [np.copy(theta_phi_k)]\n",
    "    log_Z_k_est_history = []\n",
    "\n",
    "    print(f\"\\nInitialising for model: {model_fns['model_name']}\")\n",
    "    print(f\"Initial theta_phi ({model_fns['model_name']}): {theta_phi_k}\")\n",
    "\n",
    "    start_init_particles_time = time.time()\n",
    "    particles_X_k = model_fns['initialize_particles_fn'](theta_phi_k, X_data, y_data, N_particles, D_features)\n",
    "    print(f\"Particle initialisation took {time.time() - start_init_particles_time:.2f}s\")\n",
    "\n",
    "    A_k = np.zeros(N_particles)\n",
    "\n",
    "    m_opt = np.zeros(model_fns['D_theta'])\n",
    "    beta_momentum = 0.9  # Use \\beta_{1} = 0.9 for Adam\n",
    "\n",
    "    start_Z0_time = time.time()\n",
    "    \n",
    "    log_Z_0_val = model_fns['estimate_Z0_fn'](theta_phi_k, X_data, y_data, D_features)\n",
    "\n",
    "    print(f\"Z0 estimation took {time.time() - start_Z0_time:.2f}s\")\n",
    "\n",
    "    if np.isnan(log_Z_0_val) or np.isinf(log_Z_0_val):\n",
    "        print(f\"Warning ({model_fns['model_name']}): Initial log_Z_0 is {log_Z_0_val}. Using 0.0 as fallback.\")\n",
    "        log_Z_0_val = 0.0\n",
    "\n",
    "    print(f\"Estimated log_Z_0 ({model_fns['model_name']}): {log_Z_0_val:.4f}\")\n",
    "\n",
    "    current_log_Z_k_est = log_Z_0_val + (scipy.special.logsumexp(A_k) - np.log(N_particles if N_particles > 0 else 1))\n",
    "    log_Z_k_est_history.append(current_log_Z_k_est)\n",
    "\n",
    "    potential_U_fn = model_fns['potential_U_fn']\n",
    "    grad_U_theta_phi_fn = model_fns['grad_U_theta_phi_fn']\n",
    "    grad_U_w_fn = model_fns['grad_U_w_fn']\n",
    "    phi_clips = model_fns.get('phi_clips', [(-20, 20)] * model_fns['D_theta'])\n",
    "\n",
    "    for k in range(K_iters):\n",
    "        start_time_iter = time.time()\n",
    "\n",
    "        max_A_k = np.max(A_k) if len(A_k) > 0 else 0\n",
    "        shifted_A_k = A_k - max_A_k\n",
    "        exp_A_k_shifted = np.exp(shifted_A_k)\n",
    "        sum_exp_A_k_shifted = np.sum(exp_A_k_shifted)\n",
    "\n",
    "        if sum_exp_A_k_shifted < 1e-100 or np.isnan(sum_exp_A_k_shifted) or N_particles == 0:\n",
    "            W_k = np.ones(N_particles) / N_particles if N_particles > 0 else np.array([])\n",
    "        else:\n",
    "            W_k = exp_A_k_shifted / sum_exp_A_k_shifted\n",
    "\n",
    "        g_k_accum = np.zeros(model_fns['D_theta'])\n",
    "\n",
    "        if N_particles > 0:\n",
    "            for i in range(N_particles):\n",
    "                current_particle_w = particles_X_k[i,:].reshape(D_features,1)\n",
    "                g_k_accum += W_k[i] * grad_U_theta_phi_fn(theta_phi_k, current_particle_w, X_data, y_data)\n",
    "\n",
    "        if np.any(np.isnan(g_k_accum)) or np.any(np.isinf(g_k_accum)):\n",
    "            g_k_accum = np.zeros_like(g_k_accum)\n",
    "\n",
    "        m_opt = beta_momentum * m_opt + (1 - beta_momentum) * g_k_accum\n",
    "        \n",
    "        theta_phi_k_plus_1_from_grad = theta_phi_k - opt_learning_rate * m_opt\n",
    "        theta_phi_k_plus_1 = np.copy(theta_phi_k_plus_1_from_grad)\n",
    "\n",
    "        log_clipping_this_iter = False\n",
    "        clipping_details_str = []\n",
    "\n",
    "        for dim_idx in range(model_fns['D_theta']):\n",
    "            val_before_clip = theta_phi_k_plus_1_from_grad[dim_idx]\n",
    "            clip_min, clip_max = phi_clips[dim_idx]\n",
    "            val_after_clip = np.clip(val_before_clip, clip_min, clip_max)\n",
    "            theta_phi_k_plus_1[dim_idx] = val_after_clip\n",
    "\n",
    "            if not np.isclose(val_before_clip, val_after_clip):\n",
    "                log_clipping_this_iter = True\n",
    "                clip_type = \"low\" if np.isclose(val_after_clip, clip_min) else \"high\"\n",
    "                clipping_details_str.append(\n",
    "                    f\"phi[{dim_idx}]: {val_before_clip:.4f} -> {val_after_clip:.4f} (clipped {clip_type})\"\n",
    "                )\n",
    "        \n",
    "        # Debugging stuff...\n",
    "        if k < 15 :\n",
    "            if log_clipping_this_iter:\n",
    "                print(f\"DEBUG Model: {model_fns['model_name']}, Iter {k+1}, CLIPPING OCCURRED:\")\n",
    "                for detail in clipping_details_str:\n",
    "                    print(f\"  {detail}\")\n",
    "                print(f\"  g_k_accum: {[f'{x:.3e}' for x in g_k_accum]}\")\n",
    "                print(f\"  m_opt: {[f'{x:.3e}' for x in m_opt]}\")\n",
    "                print(f\"  theta_phi_k (before this M-step): {[f'{x:.3f}' for x in theta_phi_k]}\")\n",
    "                print(f\"  theta_phi_k_plus_1_from_grad: {[f'{x:.3f}' for x in theta_phi_k_plus_1_from_grad]}\")\n",
    "                print(f\"  theta_phi_k_plus_1 (after clip): {[f'{x:.3f}' for x in theta_phi_k_plus_1]}\")\n",
    "            else:\n",
    "                 print(f\"DEBUG Model: {model_fns['model_name']}, Iter {k+1}, No clipping. theta_phi_k: {[f'{x:.3f}' for x in theta_phi_k]}, g_k_accum: {[f'{x:.3e}' for x in g_k_accum]}, m_opt: {[f'{x:.3e}' for x in m_opt]}, theta_phi_k+1: {[f'{x:.3f}' for x in theta_phi_k_plus_1]}\")\n",
    "        elif log_clipping_this_iter and (k + 1) % (K_iters // 20 or 1) == 0 :\n",
    "             print(f\"DEBUG Model: {model_fns['model_name']}, Iter {k+1}, CLIPPING OCCURRED (on reporting interval):\")\n",
    "             for detail in clipping_details_str:\n",
    "                print(f\"  {detail}\")\n",
    "\n",
    "        # End of debugging stuff...\n",
    "\n",
    "        particles_X_k_plus_1 = np.zeros_like(particles_X_k)\n",
    "        A_k_plus_1 = np.zeros_like(A_k)\n",
    "        u_next_vals_k0 = []\n",
    "        a_k_plus_1_vals_k0 = []\n",
    "\n",
    "        if N_particles > 0:\n",
    "            for i in range(N_particles):\n",
    "                x_curr = particles_X_k[i,:].reshape(D_features,1)\n",
    "                U_curr_val = potential_U_fn(theta_phi_k, x_curr, X_data, y_data)\n",
    "                grad_w_U_curr = grad_U_w_fn(theta_phi_k, x_curr, X_data, y_data)\n",
    "                xi_noise = np.random.normal(size=(D_features,1))\n",
    "                x_next_proposal = x_curr - h_langevin * grad_w_U_curr + np.sqrt(2 * h_langevin) * xi_noise\n",
    "\n",
    "                if np.any(np.isnan(x_next_proposal)) or np.any(np.isinf(x_next_proposal)):\n",
    "                     x_next = np.copy(x_curr)\n",
    "                     current_A_update_val = A_k[i] - 1000\n",
    "                     A_k_plus_1[i] = current_A_update_val\n",
    "                else:\n",
    "                    x_next = x_next_proposal\n",
    "                    particles_X_k_plus_1[i,:] = x_next.flatten()\n",
    "                    U_next_val = potential_U_fn(theta_phi_k_plus_1, x_next, X_data, y_data)\n",
    "                    grad_w_U_next = grad_U_w_fn(theta_phi_k_plus_1, x_next, X_data, y_data)\n",
    "                    \n",
    "                    if U_curr_val >= 1e19 or U_next_val >= 1e19 :\n",
    "                         current_A_update_val = A_k[i] - (U_next_val - U_curr_val)\n",
    "                         A_k_plus_1[i] = current_A_update_val\n",
    "                    else:\n",
    "                        term1_dot = np.dot((x_curr - x_next).flatten(), grad_w_U_next.flatten())\n",
    "                        term2_dot = np.dot((x_next - x_curr).flatten(), grad_w_U_curr.flatten())\n",
    "                        alpha_term_k_plus_1 = U_next_val + 0.5 * term1_dot + (h_langevin / 4) * np.sum(grad_w_U_next ** 2)\n",
    "                        alpha_term_k = U_curr_val + 0.5 * term2_dot + (h_langevin / 4) * np.sum(grad_w_U_curr ** 2)\n",
    "                        \n",
    "                        if (np.isinf(alpha_term_k_plus_1) and np.isinf(alpha_term_k) and\n",
    "                            np.sign(alpha_term_k_plus_1) != np.sign(alpha_term_k)) or \\\n",
    "                           np.isnan(alpha_term_k_plus_1) or np.isnan(alpha_term_k):\n",
    "                            current_A_update_val = A_k[i] - 1000\n",
    "                            A_k_plus_1[i] = current_A_update_val\n",
    "                        else:\n",
    "                            current_A_update_val = A_k[i] - alpha_term_k_plus_1 + alpha_term_k\n",
    "                            A_k_plus_1[i] = current_A_update_val\n",
    "                \n",
    "                if k == 0 and i < 5 :\n",
    "                    if 'U_next_val' in locals(): # Ensure U_next_val was defined\n",
    "                        u_next_vals_k0.append(U_next_val)\n",
    "                    else: # U_next_val was not computed\n",
    "                        u_next_vals_k0.append(np.nan)\n",
    "                        \n",
    "                    a_k_plus_1_vals_k0.append(current_A_update_val)\n",
    "\n",
    "                    if i == 4:\n",
    "                        print(f\"DEBUG Model: {model_fns['model_name']}, Iter {k+1}, Particle E-step details (first 5 particles):\")\n",
    "                        phi1_1_eff = theta_phi_k_plus_1[0]\n",
    "                        print(f\"  theta_phi_k_plus_1 (used for U_next): {[f'{x:.3f}' for x in theta_phi_k_plus_1]}, effective sigma_sq_1 for U_next: {np.exp(phi1_1_eff):.2e}\")\n",
    "                        print(f\"  U_next_vals for first 5 particles: {[f'{v:.2f}' for v in u_next_vals_k0]}\")\n",
    "                        print(f\"  A_k_plus_1 vals for first 5 particles: {[f'{v:.2f}' for v in a_k_plus_1_vals_k0]}\")\n",
    "            if k == 0 and N_particles > 0 and N_particles <=5 and len(a_k_plus_1_vals_k0) < 5 and len(a_k_plus_1_vals_k0) > 0 : # In case `N_particles` is less than 5 but > 0\n",
    "                 print(f\"DEBUG Model: {model_fns['model_name']}, Iter {k+1}, Particle E-step details (all {N_particles} particles):\")\n",
    "                 phi1_1_eff = theta_phi_k_plus_1[0]\n",
    "                 print(f\"  theta_phi_k_plus_1 (used for U_next): {[f'{x:.3f}' for x in theta_phi_k_plus_1]}, effective sigma_sq_1 for U_next: {np.exp(phi1_1_eff):.2e}\")\n",
    "                 print(f\"  U_next_vals for {N_particles} particles: {[f'{v:.2f}' for v in u_next_vals_k0]}\")\n",
    "                 print(f\"  A_k_plus_1 vals for {N_particles} particles: {[f'{v:.2f}' for v in a_k_plus_1_vals_k0]}\")\n",
    "\n",
    "\n",
    "        theta_phi_k = np.copy(theta_phi_k_plus_1)\n",
    "        A_k = np.copy(A_k_plus_1)\n",
    "        particles_X_k = np.copy(particles_X_k_plus_1)\n",
    "\n",
    "        if use_resampling and N_particles > 0:\n",
    "            max_A_k_resample = np.max(A_k) if len(A_k) > 0 else 0 # ensure A_k not empty\n",
    "            exp_A_k_resample_shifted = np.exp(A_k - max_A_k_resample)\n",
    "            sum_exp_A_k_resample = np.sum(exp_A_k_resample_shifted)\n",
    "\n",
    "            if sum_exp_A_k_resample < 1e-100 or np.isnan(sum_exp_A_k_resample) or N_particles == 0:\n",
    "                W_k_resample = np.ones(N_particles) / N_particles if N_particles > 0 else np.array([])\n",
    "            else:\n",
    "                W_k_resample = exp_A_k_resample_shifted / sum_exp_A_k_resample\n",
    "            \n",
    "            W_k_resample = np.maximum(W_k_resample, 0)\n",
    "            W_k_resample_sum = np.sum(W_k_resample)\n",
    "\n",
    "            if W_k_resample_sum < 1e-9:\n",
    "                 W_k_resample = np.ones(N_particles) / N_particles if N_particles > 0 else np.array([])\n",
    "            else:\n",
    "                 W_k_resample = W_k_resample / W_k_resample_sum\n",
    "\n",
    "            if np.any(np.isnan(W_k_resample)):\n",
    "                 W_k_resample = np.ones(N_particles) / N_particles if N_particles > 0 else np.array([])\n",
    "            \n",
    "            ess = 1.0 / np.sum(W_k_resample**2) if np.sum(W_k_resample ** 2) > 0 else (N_particles if N_particles > 0 else 1.0)\n",
    "\n",
    "            if ess < ess_threshold_frac * N_particles and N_particles > 0:\n",
    "                try:\n",
    "                    current_sum_W_k_resample = np.sum(W_k_resample)\n",
    "                    if not np.isclose(current_sum_W_k_resample, 1.0) or current_sum_W_k_resample <= 1e-9:\n",
    "                        if current_sum_W_k_resample <= 1e-9 or np.isnan(current_sum_W_k_resample):\n",
    "                            W_k_resample_norm = np.ones(N_particles) / N_particles\n",
    "                        else: # Sum positive but not 1\n",
    "                            W_k_resample_norm = W_k_resample / current_sum_W_k_resample\n",
    "                    else: # Sum close to 1\n",
    "                        W_k_resample_norm = W_k_resample\n",
    "                    \n",
    "                    if np.any(np.isnan(W_k_resample_norm)) or not np.all(np.isfinite(W_k_resample_norm)) or not np.all(W_k_resample_norm >=0): # Check valid probabilities\n",
    "                        W_k_resample_norm = np.ones(N_particles) / N_particles\n",
    "                    \n",
    "                    # Ensure sum is 1 after all normalisations\n",
    "                    if not np.isclose(np.sum(W_k_resample_norm), 1.0):\n",
    "                        W_k_resample_norm = np.ones(N_particles) / N_particles\n",
    "                    \n",
    "                    indices = np.random.choice(N_particles, size=N_particles, p=W_k_resample_norm, replace=True)\n",
    "                    particles_X_k = particles_X_k[indices, :]\n",
    "                    A_k = A_k[indices]\n",
    "\n",
    "                except ValueError as e:\n",
    "                    pass\n",
    "\n",
    "\n",
    "        theta_phi_history.append(np.copy(theta_phi_k))\n",
    "        \n",
    "        current_log_Z_k_est = log_Z_0_val + (scipy.special.logsumexp(A_k) - np.log(N_particles if N_particles > 0 else 1))\n",
    "        log_Z_k_est_history.append(current_log_Z_k_est)\n",
    "        \n",
    "        if (k + 1) % (K_iters // 10 or 1) == 0:\n",
    "            iter_time = time.time() - start_time_iter\n",
    "            phi_str = \", \".join([f\"{x:.3f}\" for x in theta_phi_k])\n",
    "            print(f\"Model: {model_fns['model_name']}, Iter {k+1}/{K_iters}, theta_phi = [{phi_str}], log_Z_est = {current_log_Z_k_est:.3f}, iter_time={iter_time:.2f}s\")\n",
    "\n",
    "    return np.array(theta_phi_history), np.array(log_Z_k_est_history)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f179253e",
   "metadata": {},
   "source": [
    "# Plotting Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e3f9084",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _set_logZ_ylimits(ax, log_Z_hist_list):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    valid_Z_logs_list_processed = []\n",
    "\n",
    "    for hist in log_Z_hist_list:\n",
    "        if hist is not None:\n",
    "            arr_hist = np.array(hist).flatten()\n",
    "            valid_Z_logs_list_processed.append(arr_hist[~np.isnan(arr_hist) & np.isfinite(arr_hist)])\n",
    "\n",
    "    if valid_Z_logs_list_processed:\n",
    "        valid_arrays = [arr for arr in valid_Z_logs_list_processed if len(arr) > 0]\n",
    "\n",
    "        if not valid_arrays:\n",
    "            return\n",
    "        \n",
    "        valid_Z_logs = np.concatenate(valid_arrays)\n",
    "\n",
    "        if len(valid_Z_logs) == 0:\n",
    "            return\n",
    "        \n",
    "        if len(valid_Z_logs) > 2:\n",
    "            min_lim_plot, max_lim_plot = np.percentile(valid_Z_logs, [1, 99])\n",
    "            padding = abs(max_lim_plot - min_lim_plot) * 0.1 + 1.0\n",
    "\n",
    "            if not (np.isnan(min_lim_plot) or np.isnan(max_lim_plot) or min_lim_plot >= max_lim_plot):\n",
    "                 ax.set_ylim([min_lim_plot - padding, max_lim_plot + padding])\n",
    "\n",
    "            elif not (np.isnan(min_lim_plot) or np.isnan(max_lim_plot)):\n",
    "                 ax.set_ylim([min_lim_plot -1 , max_lim_plot + 1])\n",
    "\n",
    "        elif len(valid_Z_logs) > 0:\n",
    "            min_val = np.min(valid_Z_logs)\n",
    "            max_val = np.max(valid_Z_logs)\n",
    "            padding = abs(max_val - min_val) * 0.1 + 1.0 if max_val != min_val else 1.0\n",
    "            ax.set_ylim([min_val - padding, max_val + padding])\n",
    "\n",
    "\n",
    "def populate_comparison_row(\n",
    "    axes_row,\n",
    "    iterations_axis_gauss, \n",
    "    theta_phi_gauss_hist, \n",
    "    log_Z_k_gauss_est_hist,\n",
    "    iterations_axis_student_t, \n",
    "    theta_phi_student_t_hist, \n",
    "    log_Z_k_student_t_hist,\n",
    "    true_params,\n",
    "    theta_phi_ML_gaussian,\n",
    "    log_Z_k_gauss_analytical_hist,\n",
    "    log_Z_true_gen_analytical_gaussian_data,\n",
    "    is_gaussian_data_case,\n",
    "    colors,\n",
    "    linestyles,\n",
    "    line_widths,\n",
    "    legend_fontsize,\n",
    "    subplot_titles,\n",
    "):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    phi_param_names_display = ['$\\\\phi_1$', '$\\\\phi_2$'] # Y-axis labels\n",
    "    true_param_keys = ['log_sigma_sq_gen', 'log_alpha_gen']\n",
    "    \n",
    "    local_legend_handles = [[], [], []]\n",
    "\n",
    "    # PLot 1\n",
    "    ax1 = axes_row[0]\n",
    "    ax1.plot(\n",
    "        iterations_axis_gauss,\n",
    "        theta_phi_gauss_hist[:, 0],\n",
    "        color=colors['gauss_jala'], \n",
    "        linestyle=linestyles['jala'], \n",
    "        linewidth=line_widths['main_trace'], \n",
    "        alpha=0.9,\n",
    "    )\n",
    "    ax1.plot(\n",
    "        iterations_axis_student_t,\n",
    "        theta_phi_student_t_hist[:, 0],\n",
    "        color=colors['st_jala'],\n",
    "        linestyle=linestyles['st_jala'],\n",
    "        linewidth=line_widths['main_trace'],\n",
    "        alpha=0.9,\n",
    "    )\n",
    "\n",
    "    if true_params and true_param_keys[0] in true_params:\n",
    "        true_val = true_params[true_param_keys[0]]\n",
    "        h, = ax1.plot(\n",
    "            [], [],\n",
    "            color=colors['true_gen'],\n",
    "            linestyle=linestyles['true_gen'],\n",
    "            label=f'True {phi_param_names_display[0]}',\n",
    "            linewidth=line_widths['ax_line'],\n",
    "        )\n",
    "        ax1.axhline(\n",
    "            true_val, color=colors['true_gen'], linestyle=linestyles['true_gen'], linewidth=line_widths['ax_line']\n",
    "        )\n",
    "        local_legend_handles[0].append(h)\n",
    "\n",
    "    if is_gaussian_data_case and theta_phi_ML_gaussian is not None and 0 < len(theta_phi_ML_gaussian):\n",
    "        h, = ax1.plot(\n",
    "            [], [], \n",
    "            color=colors['gauss_ml'],\n",
    "            linestyle=linestyles['gauss_ml'],\n",
    "            label='$\\\\mathcal{M}_{G}$ ML-II',\n",
    "            linewidth=line_widths['ax_line'],\n",
    "        )\n",
    "        ax1.axhline(theta_phi_ML_gaussian[0], color=colors['gauss_ml'], linestyle=linestyles['gauss_ml'], linewidth=line_widths['ax_line'])\n",
    "        local_legend_handles[0].append(h)\n",
    "        \n",
    "    ax1.set_ylabel(phi_param_names_display[0])\n",
    "    if subplot_titles and subplot_titles[0]: # If passed...\n",
    "        ax1.set_title(subplot_titles[0])\n",
    "\n",
    "    if local_legend_handles[0]:\n",
    "        ax1.legend(handles=local_legend_handles[0], loc='best', fontsize=legend_fontsize, frameon=False)\n",
    "\n",
    "    elif ax1.get_legend() is not None:\n",
    "        ax1.get_legend().set_visible(False)\n",
    "\n",
    "    ax1.grid(True, linestyle=':', alpha=0.7)\n",
    "\n",
    "\n",
    "    # Plot 2\n",
    "    ax2 = axes_row[1]\n",
    "    ax2.plot(\n",
    "        iterations_axis_gauss,\n",
    "        theta_phi_gauss_hist[:, 1],\n",
    "        color=colors['gauss_jala'], \n",
    "        linestyle=linestyles['jala'],\n",
    "        linewidth=line_widths['main_trace'],\n",
    "        alpha=0.9,\n",
    "    )\n",
    "    ax2.plot(\n",
    "        iterations_axis_student_t,\n",
    "        theta_phi_student_t_hist[:, 1],\n",
    "        color=colors['st_jala'],\n",
    "        linestyle=linestyles['st_jala'],\n",
    "        linewidth=line_widths['main_trace'],\n",
    "        alpha=0.9,\n",
    "    )\n",
    "\n",
    "    if true_params and true_param_keys[1] in true_params:\n",
    "        true_val = true_params[true_param_keys[1]]\n",
    "        h, = ax2.plot(\n",
    "            [], [], \n",
    "            color=colors['true_gen'], \n",
    "            linestyle=linestyles['true_gen'],\n",
    "            label=f'True {phi_param_names_display[1]}',\n",
    "            linewidth=line_widths['ax_line'],\n",
    "        )\n",
    "        ax2.axhline(\n",
    "            true_val, color=colors['true_gen'], linestyle=linestyles['true_gen'], linewidth=line_widths['ax_line']\n",
    "        )\n",
    "        local_legend_handles[1].append(h)\n",
    "\n",
    "    if is_gaussian_data_case and theta_phi_ML_gaussian is not None and 1 < len(theta_phi_ML_gaussian):\n",
    "        h, = ax2.plot(\n",
    "            [], [],\n",
    "            color=colors['gauss_ml'],\n",
    "            linestyle=linestyles['gauss_ml'],\n",
    "            label='$\\\\mathcal{M}_{G}$ ML-II',\n",
    "            linewidth=line_widths['ax_line'],\n",
    "        )\n",
    "        ax2.axhline(theta_phi_ML_gaussian[1], color=colors['gauss_ml'], linestyle=linestyles['gauss_ml'], linewidth=line_widths['ax_line'])\n",
    "        local_legend_handles[1].append(h)\n",
    "\n",
    "    ax2.set_ylabel(phi_param_names_display[1])\n",
    "\n",
    "    if subplot_titles and subplot_titles[1]:\n",
    "        ax2.set_title(subplot_titles[1])\n",
    "\n",
    "    if local_legend_handles[1]:\n",
    "        ax2.legend(handles=local_legend_handles[1], loc='best', fontsize=legend_fontsize, frameon=False)\n",
    "\n",
    "    elif ax2.get_legend() is not None:\n",
    "        ax2.get_legend().set_visible(False)\n",
    "\n",
    "    ax2.grid(True, linestyle=':', alpha=0.7)\n",
    "\n",
    "\n",
    "    # Plot 3\n",
    "    ax3 = axes_row[2]\n",
    "\n",
    "    ax3.plot(\n",
    "        iterations_axis_gauss,\n",
    "        log_Z_k_gauss_est_hist,\n",
    "        color=colors['gauss_jala'],\n",
    "        linestyle=linestyles['jala'],\n",
    "        linewidth=line_widths['main_trace'],\n",
    "        alpha=0.9,\n",
    "    )\n",
    "\n",
    "    ax3.plot(\n",
    "        iterations_axis_student_t,\n",
    "        log_Z_k_student_t_hist,\n",
    "        color=colors['st_jala'],\n",
    "        linestyle=linestyles['st_jala'],\n",
    "        linewidth=line_widths['main_trace'],\n",
    "        alpha=0.9,\n",
    "    )\n",
    "\n",
    "    if log_Z_k_gauss_analytical_hist is not None:\n",
    "        h, = ax3.plot(\n",
    "            iterations_axis_gauss,\n",
    "            log_Z_k_gauss_analytical_hist,\n",
    "            label='Analytic $\\\\log Z_{\\mathcal{M}_{G}, \\\\theta_k}$',\n",
    "            color=colors['gauss_analytical_est'],\n",
    "            linestyle=linestyles['analytical_est'],\n",
    "            linewidth=line_widths['main_trace'],\n",
    "            alpha=0.8,\n",
    "        )\n",
    "        local_legend_handles[2].append(h)\n",
    "\n",
    "    if is_gaussian_data_case and log_Z_true_gen_analytical_gaussian_data is not None and np.isfinite(log_Z_true_gen_analytical_gaussian_data):\n",
    "        h, = ax3.plot(\n",
    "            [], [],\n",
    "            color=colors['gauss_analytical_true_gen'],\n",
    "            linestyle=linestyles['analytical_true_gen'],\n",
    "            label='Analytic $\\\\log Z_{{\\mathcal{M}_{G}, \\\\theta_\\\\star}}$',\n",
    "            linewidth=line_widths['ax_line']\n",
    "        )\n",
    "        ax3.axhline(\n",
    "            log_Z_true_gen_analytical_gaussian_data, \n",
    "            color=colors['gauss_analytical_true_gen'],\n",
    "            linestyle=linestyles['analytical_true_gen'],\n",
    "            linewidth=line_widths['ax_line'],\n",
    "            )\n",
    "        local_legend_handles[2].append(h)\n",
    "\n",
    "    ax3.set_ylabel('$\\\\log Z_k$')\n",
    "\n",
    "    if subplot_titles and subplot_titles[2]:\n",
    "        ax3.set_title(subplot_titles[2])\n",
    "\n",
    "    if local_legend_handles[2]:\n",
    "        ax3.legend(handles=local_legend_handles[2], loc='best', fontsize=legend_fontsize, frameon=False)\n",
    "\n",
    "    elif ax3.get_legend() is not None:\n",
    "        ax3.get_legend().set_visible(False)\n",
    "\n",
    "    ax3.grid(True, linestyle=':', alpha=0.7)\n",
    "\n",
    "    logZ_hist_list_for_ylim = [log_Z_k_gauss_est_hist, log_Z_k_student_t_hist, log_Z_k_gauss_analytical_hist]\n",
    "\n",
    "    if is_gaussian_data_case and log_Z_true_gen_analytical_gaussian_data is not None:\n",
    "        logZ_hist_list_for_ylim.append(log_Z_true_gen_analytical_gaussian_data)\n",
    "    \n",
    "    _set_logZ_ylimits(ax3, logZ_hist_list_for_ylim)\n",
    "\n",
    "    for ax_idx, ax in enumerate(axes_row):\n",
    "        ax.set_xlabel('Iteration, $k$')\n",
    "        if not local_legend_handles[ax_idx] and ax.get_legend() is not None:\n",
    "             ax.get_legend().set_visible(False)\n",
    "\n",
    "\n",
    "def plot_stacked(\n",
    "    results_g_data,\n",
    "    results_st_data,\n",
    "    output_filename=\"stacked_comparison_neurips.pdf\",\n",
    "    custom_figsize=None,\n",
    "    custom_font_options=None,\n",
    "    custom_line_options=None,\n",
    "    use_tex=True,\n",
    "):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    if custom_font_options is None:\n",
    "        custom_font_options = {\n",
    "            \"default\": 10,\n",
    "            \"axes_label\": 10,\n",
    "            \"axes_title\": 11,\n",
    "            \"xtick_label\": 9,\n",
    "            \"ytick_label\": 9,\n",
    "            \"legend\": 9,\n",
    "            \"shared_legend\": 10\n",
    "        }\n",
    "    if custom_line_options is None:\n",
    "        custom_line_options = {\"main_trace\": 1.5, \"grid_line\": 0.5, \"ax_line\": 1.5}\n",
    "\n",
    "    current_rc_params = plt.rcParams.copy()\n",
    "\n",
    "    try:\n",
    "        base_font_size = custom_font_options.get(\"default\", 10)\n",
    "        neurips_rc = bundles.neurips2023(nrows=2, ncols=3, family=\"Times New Roman\", usetex=use_tex, rel_width=1.0)\n",
    "        neurips_rc['font.size'] = custom_font_options.get(\"default\", base_font_size)\n",
    "        neurips_rc['axes.labelsize'] = custom_font_options.get(\"axes_label\", base_font_size)\n",
    "        neurips_rc['xtick.labelsize'] = custom_font_options.get(\"xtick_label\", base_font_size * 0.9)\n",
    "        neurips_rc['ytick.labelsize'] = custom_font_options.get(\"ytick_label\", base_font_size * 0.9)\n",
    "        local_legend_fontsize = custom_font_options.get(\"legend\", base_font_size * 0.9)\n",
    "        shared_legend_fontsize = custom_font_options.get(\"shared_legend\", base_font_size)\n",
    "        row_title_fontsize = custom_font_options.get(\"axes_title\", base_font_size * 1.1)\n",
    "\n",
    "        neurips_rc['lines.linewidth'] = custom_line_options.get(\"main_trace\", 1.5)\n",
    "        neurips_rc['grid.linewidth'] = custom_line_options.get(\"grid_line\", 0.5)\n",
    "        \n",
    "        plt.rcParams.update(neurips_rc)\n",
    "        if use_tex:\n",
    "            plt.rcParams['font.family'] = 'serif'\n",
    "            plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']\n",
    "\n",
    "        colors = {\n",
    "            'gauss_jala': 'tab:blue', \n",
    "            'st_jala': 'tab:orange',\n",
    "            'gauss_analytical_est': 'tab:green', \n",
    "            'gauss_analytical_true_gen': 'tab:red',\n",
    "            'true_gen': 'dimgray', \n",
    "            'gauss_ml': 'darkviolet',\n",
    "        }\n",
    "        linestyles = {\n",
    "            'jala': '-', \n",
    "            'st_jala': '--', \n",
    "            'analytical_est': ':',\n",
    "            'analytical_true_gen': '-', \n",
    "            'true_gen': ':', \n",
    "            'gauss_ml': '-.',\n",
    "        }\n",
    "        line_widths_dict = {\n",
    "            'main_trace': custom_line_options.get(\"main_trace\", 1.5),\n",
    "            'ax_line': custom_line_options.get(\"ax_line\", 1.5)\n",
    "        }\n",
    "\n",
    "        figsize_to_use = custom_figsize if custom_figsize else plt.rcParams.get('figure.figsize')\n",
    "        fig, axes = plt.subplots(2, 3, figsize=figsize_to_use, sharey=False)\n",
    "\n",
    "        no_titles = [\"\", \"\", \"\"] # No individual titles!\n",
    "        \n",
    "        # Row 1 -> Data generated from Gaussian\n",
    "        populate_comparison_row(\n",
    "            axes_row=axes[0,:],\n",
    "            iterations_axis_gauss=results_g_data['iter_axis_g'],\n",
    "            theta_phi_gauss_hist=results_g_data['th_phi_g_on_g'],\n",
    "            log_Z_k_gauss_est_hist=results_g_data['logZ_g_on_g'],\n",
    "            iterations_axis_student_t=results_g_data['iter_axis_st'],\n",
    "            theta_phi_student_t_hist=results_g_data['th_phi_st_on_g'],\n",
    "            log_Z_k_student_t_hist=results_g_data['logZ_st_on_g'],\n",
    "            true_params=results_g_data['true_params'],\n",
    "            theta_phi_ML_gaussian=results_g_data['ml_g_on_g'],\n",
    "            log_Z_k_gauss_analytical_hist=results_g_data['logZ_an_g_on_g'],\n",
    "            log_Z_true_gen_analytical_gaussian_data=results_g_data.get('true_log_Z_analytical_gen_g_data'),\n",
    "            is_gaussian_data_case=True,\n",
    "            colors=colors, linestyles=linestyles, line_widths=line_widths_dict, legend_fontsize=local_legend_fontsize,\n",
    "            subplot_titles=no_titles,\n",
    "        )\n",
    "        axes[0,1].set_title(\"$\\\\mathcal{G} = \\\\mathcal{M}_{G}$\", fontsize=row_title_fontsize, y=1.03)\n",
    "\n",
    "        # Row 2 -> Data generated from Student-t\n",
    "        populate_comparison_row(\n",
    "            axes_row=axes[1,:],\n",
    "            iterations_axis_gauss=results_st_data['iter_axis_g'],\n",
    "            theta_phi_gauss_hist=results_st_data['th_phi_g_on_st'],\n",
    "            log_Z_k_gauss_est_hist=results_st_data['logZ_g_on_st'],\n",
    "            iterations_axis_student_t=results_st_data['iter_axis_st'],\n",
    "            theta_phi_student_t_hist=results_st_data['th_phi_st_on_st'],\n",
    "            log_Z_k_student_t_hist=results_st_data['logZ_st_on_st'],\n",
    "            true_params=results_st_data['true_params'],\n",
    "            theta_phi_ML_gaussian=results_st_data['ml_g_on_st'],\n",
    "            log_Z_k_gauss_analytical_hist=results_st_data['logZ_an_g_on_st'],\n",
    "            log_Z_true_gen_analytical_gaussian_data=None,\n",
    "            is_gaussian_data_case=False,\n",
    "            colors=colors, linestyles=linestyles, line_widths=line_widths_dict, legend_fontsize=local_legend_fontsize,\n",
    "            subplot_titles=no_titles,\n",
    "        )\n",
    "        axes[1,1].set_title(\"$\\\\mathcal{G} = \\\\mathcal{M}_{T}$\", fontsize=row_title_fontsize, y=1.03)\n",
    "\n",
    "        for j in range(3):\n",
    "            axes[0,j].set_xlabel('')\n",
    "\n",
    "        # Shared legennd!\n",
    "        handles_for_shared_legend = []\n",
    "\n",
    "        if axes[0,0].lines:\n",
    "            handles_for_shared_legend.append(axes[0,0].lines[0])\n",
    "\n",
    "        if len(axes[0,0].lines) > 1:\n",
    "            handles_for_shared_legend.append(axes[0,0].lines[1])\n",
    "        \n",
    "        shared_labels = [\"JALA-EM ($\\\\mathcal{M}_{G}$)\", \"JALA-EM ($\\\\mathcal{M}_{T}$)\"]\n",
    "\n",
    "        if len(handles_for_shared_legend) != len(shared_labels) and handles_for_shared_legend :\n",
    "             handles_for_shared_legend = handles_for_shared_legend[:len(shared_labels)]\n",
    "\n",
    "        fig.subplots_adjust(left=0.08, right=0.98, top=0.92, bottom=0.15, hspace=0.35, wspace=0.25)\n",
    "\n",
    "        if handles_for_shared_legend:\n",
    "            fig.legend(handles_for_shared_legend, shared_labels[:len(handles_for_shared_legend)],\n",
    "                       loc='lower center',\n",
    "                       bbox_to_anchor=(0.5, -0.1),\n",
    "                       ncol=2, \n",
    "                       frameon=False,\n",
    "                       fontsize=shared_legend_fontsize)\n",
    "        else:\n",
    "            print(\"Warning: Issue with shared legend!\")\n",
    "\n",
    "        if output_filename:\n",
    "            plt.savefig(output_filename, bbox_inches='tight', dpi=300)\n",
    "        plt.show()\n",
    "\n",
    "    finally:\n",
    "        plt.rcParams.update(current_rc_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b509eaae",
   "metadata": {},
   "source": [
    "### Running a single experimental trial..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07da9a00",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Data Generation Config ###\n",
    "M_data = 1500\n",
    "D_features = 8\n",
    "true_w_mean_gen = np.zeros(D_features)\n",
    "true_w_var_gen = 1\n",
    "true_w_cov_gen = np.eye(D_features) * true_w_var_gen\n",
    "true_alpha_gen_val = 1.0/true_w_var_gen\n",
    "\n",
    "true_sigma_scale_gen_default = np.sqrt(1.0) # For Gaussian data, and default for Student-t sigma\n",
    "true_nu_gen_val_for_st_data = 4.0 # Specifically for when DATA_GENERATION_TYPE = 'student_t'\n",
    "\n",
    "### Algorithm Config ###\n",
    "K_iterations = 250\n",
    "N_particles_jala = 50\n",
    "h_langevin_step = 0.00005\n",
    "opt_lr = 0.005\n",
    "ess_thresh_frac_jala_script = 0.0\n",
    "\n",
    "### Numerical Stability Config ###\n",
    "gauss_phi_clips = [(-15, 10), (-10, 15)]\n",
    "studentT_phi_clips = [(-15, 10), (-10, 15), (np.log(0.2), np.log(5.0))] # Wider nu range\n",
    "\n",
    "results_for_g_data = {}\n",
    "results_for_st_data = {}\n",
    "\n",
    "# For reproducibility!\n",
    "np.random.seed(420) \n",
    "\n",
    "### Experiment 1: Data Generated from Gaussian Distribution ###\n",
    "print(\"\\nEXPERIMENT 1: DATA GENERATED FROM GAUSSIAN...\")\n",
    "DATA_GEN_TYPE_EXP1 = 'gaussian'\n",
    "true_sigma_sq_gen_exp1 = true_sigma_scale_gen_default ** 2\n",
    "\n",
    "X_g, y_g, _ = generate_synthetic_data(\n",
    "    M_data, \n",
    "    D_features, \n",
    "    true_sigma_scale_gen_default,\n",
    "    {'mean': true_w_mean_gen, 'cov': true_w_cov_gen},\n",
    "    likelihood_type=DATA_GEN_TYPE_EXP1,\n",
    "    true_nu=true_nu_gen_val_for_st_data,\n",
    ")\n",
    "results_for_g_data['true_params'] = {\n",
    "    'log_sigma_sq_gen': np.log(true_sigma_sq_gen_exp1),\n",
    "    'log_alpha_gen': np.log(true_alpha_gen_val),\n",
    "}\n",
    "print(f\"Generated {DATA_GEN_TYPE_EXP1} data: X({X_g.shape}), y({y_g.shape})\")\n",
    "\n",
    "# Calculate true analytical Log Z for the true Gaussian parameters\n",
    "true_theta_phi_for_gen_g_data = np.array([\n",
    "    results_for_g_data['true_params']['log_sigma_sq_gen'],\n",
    "    results_for_g_data['true_params']['log_alpha_gen']\n",
    "])\n",
    "results_for_g_data['true_log_Z_analytical_gen_g_data'] = analytical_marginal_likelihood_gaussian(\n",
    "    true_theta_phi_for_gen_g_data, X_g, y_g, D_features\n",
    ")\n",
    "print(f\"True Analytical Log Z (for Gaussian Data generation): {results_for_g_data['true_log_Z_analytical_gen_g_data']:.3f}\")\n",
    "\n",
    "# Initialise params peterbed from true values\n",
    "initial_phi1_g_exp1 = np.log(true_sigma_sq_gen_exp1) + 1\n",
    "initial_phi2_g_exp1 = np.log(true_alpha_gen_val) + 1\n",
    "initial_phi3_st_exp1 = np.log(5.0) # Default nu when true data is not Student-t\n",
    "\n",
    "# Run Gaussian model on Gaussian data...\n",
    "print(\"\\nRunning Gaussian Model on Gaussian Data...\")\n",
    "model_fns_g_on_g = {\n",
    "    'potential_U_fn': potential_U_gaussian, \n",
    "    'grad_U_theta_phi_fn': grad_U_theta_phi_gaussian,\n",
    "    'grad_U_w_fn': grad_U_w_gaussian, \n",
    "    'initialize_particles_fn': initialize_particles_gaussian,\n",
    "    'estimate_Z0_fn': analytical_marginal_likelihood_gaussian,\n",
    "    'theta_phi_initial': np.array([initial_phi1_g_exp1, initial_phi2_g_exp1]),\n",
    "    'D_theta': 2, \n",
    "    'model_name': 'G_on_G_Data', \n",
    "    'phi_clips': gauss_phi_clips,\n",
    "}\n",
    "\n",
    "th_g_g, lz_g_g = run_jala_em(\n",
    "    X_g,\n",
    "    y_g,\n",
    "    model_fns_g_on_g,\n",
    "    K_iterations,\n",
    "    N_particles_jala,\n",
    "    h_langevin_step,\n",
    "    opt_lr,\n",
    "    ess_thresh_frac_jala_script,\n",
    "    D_features,\n",
    ")\n",
    "results_for_g_data['th_phi_g_on_g'] = th_g_g\n",
    "results_for_g_data['logZ_g_on_g'] = lz_g_g\n",
    "results_for_g_data['iter_axis_g'] = np.arange(th_g_g.shape[0])\n",
    "results_for_g_data['ml_g_on_g'] = get_theta_ML_gaussian(X_g, y_g, model_fns_g_on_g['theta_phi_initial'])\n",
    "results_for_g_data['logZ_an_g_on_g'] = [analytical_marginal_likelihood_gaussian(tp, X_g, y_g, D_features) for tp in th_g_g]\n",
    "\n",
    "# Run Student-t model on Gaussian data...\n",
    "print(\"\\nRunning Student-t Model on Gaussian Data...\")\n",
    "model_fns_st_on_g = {\n",
    "    'potential_U_fn': potential_U_studentT, \n",
    "    'grad_U_theta_phi_fn': grad_U_theta_phi_studentT,\n",
    "    'grad_U_w_fn': grad_U_w_studentT,\n",
    "    'initialize_particles_fn': initialize_particles_studentT,\n",
    "    'estimate_Z0_fn': estimate_Z0_studentT_MC_IS,\n",
    "    'theta_phi_initial': np.array([initial_phi1_g_exp1, initial_phi2_g_exp1, initial_phi3_st_exp1]),\n",
    "    'D_theta': 3, \n",
    "    'model_name': 'ST_on_G_Data', \n",
    "    'phi_clips': studentT_phi_clips,\n",
    "}\n",
    "\n",
    "th_st_g, lz_st_g = run_jala_em(\n",
    "    X_g,\n",
    "    y_g,\n",
    "    model_fns_st_on_g,\n",
    "    K_iterations,\n",
    "    N_particles_jala,\n",
    "    h_langevin_step,\n",
    "    opt_lr,\n",
    "    ess_thresh_frac_jala_script,\n",
    "    D_features,\n",
    ")\n",
    "results_for_g_data['th_phi_st_on_g'] = th_st_g\n",
    "results_for_g_data['logZ_st_on_g'] = lz_st_g\n",
    "results_for_g_data['iter_axis_st'] = np.arange(th_st_g.shape[0])\n",
    "\n",
    "\n",
    "### Experiment 2: Data Generated from Student-t Distribution ###\n",
    "print(\"\\nEXPERIMENT 2: DATA GENERATED FROM STUDENT-T...\")\n",
    "DATA_GEN_TYPE_EXP2 = 'student_t'\n",
    "true_sigma_sq_gen_exp2 = true_sigma_scale_gen_default ** 2 # Sigma for Student-t\n",
    "true_nu_gen_exp2 = true_nu_gen_val_for_st_data\n",
    "\n",
    "X_st, y_st, _ = generate_synthetic_data(\n",
    "    M_data, \n",
    "    D_features, \n",
    "    true_sigma_scale_gen_default,\n",
    "    {'mean': true_w_mean_gen, 'cov': true_w_cov_gen},\n",
    "    likelihood_type=DATA_GEN_TYPE_EXP2,\n",
    "    true_nu=true_nu_gen_exp2,\n",
    ")\n",
    "results_for_st_data['true_params'] = {\n",
    "    'log_sigma_sq_gen': np.log(true_sigma_sq_gen_exp2),\n",
    "    'log_alpha_gen': np.log(true_alpha_gen_val),\n",
    "    'log_nu_gen': np.log(true_nu_gen_exp2) # \\nu is relevant now\n",
    "}\n",
    "print(f\"Generated {DATA_GEN_TYPE_EXP2} data: X({X_st.shape}), y({y_st.shape})\")\n",
    "\n",
    "# Initialise params peterbed from true values\n",
    "initial_phi1_st_exp2 = np.log(true_sigma_sq_gen_exp2) + 1\n",
    "initial_phi2_st_exp2 = np.log(true_alpha_gen_val) + 1\n",
    "initial_phi3_st_exp2_true_nu = np.log(true_nu_gen_exp2) + 1\n",
    "\n",
    "# Run Gaussian model on Student-t data\n",
    "print(\"\\nRunning Gaussian Model on Student-t Data...\")\n",
    "model_fns_g_on_st = {\n",
    "    'potential_U_fn': potential_U_gaussian, \n",
    "    'grad_U_theta_phi_fn': grad_U_theta_phi_gaussian,\n",
    "    'grad_U_w_fn': grad_U_w_gaussian, \n",
    "    'initialize_particles_fn': initialize_particles_gaussian,\n",
    "    'estimate_Z0_fn': analytical_marginal_likelihood_gaussian,\n",
    "    'theta_phi_initial': np.array([initial_phi1_st_exp2, initial_phi2_st_exp2]),\n",
    "    'D_theta': 2, \n",
    "    'model_name': 'G_on_ST_Data', \n",
    "    'phi_clips': gauss_phi_clips,\n",
    "}\n",
    "\n",
    "th_g_st, lz_g_st = run_jala_em(\n",
    "    X_st,\n",
    "    y_st,\n",
    "    model_fns_g_on_st,\n",
    "    K_iterations,\n",
    "    N_particles_jala,\n",
    "    h_langevin_step,\n",
    "    opt_lr,\n",
    "    ess_thresh_frac_jala_script,\n",
    "    D_features,\n",
    ")\n",
    "results_for_st_data['th_phi_g_on_st'] = th_g_st\n",
    "results_for_st_data['logZ_g_on_st'] = lz_g_st\n",
    "results_for_st_data['iter_axis_g'] = np.arange(th_g_st.shape[0])\n",
    "results_for_st_data['ml_g_on_st'] = get_theta_ML_gaussian(X_st, y_st, model_fns_g_on_st['theta_phi_initial'])\n",
    "results_for_st_data['logZ_an_g_on_st'] = [analytical_marginal_likelihood_gaussian(tp, X_st, y_st, D_features) for tp in th_g_st]\n",
    "\n",
    "# Run Student-t model on Student-t data...\n",
    "print(\"\\nRunning Student-t Model on Student-t Data...\") \n",
    "model_fns_st_on_st = {\n",
    "    'potential_U_fn': potential_U_studentT, \n",
    "    'grad_U_theta_phi_fn': grad_U_theta_phi_studentT,\n",
    "    'grad_U_w_fn': grad_U_w_studentT,\n",
    "    'initialize_particles_fn': initialize_particles_studentT,\n",
    "    'estimate_Z0_fn': estimate_Z0_studentT_MC_IS,\n",
    "    'theta_phi_initial': np.array([initial_phi1_st_exp2, initial_phi2_st_exp2, initial_phi3_st_exp2_true_nu]),\n",
    "    'D_theta': 3, \n",
    "    'model_name': 'ST_on_ST_Data',\n",
    "    'phi_clips': studentT_phi_clips,\n",
    "}\n",
    "\n",
    "th_st_st, lz_st_st = run_jala_em(\n",
    "    X_st,\n",
    "    y_st,\n",
    "    model_fns_st_on_st,\n",
    "    K_iterations,\n",
    "    N_particles_jala,\n",
    "    h_langevin_step,\n",
    "    opt_lr,\n",
    "    ess_thresh_frac_jala_script,\n",
    "    D_features,\n",
    ")\n",
    "results_for_st_data['th_phi_st_on_st'] = th_st_st\n",
    "results_for_st_data['logZ_st_on_st'] = lz_st_st\n",
    "results_for_st_data['iter_axis_st'] = np.arange(th_st_st.shape[0])\n",
    "\n",
    "\n",
    "print(\"\\nGenerating Plot...\")\n",
    "\n",
    "stacked_figsize = (12, 7)\n",
    "\n",
    "font_opts = {\n",
    "    \"default\": 16, \n",
    "    \"axes_label\": 16, \n",
    "    \"axes_title\": 16,\n",
    "    \"row_title\": 16, \n",
    "    \"xtick_label\": 14, \n",
    "    \"ytick_label\": 14,\n",
    "    \"legend\": 16, \n",
    "    \"shared_legend\": 16\n",
    "}\n",
    "line_opts = {\n",
    "    \"main_trace\": 2.5, \n",
    "    \"ax_line\": 2.0, \n",
    "    \"grid_line\": 1.5,\n",
    "}\n",
    "stacked_fig_size = (16, 12.0) \n",
    "\n",
    "plot_stacked(\n",
    "    results_g_data=results_for_g_data,\n",
    "    results_st_data=results_for_st_data,\n",
    "    output_filename=f\"stacked_M{M_data}_D{D_features}_for_report.pdf\",\n",
    "    custom_figsize=stacked_fig_size,\n",
    "    custom_font_options=font_opts,\n",
    "    custom_line_options=line_opts,\n",
    "    use_tex=True\n",
    ")\n",
    "\n",
    "print(\"\\nSingle Experimental Trial End!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd2cf756",
   "metadata": {},
   "source": [
    "## Some More Plotting Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7167c3ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _set_logZ_ylimits(ax, log_Z_hist_list):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    valid_Z_logs_list_processed = []\n",
    "    for hist in log_Z_hist_list:\n",
    "        if hist is not None:\n",
    "            arr_hist = np.array(hist).flatten()\n",
    "            valid_Z_logs_list_processed.append(arr_hist[~np.isnan(arr_hist) & np.isfinite(arr_hist)])\n",
    "\n",
    "    if valid_Z_logs_list_processed:\n",
    "        valid_Z_logs = np.concatenate([arr for arr in valid_Z_logs_list_processed if len(arr) > 0])\n",
    "\n",
    "        if len(valid_Z_logs) > 2:\n",
    "            min_lim_plot, max_lim_plot = np.percentile(valid_Z_logs, [1, 99])\n",
    "            padding = abs(max_lim_plot - min_lim_plot) * 0.1 + 1.0\n",
    "\n",
    "            if not (np.isnan(min_lim_plot) or np.isnan(max_lim_plot) or min_lim_plot >= max_lim_plot):\n",
    "                 ax.set_ylim([min_lim_plot - padding, max_lim_plot + padding])\n",
    "\n",
    "            elif not (np.isnan(min_lim_plot) or np.isnan(max_lim_plot)):\n",
    "                 ax.set_ylim([min_lim_plot -1 , max_lim_plot + 1])\n",
    "\n",
    "        elif len(valid_Z_logs) > 0:\n",
    "            min_val = np.min(valid_Z_logs)\n",
    "            max_val = np.max(valid_Z_logs)\n",
    "            padding = abs(max_val - min_val) * 0.1 + 1.0 if max_val != min_val else 1.0\n",
    "            ax.set_ylim([min_val - padding, max_val + padding])\n",
    "\n",
    "\n",
    "def plot_logZ_comparison_side_by_side(\n",
    "    results_g_data,\n",
    "    results_st_data,\n",
    "    output_filename=\"logZ_comparison_final.pdf\",\n",
    "    custom_figsize=None,\n",
    "    custom_font_options=None,\n",
    "    custom_line_options=None\n",
    "):\n",
    "    \"\"\"\n",
    "    TODO\n",
    "    \"\"\"\n",
    "    # Some default options, if not provided!\n",
    "    if custom_font_options is None:\n",
    "        custom_font_options = {\n",
    "            \"default\": 10,\n",
    "            \"axes_label\": 10, \n",
    "            \"axes_title\": 11,\n",
    "            \"xtick_label\": 9,\n",
    "            \"ytick_label\": 9,\n",
    "            \"legend\": 9,\n",
    "        }\n",
    "        \n",
    "    if custom_line_options is None:\n",
    "        custom_line_options = {\n",
    "            \"main_trace\": 1.5,\n",
    "            \"grid_line\": 0.5,\n",
    "            \"ax_line\": 1.5,\n",
    "        }\n",
    "\n",
    "    current_rc_params = plt.rcParams.copy()\n",
    "\n",
    "    try:\n",
    "        base_font_size = custom_font_options.get(\"default\", 10)\n",
    "\n",
    "        neurips_rc = bundles.neurips2023(ncols=2, nrows=1, family=\"serif\", usetex=True)\n",
    "        \n",
    "        neurips_rc['font.size'] = custom_font_options.get(\"default\", base_font_size)\n",
    "        neurips_rc['axes.labelsize'] = custom_font_options.get(\"axes_label\", base_font_size)\n",
    "        neurips_rc['axes.titlesize'] = custom_font_options.get(\"axes_title\", base_font_size * 1.2)\n",
    "        neurips_rc['xtick.labelsize'] = custom_font_options.get(\"xtick_label\", base_font_size * 0.9)\n",
    "        neurips_rc['ytick.labelsize'] = custom_font_options.get(\"ytick_label\", base_font_size * 0.9)\n",
    "        neurips_rc['legend.fontsize'] = custom_font_options.get(\"legend\", base_font_size * 0.9)\n",
    "        \n",
    "        neurips_rc['lines.linewidth'] = custom_line_options.get(\"main_trace\", 1.5)\n",
    "        neurips_rc['grid.linewidth'] = custom_line_options.get(\"grid_line\", 0.5)\n",
    "        ax_line_width = custom_line_options.get(\"ax_line\", 1.5)\n",
    "        main_trace_lw = custom_line_options.get(\"main_trace\", 1.5)\n",
    "\n",
    "        plt.rcParams.update(neurips_rc)\n",
    "\n",
    "        figsize_to_use = custom_figsize if custom_figsize else plt.rcParams[\"figure.figsize\"]\n",
    "\n",
    "        fig, axes = plt.subplots(1, 2, figsize=figsize_to_use, sharey=False)\n",
    "\n",
    "        colors = {\n",
    "            'gauss_jala': 'tab:blue', \n",
    "            'st_jala': 'tab:orange',\n",
    "            'gauss_analytical_est': 'tab:green', \n",
    "            'gauss_analytical_true': 'tab:red',\n",
    "        }\n",
    "        linestyles = {\n",
    "            'jala': '-', \n",
    "            'st_jala': '--',\n",
    "            'analytical_est': ':', \n",
    "            'analytical_true': '-',\n",
    "        }\n",
    "\n",
    "        handles_for_shared_legend = []\n",
    "\n",
    "        # Plot 1 -> Models on Gaussian Data\n",
    "        ax1 = axes[0]\n",
    "        iter_axis_g_g = results_g_data['iter_axis_g']\n",
    "        logZ_g_on_g = results_g_data['logZ_g_on_g']\n",
    "        iter_axis_st_g = results_g_data['iter_axis_st']\n",
    "        logZ_st_on_g = results_g_data['logZ_st_on_g']\n",
    "        logZ_an_g_on_g = results_g_data['logZ_an_g_on_g']\n",
    "        true_logZ_gen_g_data = results_g_data.get('true_log_Z_analytical_gen_g_data')\n",
    "\n",
    "        line_g_est1, = ax1.plot(\n",
    "            iter_axis_g_g, \n",
    "            logZ_g_on_g, \n",
    "            label=\"JALA-EM ($\\mathcal{M}_{G}$)\",\n",
    "            color=colors['gauss_jala'],\n",
    "            linestyle=linestyles['jala'],\n",
    "            linewidth=main_trace_lw,\n",
    "        )\n",
    "        handles_for_shared_legend.append(line_g_est1)\n",
    "\n",
    "        line_st_est1, = ax1.plot(\n",
    "            iter_axis_st_g,\n",
    "            logZ_st_on_g,\n",
    "            label=\"JALA-EM ($\\mathcal{M}_{T}$)\",\n",
    "            color=colors['st_jala'],\n",
    "            linestyle=linestyles['st_jala'],\n",
    "            linewidth=main_trace_lw,\n",
    "        )\n",
    "        handles_for_shared_legend.append(line_st_est1)\n",
    "\n",
    "        line_g_an_est1, = ax1.plot(\n",
    "            iter_axis_g_g,\n",
    "            logZ_an_g_on_g,\n",
    "            label=\"Analytic $\\\\log Z_{\\mathcal{M}_{G}, \\\\theta_{k}}$\",\n",
    "            color=colors['gauss_analytical_est'], \n",
    "            linestyle=linestyles['analytical_est'],\n",
    "            linewidth=main_trace_lw,\n",
    "        )\n",
    "        handles_for_shared_legend.append(line_g_an_est1)\n",
    "\n",
    "        h_true_gen_proxy_ax1 = None\n",
    "        if true_logZ_gen_g_data is not None and np.isfinite(true_logZ_gen_g_data):\n",
    "            ax1.axhline(\n",
    "                true_logZ_gen_g_data,\n",
    "                color=colors['gauss_analytical_true'],\n",
    "                linestyle=linestyles['analytical_true'],\n",
    "                linewidth=ax_line_width,\n",
    "            )\n",
    "\n",
    "            h_true_gen_proxy_ax1, = ax1.plot(\n",
    "                [], [],\n",
    "                color=colors['gauss_analytical_true'],\n",
    "                linestyle=linestyles['analytical_true'],\n",
    "                linewidth=ax_line_width,\n",
    "                label=\"Analytic $\\\\log Z_{\\mathcal{M}_{G}, \\\\theta_{\\star}}$\"\n",
    "            )\n",
    "\n",
    "        ax1.set_title(\"$\\\\mathcal{G} = \\\\mathcal{M}_{G}$\", fontsize=base_font_size)\n",
    "        ax1.set_xlabel('Iteration, $k$')\n",
    "        ax1.set_ylabel('$\\\\log Z_k$')\n",
    "        ax1.grid(True, linestyle=':', alpha=0.7)\n",
    "\n",
    "        if h_true_gen_proxy_ax1:\n",
    "            ax1.legend(handles=[h_true_gen_proxy_ax1], loc='best')\n",
    "        \n",
    "        _set_logZ_ylimits(ax1, [logZ_g_on_g, logZ_st_on_g, logZ_an_g_on_g, true_logZ_gen_g_data])\n",
    "\n",
    "        # Plot 2 -> Models on Student-t Data\n",
    "        ax2 = axes[1]\n",
    "        iter_axis_g_st = results_st_data['iter_axis_g']\n",
    "        logZ_g_on_st = results_st_data['logZ_g_on_st']\n",
    "        iter_axis_st_st = results_st_data['iter_axis_st']\n",
    "        logZ_st_on_st = results_st_data['logZ_st_on_st']\n",
    "        logZ_an_g_on_st = results_st_data['logZ_an_g_on_st']\n",
    "\n",
    "        ax2.plot(\n",
    "            iter_axis_g_st,\n",
    "            logZ_g_on_st,\n",
    "            color=colors['gauss_jala'],\n",
    "            linestyle=linestyles['jala'],\n",
    "            linewidth=main_trace_lw,\n",
    "        )\n",
    "        ax2.plot(\n",
    "            iter_axis_st_st,\n",
    "            logZ_st_on_st,\n",
    "            color=colors['st_jala'],\n",
    "            linestyle=linestyles['st_jala'],\n",
    "            linewidth=main_trace_lw,\n",
    "        )\n",
    "        ax2.plot(\n",
    "            iter_axis_g_st,\n",
    "            logZ_an_g_on_st,\n",
    "            color=colors['gauss_analytical_est'],\n",
    "            linestyle=linestyles['analytical_est'],\n",
    "            linewidth=main_trace_lw,\n",
    "        )\n",
    "\n",
    "        ax2.set_title(\"$\\\\mathcal{G} = \\\\mathcal{M}_{T}$\", fontsize=base_font_size)\n",
    "        ax2.set_xlabel('Iteration, $k$')\n",
    "        ax2.grid(True, linestyle=':', alpha=0.7)\n",
    "\n",
    "        _set_logZ_ylimits(ax2, [logZ_g_on_st, logZ_st_on_st, logZ_an_g_on_st])\n",
    "\n",
    "        # Share Legend! (TODO: below can be cleaned up...(is quite messy))\n",
    "        final_shared_labels = [h.get_label() for h in handles_for_shared_legend]\n",
    "\n",
    "        bottom_margin = 0.30\n",
    "        fig.subplots_adjust(bottom=bottom_margin)\n",
    "\n",
    "        legend_y_position = -0.10\n",
    "\n",
    "        if not handles_for_shared_legend:\n",
    "             print(\"No handles found for legend!\")\n",
    "        else:\n",
    "            fig.legend(\n",
    "                handles_for_shared_legend,\n",
    "                final_shared_labels,\n",
    "                loc='lower center',\n",
    "                bbox_to_anchor=(0.5, legend_y_position),\n",
    "                ncol=3,\n",
    "                frameon=False,\n",
    "            )\n",
    "\n",
    "        if output_filename:\n",
    "            plt.savefig(output_filename, bbox_inches='tight', dpi=300)\n",
    "            print(f\"Final LogZ comparison plot saved as {output_filename}\")\n",
    "        plt.show()\n",
    "\n",
    "    finally:\n",
    "        plt.rcParams.update(current_rc_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "357571d5",
   "metadata": {},
   "source": [
    "### Plotting the single experimental trial..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d95d56b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"\\nGenerating Log Z Side-by-Side Comparison Plot\")\n",
    "\n",
    "custom_font_options = {\n",
    "    \"default\": 20,\n",
    "    \"axes_label\": 20,\n",
    "    \"axes_title\": 20,\n",
    "    \"xtick_label\": 18,\n",
    "    \"ytick_label\": 18,\n",
    "    \"legend\": 20,\n",
    "}\n",
    "\n",
    "custom_line_options = {\n",
    "    \"main_trace\": 3.0,\n",
    "    \"ax_line\": 2.0,\n",
    "    \"grid_line\": 2.0,\n",
    "}\n",
    "\n",
    "logZ_figsize = (16, 6)\n",
    "\n",
    "plot_logZ_comparison_side_by_side(\n",
    "    results_g_data=results_for_g_data,\n",
    "    results_st_data=results_for_st_data,\n",
    "    output_filename=\"logZ_side_by_side_latex_1500_datapoints.pdf\",\n",
    "    custom_figsize=logZ_figsize,\n",
    "    custom_font_options=custom_font_options,\n",
    "    custom_line_options=custom_line_options\n",
    ")\n",
    "\n",
    "print(\"\\nPlot Generated!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "083db7fd",
   "metadata": {},
   "source": [
    "# Experiment Pipeline (Multiple Runs)\n",
    "\n",
    "Logic to run the core experimental trial over a number of repeats."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3e861e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_TOTAL_EXPERIMENTS = 100  # Number of different experimental trials we run!\n",
    "\n",
    "### Data Generation Conifg ###\n",
    "M_data_setting = 500\n",
    "D_features = 8\n",
    "true_w_mean_gen = np.zeros(D_features)\n",
    "true_w_var_gen = 1.0\n",
    "true_w_cov_gen = np.eye(D_features) * true_w_var_gen\n",
    "true_alpha_gen_val = 1.0 / true_w_var_gen\n",
    "\n",
    "true_sigma_scale_gen_default = np.sqrt(1.0)\n",
    "true_nu_gen_val_for_st_data = 4.0\n",
    "\n",
    "### Model Config ###\n",
    "K_iterations = 250\n",
    "N_particles_jala = 50\n",
    "h_langevin_step = 0.00005\n",
    "opt_lr = 0.005\n",
    "ess_thresh_frac_jala_script = 0.0\n",
    "\n",
    "### Numerical Stability Config ###\n",
    "gauss_phi_clips = [(-15, 10), (-10, 15)]\n",
    "studentT_phi_clips = [(-15, 10), (-10, 15), (np.log(0.2), np.log(5.0))]\n",
    "\n",
    "all_runs_results_list = []\n",
    "\n",
    "print(f\"Starting {NUM_TOTAL_EXPERIMENTS} experiments with M_data = {M_data_setting}\\n\")\n",
    "\n",
    "for i_exp_run in range(NUM_TOTAL_EXPERIMENTS):\n",
    "\n",
    "    # For reproducibility!\n",
    "    current_seed = i_exp_run\n",
    "    np.random.seed(current_seed)\n",
    "\n",
    "    print(f\"\\nRUNNING EXPERIMENT {i_exp_run + 1}/{NUM_TOTAL_EXPERIMENTS} (Seed: {current_seed})...\")\n",
    "\n",
    "    current_run_results_dict = {'seed': current_seed}\n",
    "    results_for_g_data_this_run = {}\n",
    "    results_for_st_data_this_run = {}\n",
    "\n",
    "    print(f\"\\nSub-Experiment 1 (Seed {current_seed}): Data Generated from GAUSSIAN!\")\n",
    "\n",
    "    DATA_GEN_TYPE_EXP1 = 'gaussian'\n",
    "    true_sigma_sq_gen_exp1 = true_sigma_scale_gen_default ** 2\n",
    "\n",
    "    X_g, y_g, true_w_g_gen = generate_synthetic_data(\n",
    "        M_data_setting, \n",
    "        D_features, \n",
    "        true_sigma_scale_gen_default,\n",
    "        {'mean': true_w_mean_gen, \n",
    "         'cov': true_w_cov_gen},\n",
    "        likelihood_type=DATA_GEN_TYPE_EXP1, \n",
    "        true_nu=true_nu_gen_val_for_st_data,\n",
    "    )\n",
    "\n",
    "    results_for_g_data_this_run['true_params'] = {\n",
    "        'log_sigma_sq_gen': np.log(true_sigma_sq_gen_exp1),\n",
    "        'log_alpha_gen': np.log(true_alpha_gen_val), \n",
    "        'true_w_gen': true_w_g_gen,\n",
    "    }\n",
    "\n",
    "    true_theta_phi_for_gen_g_data = np.array([\n",
    "        results_for_g_data_this_run['true_params']['log_sigma_sq_gen'],\n",
    "        results_for_g_data_this_run['true_params']['log_alpha_gen']\n",
    "    ])\n",
    "\n",
    "    results_for_g_data_this_run['true_log_Z_analytical_gen_g_data'] = analytical_marginal_likelihood_gaussian(\n",
    "        true_theta_phi_for_gen_g_data, X_g, y_g, D_features\n",
    "    )\n",
    "\n",
    "    initial_phi1_common = 1.0\n",
    "    initial_phi2_common = 1.0\n",
    "    initial_phi3_st_on_g_case = np.log(5.0)\n",
    "\n",
    "    model_fns_g_on_g = {\n",
    "        'potential_U_fn': potential_U_gaussian,\n",
    "        'grad_U_theta_phi_fn': grad_U_theta_phi_gaussian,\n",
    "        'grad_U_w_fn': grad_U_w_gaussian,\n",
    "        'initialize_particles_fn': initialize_particles_gaussian,\n",
    "        'estimate_Z0_fn': analytical_marginal_likelihood_gaussian,\n",
    "        'theta_phi_initial': np.array([initial_phi1_common, initial_phi2_common]),\n",
    "        'D_theta': 2, \n",
    "        'model_name': f'G_on_G_S{current_seed}', \n",
    "        'phi_clips': gauss_phi_clips,\n",
    "    }\n",
    "\n",
    "    th_g_g, lz_g_g = run_jala_em(\n",
    "        X_g,\n",
    "        y_g,\n",
    "        model_fns_g_on_g,\n",
    "        K_iterations,\n",
    "        N_particles_jala,\n",
    "        h_langevin_step,\n",
    "        opt_lr,\n",
    "        ess_thresh_frac_jala_script,\n",
    "        D_features=D_features,\n",
    "    )\n",
    "    results_for_g_data_this_run['th_phi_g_on_g'] = th_g_g\n",
    "    results_for_g_data_this_run['logZ_g_on_g'] = lz_g_g\n",
    "    results_for_g_data_this_run['iter_axis_g'] = np.arange(th_g_g.shape[0])\n",
    "    results_for_g_data_this_run['ml_g_on_g'] = get_theta_ML_gaussian(X_g, y_g, model_fns_g_on_g['theta_phi_initial'])\n",
    "    results_for_g_data_this_run['logZ_an_g_on_g'] = [analytical_marginal_likelihood_gaussian(tp, X_g, y_g, D_features) for tp in th_g_g]\n",
    "\n",
    "    model_fns_st_on_g = {\n",
    "        'potential_U_fn': potential_U_studentT, \n",
    "        'grad_U_theta_phi_fn': grad_U_theta_phi_studentT,\n",
    "        'grad_U_w_fn': grad_U_w_studentT, \n",
    "        'initialize_particles_fn': initialize_particles_studentT,\n",
    "        'estimate_Z0_fn': estimate_Z0_studentT_MC_IS,\n",
    "        'theta_phi_initial': np.array([initial_phi1_common, initial_phi2_common, initial_phi3_st_on_g_case]),\n",
    "        'D_theta': 3, \n",
    "        'model_name': f'ST_on_G_S{current_seed}', \n",
    "        'phi_clips': studentT_phi_clips,\n",
    "    }\n",
    "\n",
    "    th_st_g, lz_st_g = run_jala_em(\n",
    "        X_g,\n",
    "        y_g,\n",
    "        model_fns_st_on_g,\n",
    "        K_iterations,\n",
    "        N_particles_jala,\n",
    "        h_langevin_step,\n",
    "        opt_lr,\n",
    "        ess_thresh_frac_jala_script,\n",
    "        D_features=D_features,\n",
    "    )\n",
    "    results_for_g_data_this_run['th_phi_st_on_g'] = th_st_g\n",
    "    results_for_g_data_this_run['logZ_st_on_g'] = lz_st_g\n",
    "    results_for_g_data_this_run['iter_axis_st'] = np.arange(th_st_g.shape[0])\n",
    "    current_run_results_dict['gaussian_data_scenario'] = results_for_g_data_this_run\n",
    "\n",
    "    print(f\"\\nSub-Experiment 2 (Seed {current_seed}): Data Generated from STUDENT-T!\")\n",
    "\n",
    "    DATA_GEN_TYPE_EXP2 = 'student_t'\n",
    "    true_sigma_sq_gen_exp2 = true_sigma_scale_gen_default ** 2\n",
    "    true_nu_gen_exp2 = true_nu_gen_val_for_st_data\n",
    "\n",
    "    X_st, y_st, true_w_st_gen = generate_synthetic_data(\n",
    "        M_data_setting, \n",
    "        D_features, \n",
    "        true_sigma_scale_gen_default,\n",
    "        {'mean': true_w_mean_gen, 'cov': true_w_cov_gen},\n",
    "        likelihood_type=DATA_GEN_TYPE_EXP2,\n",
    "        true_nu=true_nu_gen_exp2\n",
    "    )\n",
    "\n",
    "    results_for_st_data_this_run['true_params'] = {\n",
    "        'log_sigma_sq_gen': np.log(true_sigma_sq_gen_exp2),\n",
    "        'log_alpha_gen': np.log(true_alpha_gen_val),\n",
    "        'log_nu_gen': np.log(true_nu_gen_exp2), \n",
    "        'true_w_gen': true_w_st_gen\n",
    "    }\n",
    "\n",
    "    initial_phi3_st_on_st_case = np.log(true_nu_gen_exp2) + 1.0\n",
    "\n",
    "    model_fns_g_on_st = {\n",
    "        'potential_U_fn': potential_U_gaussian, \n",
    "        'grad_U_theta_phi_fn': grad_U_theta_phi_gaussian,\n",
    "        'grad_U_w_fn': grad_U_w_gaussian, \n",
    "        'initialize_particles_fn': initialize_particles_gaussian,\n",
    "        'estimate_Z0_fn': analytical_marginal_likelihood_gaussian,\n",
    "        'theta_phi_initial': np.array([initial_phi1_common, initial_phi2_common]),\n",
    "        'D_theta': 2, \n",
    "        'model_name': f'G_on_ST_S{current_seed}',\n",
    "        'phi_clips': gauss_phi_clips\n",
    "    }\n",
    "\n",
    "    th_g_st, lz_g_st = run_jala_em(\n",
    "        X_st,\n",
    "        y_st,\n",
    "        model_fns_g_on_st,\n",
    "        K_iterations,\n",
    "        N_particles_jala,\n",
    "        h_langevin_step,\n",
    "        opt_lr,\n",
    "        ess_thresh_frac_jala_script,\n",
    "        D_features=D_features,\n",
    "    )\n",
    "    results_for_st_data_this_run['th_phi_g_on_st'] = th_g_st\n",
    "    results_for_st_data_this_run['logZ_g_on_st'] = lz_g_st\n",
    "    results_for_st_data_this_run['iter_axis_g'] = np.arange(th_g_st.shape[0])\n",
    "    results_for_st_data_this_run['ml_g_on_st'] = get_theta_ML_gaussian(X_st, y_st, model_fns_g_on_st['theta_phi_initial'])\n",
    "    results_for_st_data_this_run['logZ_an_g_on_st'] = [analytical_marginal_likelihood_gaussian(tp, X_st, y_st, D_features) for tp in th_g_st]\n",
    "\n",
    "    model_fns_st_on_st = {\n",
    "        'potential_U_fn': potential_U_studentT,\n",
    "        'grad_U_theta_phi_fn': grad_U_theta_phi_studentT,\n",
    "        'grad_U_w_fn': grad_U_w_studentT,\n",
    "        'initialize_particles_fn': initialize_particles_studentT,\n",
    "        'estimate_Z0_fn': estimate_Z0_studentT_MC_IS,\n",
    "        'theta_phi_initial': np.array([initial_phi1_common, initial_phi2_common, initial_phi3_st_on_st_case]),\n",
    "        'D_theta': 3,\n",
    "        'model_name': f'ST_on_ST_S{current_seed}',\n",
    "        'phi_clips': studentT_phi_clips,\n",
    "    }\n",
    "\n",
    "    th_st_st, lz_st_st = run_jala_em(\n",
    "        X_st,\n",
    "        y_st,\n",
    "        model_fns_st_on_st,\n",
    "        K_iterations,\n",
    "        N_particles_jala,\n",
    "        h_langevin_step,\n",
    "        opt_lr,\n",
    "        ess_thresh_frac_jala_script,\n",
    "        D_features=D_features,\n",
    "    )\n",
    "    results_for_st_data_this_run['th_phi_st_on_st'] = th_st_st\n",
    "    results_for_st_data_this_run['logZ_st_on_st'] = lz_st_st\n",
    "    results_for_st_data_this_run['iter_axis_st'] = np.arange(th_st_st.shape[0])\n",
    "    current_run_results_dict['student_t_data_scenario'] = results_for_st_data_this_run\n",
    "\n",
    "    all_runs_results_list.append(current_run_results_dict)\n",
    "    \n",
    "    # Saving results...\n",
    "    if (i_exp_run + 1) % 10 == 0:\n",
    "            temp_filename = f\"jala_em_M{M_data_setting}_D{D_features}_runs_temp_upto_{i_exp_run+1}.pkl\"\n",
    "            with open(temp_filename, 'wb') as f_temp:\n",
    "                pickle.dump(all_runs_results_list, f_temp)\n",
    "            print(f\"Temporarily saved results up to run {i_exp_run+1} to {temp_filename}\")\n",
    "\n",
    "print(f\"\\n\\nALL {NUM_TOTAL_EXPERIMENTS} EXPERIMENTS COMPLETED!\")\n",
    "\n",
    "final_results_filename = f\"jala_em_M{M_data_setting}_D{D_features}_total{NUM_TOTAL_EXPERIMENTS}_runs_final_text_params.pkl\"\n",
    "\n",
    "with open(final_results_filename, 'wb') as f_final:\n",
    "    pickle.dump(all_runs_results_list, f_final)\n",
    "\n",
    "print(f\"All experiment results saved to {final_results_filename}\")\n",
    "\n",
    "# Evaluation of model recovery!\n",
    "if NUM_TOTAL_EXPERIMENTS > 0 and len(all_runs_results_list) == NUM_TOTAL_EXPERIMENTS:\n",
    "    correctly_recovered_when_g_is_true = 0\n",
    "    correctly_recovered_when_st_is_true = 0\n",
    "    total_valid_g_scenarios = 0\n",
    "    total_valid_st_scenarios = 0\n",
    "\n",
    "    for i_run, run_results in enumerate(all_runs_results_list):\n",
    "        seed = run_results.get('seed', 'Unknown')\n",
    "        \n",
    "        # Scenario 1 (i.e. $\\mathcal{G} = \\mathcal{M}_G$)\n",
    "        g_data_results = run_results.get('gaussian_data_scenario')\n",
    "        if g_data_results and \\\n",
    "            g_data_results.get('logZ_g_on_g') is not None and len(g_data_results['logZ_g_on_g']) > 0 and \\\n",
    "            g_data_results.get('logZ_st_on_g') is not None and len(g_data_results['logZ_st_on_g']) > 0:\n",
    "            \n",
    "            logZ_g_final_for_g_data = g_data_results['logZ_g_on_g'][-1]\n",
    "            logZ_st_final_for_g_data = g_data_results['logZ_st_on_g'][-1]\n",
    "\n",
    "            if np.isfinite(logZ_g_final_for_g_data) and np.isfinite(logZ_st_final_for_g_data):\n",
    "                total_valid_g_scenarios += 1\n",
    "                if logZ_g_final_for_g_data > logZ_st_final_for_g_data:\n",
    "                    correctly_recovered_when_g_is_true += 1\n",
    "            else:\n",
    "                print(f\"Warning (Run {i_run + 1}, Seed {seed}, G-data): Non-finite logZ values found. logZ_G={logZ_g_final_for_g_data}, logZ_ST={logZ_st_final_for_g_data}. Skipping this run for G-scenario recovery count.\")\n",
    "        else:\n",
    "            print(f\"Warning (Run {i_run +1}, Seed {seed}, G-data): Missing or incomplete logZ data. Skipping this run for Gaussian-scenario recovery count.\")\n",
    "\n",
    "        # Scenario 2 (i.e. $\\mathcal{G} = \\mathcal{M}_T$)\n",
    "        st_data_results = run_results.get('student_t_data_scenario')\n",
    "        if st_data_results and \\\n",
    "            st_data_results.get('logZ_st_on_st') is not None and len(st_data_results['logZ_st_on_st']) > 0 and \\\n",
    "            st_data_results.get('logZ_g_on_st') is not None and len(st_data_results['logZ_g_on_st']) > 0:\n",
    "\n",
    "            logZ_st_final_for_st_data = st_data_results['logZ_st_on_st'][-1]\n",
    "            logZ_g_final_for_st_data = st_data_results['logZ_g_on_st'][-1]\n",
    "\n",
    "            if np.isfinite(logZ_st_final_for_st_data) and np.isfinite(logZ_g_final_for_st_data):\n",
    "                total_valid_st_scenarios += 1\n",
    "                if logZ_st_final_for_st_data > logZ_g_final_for_st_data:\n",
    "                    correctly_recovered_when_st_is_true += 1\n",
    "            else:\n",
    "                print(f\"Warning (Run {i_run + 1}, Seed {seed}, ST-data): Non-finite logZ values found. logZ_ST={logZ_st_final_for_st_data}, logZ_G={logZ_g_final_for_st_data}. Skipping this run for ST-scenario recovery count.\")\n",
    "        else:\n",
    "            print(f\"Warning (Run {i_run + 1}, Seed {seed}, ST-data): Missing or incomplete logZ data. Skipping this run for Student-T-scenario recovery count.\")\n",
    "    \n",
    "    print(\"\\nModel Recovery Proportions:\")\n",
    "    if total_valid_g_scenarios > 0:\n",
    "        percentage_g_data_recovered = (correctly_recovered_when_g_is_true / total_valid_g_scenarios) * 100\n",
    "        print(f\"Data Generation: Gaussian ($\\mathcal{{G}} = \\mathcal{{M}}_G$)\")\n",
    "        print(f\"Proportion of trials where $\\mathcal{{M}}_G$ was correctly selected: {correctly_recovered_when_g_is_true}/{total_valid_g_scenarios} ({percentage_g_data_recovered:.2f}%)\")\n",
    "    else:\n",
    "        print(\"Data Generation: Gaussian ($\\mathcal{{G}} = \\mathcal{{M}}_G$) -> No valid scenarios found to calculate recovery rate!\")\n",
    "\n",
    "    if total_valid_st_scenarios > 0:\n",
    "        percentage_st_data_recovered = (correctly_recovered_when_st_is_true / total_valid_st_scenarios) * 100\n",
    "        print(f\"\\nData Generation: Student-t ($\\mathcal{{G}} = \\mathcal{{M}}_T$)\")\n",
    "        print(f\"Proportion of trials where $\\mathcal{{M}}_T$ was correctly selected: {correctly_recovered_when_st_is_true}/{total_valid_st_scenarios} ({percentage_st_data_recovered:.2f}%)\")\n",
    "    else:\n",
    "        print(\"Data Generation: Student-t ($\\mathcal{{G}} = \\mathcal{{M}}_T$) -> No valid scenarios found to calculate recovery rate!\")\n",
    "        \n",
    "else:\n",
    "    print(\"Could not calculate model recovery percentages!\")\n",
    "\n",
    "\n",
    "if NUM_TOTAL_EXPERIMENTS > 0 and len(all_runs_results_list) > 0:\n",
    "    print(\"\\nGenerating Stacked Comparison Plot for the First Experiment Run (Seed 0)...\")\n",
    "    first_run_data = all_runs_results_list[0]\n",
    "    results_g_data_first_run = first_run_data.get('gaussian_data_scenario', {})\n",
    "    results_st_data_first_run = first_run_data.get('student_t_data_scenario', {})\n",
    "\n",
    "    if results_g_data_first_run and results_st_data_first_run:\n",
    "        stacked_figsize = (12, 7)\n",
    "        plot_stacked(\n",
    "            results_g_data=results_g_data_first_run,\n",
    "            results_st_data=results_st_data_first_run,\n",
    "            output_filename=f\"stacked_M{M_data_setting}_D{D_features}_seed0_text_params_example.pdf\",\n",
    "            custom_figsize=stacked_figsize\n",
    "        )\n",
    "    else:\n",
    "        print(\"Could not generate plot for the Seed 0 run!\")\n",
    "else:\n",
    "    print(\"Errpr!\")\n",
    "\n",
    "print(\"\\nFinished!\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
