{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30f9be6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams[\"font.size\"] = 14"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "381ea642-a9e9-416f-a397-ec8e045bf2d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "\n",
    "fig, axes = plt.subplots(2, 2, figsize=(11, 4.5), dpi=300)\n",
    "fig.subplots_adjust(wspace=0.1, hspace=0.1) \n",
    "\n",
    "md_cuu = np.genfromtxt(f\"../water_qtip4pf/qtip4pf-100/cft.dat\")\n",
    "arrays_cuu = {}\n",
    "for n in [1, 2, 4, 8, 16, 32]:\n",
    "    try:\n",
    "        arrays_cuu[n] = np.genfromtxt(f\"../water_qtip4pf/{n}-100/cft.dat\")\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        pass\n",
    "\n",
    "# ======== TIME CORRELATION ========\n",
    "l = len(md_cuu)\n",
    "s = len(md_cuu) // 100\n",
    "\n",
    "axes[0,0].plot(md_cuu[:l:s, 0] / 4, md_cuu[:l:s, 3], label=\"MD\", color=\"black\", linewidth=2, zorder=100)\n",
    "axes[0,0].fill_between(\n",
    "    md_cuu[:l:s, 0] / 4,\n",
    "    md_cuu[:l:s, 3] - md_cuu[:l:s, 4],\n",
    "    md_cuu[:l:s, 3] + md_cuu[:l:s, 4],\n",
    "    color=\"black\",\n",
    "    alpha=0.3,\n",
    "    linewidth=0,\n",
    "    zorder=100,\n",
    ")\n",
    "\n",
    "\n",
    "for key, array in arrays_cuu.items():\n",
    "    l = len(array)\n",
    "    s = len(array) // 100\n",
    "    print(s)\n",
    "    axes[0,0].plot(array[:l:s, 0] * key, array[:l:s, 3], label=f\"{str(key)} fs\")\n",
    "    axes[0,0].fill_between(\n",
    "        array[:l:s, 0] * key,\n",
    "        array[:l:s, 3] - array[:l:s, 4],\n",
    "        array[:l:s, 3] + array[:l:s, 4],\n",
    "        alpha=0.3,\n",
    "        linewidth=0,\n",
    "    )\n",
    "\n",
    "# ======== MSD ========\n",
    "\n",
    "for key, array in arrays_cuu.items():\n",
    "    l = len(array)\n",
    "    s = len(array) // 10\n",
    "    axes[1,0].plot(array[:l, 0] * key, array[:l, 1], label=f\"{str(key)} fs\")\n",
    "    axes[1,0].fill_between(\n",
    "        array[:l, 0] * key,\n",
    "        array[:l, 1] - array[:l, 2],\n",
    "        array[:l, 1] + array[:l, 2],\n",
    "        alpha=0.3\n",
    "    )\n",
    "l = len(md_cuu)\n",
    "s = len(md_cuu) // 10\n",
    "axes[1,0].plot(md_cuu[:l, 0] / 4, md_cuu[:l, 1], label=\"MD\", color=\"black\", linewidth=2)\n",
    "axes[1,0].fill_between(\n",
    "    md_cuu[:l, 0] / 4,\n",
    "    md_cuu[:l, 1] - md_cuu[:l, 2],\n",
    "    md_cuu[:l, 1] + md_cuu[:l, 2],\n",
    "    color=\"black\",\n",
    "    alpha=0.3\n",
    ")\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# ========================== 1000 FS DATA ===========================\n",
    "\n",
    "\n",
    "md_cuu = np.genfromtxt(f\"../water_qtip4pf/qtip4pf-1000/cft.dat\")\n",
    "arrays_cuu = {}\n",
    "for n in [1, 2, 4, 8, 16, 32]:\n",
    "    try:\n",
    "        arrays_cuu[n] = np.genfromtxt(f\"../water_qtip4pf/{n}-1000/cft.dat\")\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        pass\n",
    "\n",
    "\n",
    "# ======== TIME CORRELATION ========\n",
    "l = len(md_cuu)\n",
    "s = len(md_cuu) // 100\n",
    "\n",
    "axes[0,1].plot(md_cuu[:l:s, 0] / 4, md_cuu[:l:s, 3], label=\"MD\", color=\"black\", linewidth=2, zorder=100)\n",
    "axes[0,1].fill_between(\n",
    "    md_cuu[:l:s, 0] / 4,\n",
    "    md_cuu[:l:s, 3] - md_cuu[:l:s, 4],\n",
    "    md_cuu[:l:s, 3] + md_cuu[:l:s, 4],\n",
    "    color=\"black\",\n",
    "    alpha=0.3,\n",
    "    linewidth=0,\n",
    "    zorder=100,\n",
    ")\n",
    "\n",
    "\n",
    "for key, array in arrays_cuu.items():\n",
    "    l = len(array)\n",
    "    s = len(array) // 100\n",
    "    print(s)\n",
    "    axes[0,1].plot(array[:l:s, 0] * key, array[:l:s, 3], label=f\"{str(key)} fs\")\n",
    "    axes[0,1].fill_between(\n",
    "        array[:l:s, 0] * key,\n",
    "        array[:l:s, 3] - array[:l:s, 4],\n",
    "        array[:l:s, 3] + array[:l:s, 4],\n",
    "        alpha=0.3,\n",
    "        linewidth=0,\n",
    "    )\n",
    "\n",
    "\n",
    "# ======== MSD ========\n",
    "\n",
    "for key, array in arrays_cuu.items():\n",
    "    l = len(array)\n",
    "    s = len(array) // 10\n",
    "    axes[1,1].plot(array[:l, 0] * key, array[:l, 1], label=f\"{str(key)} fs\")\n",
    "    axes[1,1].fill_between(\n",
    "        array[:l, 0] * key,\n",
    "        array[:l, 1] - array[:l, 2],\n",
    "        array[:l, 1] + array[:l, 2],\n",
    "        alpha=0.3\n",
    "    )\n",
    "l = len(md_cuu)\n",
    "s = len(md_cuu) // 10\n",
    "axes[1,1].plot(md_cuu[:l, 0] / 4, md_cuu[:l, 1], label=\"MD\", color=\"black\", linewidth=2)\n",
    "axes[1,1].fill_between(\n",
    "    md_cuu[:l, 0] / 4,\n",
    "    md_cuu[:l, 1] - md_cuu[:l, 2],\n",
    "    md_cuu[:l, 1] + md_cuu[:l, 2],\n",
    "    color=\"black\",\n",
    "    alpha=0.3\n",
    ")\n",
    "\n",
    "\n",
    "for ax in axes.flatten():\n",
    "    ax.set_xlim(-0.1*1000, 10.1*1000)\n",
    "    ax.set_xticks(np.linspace(0, 10000, 5))\n",
    "    ax.set_xticklabels(np.linspace(0, 10, 5))\n",
    "    ax.tick_params(axis='both', labelsize=10) \n",
    "    \n",
    "    \n",
    "\n",
    "axes[0,0].set_ylim(-0.05, 1.05)\n",
    "axes[0,1].set_ylim(-0.05, 1.05)\n",
    "axes[0,0].set_yticks(np.linspace(0, 1, 6))\n",
    "axes[0,1].set_yticks(np.linspace(0, 1, 6))\n",
    "\n",
    "axes[1,0].set_ylim(-0.5, 20.5)\n",
    "axes[1,1].set_ylim(-0.5, 20.5)\n",
    "axes[1,0].set_yticks(np.linspace(0, 20, 5))\n",
    "axes[1,1].set_yticks(np.linspace(0, 20, 5))\n",
    "\n",
    "axes[0,0].set_xticklabels([])\n",
    "axes[0,1].set_xticklabels([])\n",
    "axes[0,1].set_yticklabels([])\n",
    "axes[1,1].set_yticklabels([])\n",
    "\n",
    "\n",
    "axes[1,0].legend(fontsize=8, ncols=2)\n",
    "\n",
    "axes[0,0].set_title(r\"$\\tau_L = 100$ fs\", fontsize=12) \n",
    "axes[0,1].set_title(r\"$\\tau_L = 1000$ fs\", fontsize=12)\n",
    "\n",
    "axes[0,0].set_ylabel(r\"$c_{uu}(t)$\", fontsize=12) \n",
    "axes[1,0].set_ylabel(\"MSD [Å$^2$]\", fontsize=12) \n",
    "\n",
    "left = axes[1, 0].get_position()\n",
    "right = axes[1, 1].get_position()\n",
    "\n",
    "# Compute the horizontal center of the plotting region\n",
    "xcenter = (left.x0 + right.x1) / 2\n",
    "\n",
    "# Add the common x-label\n",
    "fig.text(xcenter, 0, 'Time [ps]', ha='center', va='center', fontsize=14)\n",
    "\n",
    "\n",
    "plt.savefig(\"figure_qtip4pf.pdf\", dpi=300, bbox_inches=\"tight\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c45fcf9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "p_uncertainties_epistemic_16 = np.load(\"../uq/p_uncertainties_epistemic_16.npy\")\n",
    "p_uncertainties_epistemic_funny_16 = np.load(\"../uq/p_uncertainties_epistemic_funny_16.npy\")\n",
    "p_uncertainties_epistemic_64 = np.load(\"../uq/p_uncertainties_epistemic_64.npy\")\n",
    "q_uncertainties_epistemic_16 = np.load(\"../uq/q_uncertainties_epistemic_16.npy\")\n",
    "q_uncertainties_epistemic_funny_16 = np.load(\"../uq/q_uncertainties_epistemic_funny_16.npy\")\n",
    "q_uncertainties_epistemic_64 = np.load(\"../uq/q_uncertainties_epistemic_64.npy\")\n",
    "\n",
    "p_uncertainties_aleatoric_16 = np.load(\"../uq/p_uncertainties_aleatoric_16.npy\")\n",
    "p_uncertainties_aleatoric_funny_16 = np.load(\"../uq/p_uncertainties_aleatoric_funny_16.npy\")\n",
    "p_uncertainties_aleatoric_64 = np.load(\"../uq/p_uncertainties_aleatoric_64.npy\")\n",
    "q_uncertainties_aleatoric_16 = np.load(\"../uq/q_uncertainties_aleatoric_16.npy\")\n",
    "q_uncertainties_aleatoric_funny_16 = np.load(\"../uq/q_uncertainties_aleatoric_funny_16.npy\")\n",
    "q_uncertainties_aleatoric_64 = np.load(\"../uq/q_uncertainties_aleatoric_64.npy\")\n",
    "\n",
    "q_squared_residuals_16 = np.load(\"../uq/q_squared_residuals_16.npy\")\n",
    "q_squared_residuals_64 = np.load(\"../uq/q_squared_residuals_64.npy\")\n",
    "p_squared_residuals_16 = np.load(\"../uq/p_squared_residuals_16.npy\")\n",
    "p_squared_residuals_64 = np.load(\"../uq/p_squared_residuals_64.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c789d5df",
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma = 1.0\n",
    "func = lambda x: x * np.exp(-x**2/(2*sigma**2)) * 1.0/(sigma*np.sqrt(2*np.pi))\n",
    "x = np.geomspace(0.01, 10, 10000)\n",
    "\n",
    "from scipy.optimize import root_scalar\n",
    "\n",
    "def pdf(x, sigma):\n",
    "    return x * np.exp(-x**2/(2*sigma**2)) * 1.0/(sigma*np.sqrt(2*np.pi))\n",
    "\n",
    "def find_where_pdf_is_c(c, sigma):\n",
    "    # Finds the two values of x where the pdf is equal to c\n",
    "    mode_value = pdf(sigma, sigma)\n",
    "    if c > mode_value:\n",
    "        raise ValueError(\"c must be less than mode_value\")\n",
    "    where_below_mode = root_scalar(lambda x: pdf(x, sigma) - c, bracket=[0, sigma]).root\n",
    "    where_above_mode = root_scalar(lambda x: pdf(x, sigma) - c, bracket=[sigma, 100]).root\n",
    "    return where_below_mode, where_above_mode\n",
    "\n",
    "def pdf_integral(sigma, c):\n",
    "    # Calculates the integral (analytical) of the pdf from x1 to x2,\n",
    "    # where x1 and x2 are the two values of x where the pdf is equal to c\n",
    "    x1, x2 = find_where_pdf_is_c(c, sigma)\n",
    "    return np.exp(-x1**2/(2*sigma**2)) - np.exp(-x2**2/(2*sigma**2))\n",
    "\n",
    "def find_fraction(sigma, fraction):\n",
    "    # Finds the value of c where the integral of the pdf from x1 to x2 is equal to fraction,\n",
    "    # where x1 and x2 are the two values of x where the pdf is equal to c\n",
    "    mode_value = pdf(sigma, sigma)\n",
    "    return root_scalar(lambda x: pdf_integral(sigma, x) - fraction, x0=mode_value-0.01, x1=mode_value-0.02).root\n",
    "\n",
    "from scipy.stats import norm\n",
    "\n",
    "desired_fractions = [\n",
    "    norm.cdf(1, 0.0, 1.0) - norm.cdf(-1, 0.0, 1.0),  # 1 sigma\n",
    "    norm.cdf(2, 0.0, 1.0) - norm.cdf(-2, 0.0, 1.0),  # 2 sigma\n",
    "    norm.cdf(3, 0.0, 1.0) - norm.cdf(-3, 0.0, 1.0),  # 3 sigma\n",
    "]\n",
    "# print(desired_fractions)\n",
    "\n",
    "sigmas = np.linspace(2e-5, 5e0, 5)\n",
    "\n",
    "lower_bounds = []\n",
    "upper_bounds = []\n",
    "for desired_fraction in desired_fractions:\n",
    "    lower_bounds.append([])\n",
    "    upper_bounds.append([])\n",
    "    for sigma in sigmas:\n",
    "        isoline_value = find_fraction(sigma, desired_fraction)\n",
    "        x1, x2 = find_where_pdf_is_c(isoline_value, sigma)\n",
    "        lower_bounds[-1].append(x1)\n",
    "        upper_bounds[-1].append(x2)\n",
    "\n",
    "    additional_sigma = 100.0\n",
    "    lower_bounds[-1].append(\n",
    "        lower_bounds[-1][-1] + (lower_bounds[-1][-1] - lower_bounds[-1][-2])/(sigmas[-1] - sigmas[-2]) * additional_sigma\n",
    "    )\n",
    "    upper_bounds[-1].append(\n",
    "        upper_bounds[-1][-1] + (upper_bounds[-1][-1] - upper_bounds[-1][-2])/(sigmas[-1] - sigmas[-2]) * additional_sigma\n",
    "    )\n",
    "\n",
    "    lower_bounds[-1] = np.array(lower_bounds[-1])\n",
    "    upper_bounds[-1] = np.array(upper_bounds[-1])\n",
    "\n",
    "sigmas = np.concatenate([sigmas, np.array([100.0])])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2934305c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    " \n",
    "gs = gridspec.GridSpec(nrows=5, ncols=7, width_ratios=[1, 0.1, 1, 0.6, 1, 0.1, 1], height_ratios=[1, 0.1, 1, 0.55, 1])\n",
    "fig = plt.figure(figsize=(11, 8), dpi=300)\n",
    "fig.subplots_adjust(wspace=0, hspace=0) \n",
    "\n",
    "all_axes = []\n",
    "\n",
    "# ===== 16 fs =====\n",
    "\n",
    "ax16_q_llpr = fig.add_subplot(gs[0, 0])\n",
    "all_axes.append(ax16_q_llpr)\n",
    "ax16_q_llpr.text(0.975, 0.025, r\"$\\sigma_{\\text{LLPR}} (\\tilde{\\boldsymbol{q}}_i)$\", transform=ax16_q_llpr.transAxes,\n",
    "            fontsize=15, color=\"k\", va='bottom', ha='right')\n",
    "ax16_q_llpr.plot(q_uncertainties_epistemic_16, q_squared_residuals_16, \".\", markersize=0.5, rasterized=True, alpha=0.2, color=\"#228C22\")\n",
    "\n",
    "ax16_p_llpr = fig.add_subplot(gs[0, 2])\n",
    "all_axes.append(ax16_p_llpr)\n",
    "ax16_p_llpr.text(0.975, 0.025, r\"$\\sigma_{\\text{LLPR}} (\\tilde{\\boldsymbol{p}}_i)$\", transform=ax16_p_llpr.transAxes,\n",
    "            fontsize=15, color=\"k\", va='bottom', ha='right')\n",
    "ax16_p_llpr.plot(p_uncertainties_epistemic_16, p_squared_residuals_16, \".\", markersize=0.5, rasterized=True, alpha=0.2, color=\"#228C22\")\n",
    "\n",
    "ax16_q_mve =  fig.add_subplot(gs[2, 0])\n",
    "all_axes.append(ax16_q_mve)\n",
    "ax16_q_mve.text(0.975, 0.025, r\"$\\sigma_{\\text{MVE}} (\\tilde{\\boldsymbol{q}}_i)$\", transform=ax16_q_mve.transAxes,\n",
    "            fontsize=15, color=\"k\", va='bottom', ha='right')\n",
    "ax16_q_mve.plot(q_uncertainties_aleatoric_16, q_squared_residuals_16, \".\", markersize=0.5, rasterized=True, alpha=0.2, color=\"orange\")\n",
    "\n",
    "ax16_p_mve =  fig.add_subplot(gs[2, 2])\n",
    "all_axes.append(ax16_p_mve)\n",
    "ax16_p_mve.text(0.975, 0.025, r\"$\\sigma_{\\text{MVE}} (\\tilde{\\boldsymbol{p}}_i)$\", transform=ax16_p_mve.transAxes,\n",
    "            fontsize=15, color=\"k\", va='bottom', ha='right')\n",
    "ax16_p_mve.plot(p_uncertainties_aleatoric_16, p_squared_residuals_16, \".\", markersize=0.5, rasterized=True, alpha=0.2, color=\"orange\")\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# ===== 64 fs =====\n",
    "\n",
    "ax64_q_llpr = fig.add_subplot(gs[0, 4])\n",
    "all_axes.append(ax64_q_llpr)\n",
    "ax64_q_llpr.text(0.975, 0.025, r\"$\\sigma_{\\text{LLPR}} (\\tilde{\\boldsymbol{q}}_i)$\", transform=ax64_q_llpr.transAxes,\n",
    "            fontsize=15, color=\"k\", va='bottom', ha='right')\n",
    "ax64_q_llpr.plot(q_uncertainties_epistemic_64, q_squared_residuals_64, \".\", markersize=0.5, rasterized=True, alpha=0.2, color=\"#228C22\")\n",
    "\n",
    "ax64_p_llpr = fig.add_subplot(gs[0, 6])\n",
    "all_axes.append(ax64_p_llpr)\n",
    "ax64_p_llpr.text(0.975, 0.025, r\"$\\sigma_{\\text{LLPR}} (\\tilde{\\boldsymbol{p}}_i)$\", transform=ax64_p_llpr.transAxes,\n",
    "            fontsize=15, color=\"k\", va='bottom', ha='right')\n",
    "ax64_p_llpr.plot(p_uncertainties_epistemic_64, p_squared_residuals_64, \".\", markersize=0.5, rasterized=True, alpha=0.2, color=\"#228C22\")\n",
    "\n",
    "ax64_q_mve = fig.add_subplot(gs[2, 4])\n",
    "all_axes.append(ax64_q_mve)\n",
    "ax64_q_mve.text(0.975, 0.025, r\"$\\sigma_{\\text{MVE}} (\\tilde{\\boldsymbol{q}}_i)$\", transform=ax64_q_mve.transAxes,\n",
    "            fontsize=15, color=\"k\", va='bottom', ha='right')\n",
    "ax64_q_mve.plot(q_uncertainties_aleatoric_64, q_squared_residuals_64, \".\", markersize=0.5, rasterized=True, alpha=0.2, color=\"orange\")\n",
    "\n",
    "ax64_p_mve = fig.add_subplot(gs[2, 6])\n",
    "all_axes.append(ax64_p_mve)\n",
    "ax64_p_mve.text(0.975, 0.025, r\"$\\sigma_{\\text{MVE}} (\\tilde{\\boldsymbol{p}}_i)$\", transform=ax64_p_mve.transAxes,\n",
    "            fontsize=15, color=\"k\", va='bottom', ha='right')\n",
    "ax64_p_mve.plot(p_uncertainties_aleatoric_64, p_squared_residuals_64, \".\", markersize=0.5, rasterized=True, alpha=0.2, color=\"orange\")\n",
    "\n",
    "\n",
    "for ax in all_axes:\n",
    "    ax.set_xscale(\"log\")\n",
    "    ax.set_yscale(\"log\")\n",
    "    ax.plot(sigmas, sigmas, color=\"k\")\n",
    "    for l, u in zip(lower_bounds, upper_bounds):\n",
    "        ax.plot(sigmas, l, color=\"k\", linewidth=0.5)\n",
    "        ax.plot(sigmas, u, color=\"k\", linewidth=0.5)\n",
    "    ax.set_xlim(3e-5, 1e1)\n",
    "    ax.set_ylim(3e-7, 1e1)\n",
    "    ax.tick_params(axis='both', labelsize=7) \n",
    "\n",
    "    # major_xticks = ax.get_xticks()\n",
    "    # # Keep every other major tick\n",
    "    # ax.set_xticks(major_xticks[::2])\n",
    "\n",
    "    # major_yticks = ax.get_yticks()\n",
    "    # # Keep every other major tick\n",
    "    # ax.set_yticks(major_yticks[::2])\n",
    "\n",
    "    ax.minorticks_on()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "ax16_q_llpr.set_xticklabels([])\n",
    "ax16_p_llpr.set_xticklabels([])\n",
    "ax16_p_llpr.set_yticklabels([])\n",
    "ax16_p_mve.set_yticklabels([])\n",
    "\n",
    "ax64_q_llpr.set_xticklabels([])\n",
    "ax64_p_llpr.set_xticklabels([])\n",
    "ax64_p_llpr.set_yticklabels([])\n",
    "ax64_p_mve.set_yticklabels([])\n",
    "\n",
    "left_cluster_left = ax16_q_llpr.get_position().x0\n",
    "left_cluster_right = ax16_p_llpr.get_position().x1\n",
    "left_cluster_center = (left_cluster_left + left_cluster_right) / 2\n",
    "\n",
    "right_cluster_left = ax64_q_llpr.get_position().x0\n",
    "right_cluster_right = ax64_p_llpr.get_position().x1\n",
    "right_cluster_center = (right_cluster_left + right_cluster_right) / 2\n",
    "\n",
    "# Choose a y-position slightly above the axes (adjust if needed)\n",
    "y_pos = 0.895\n",
    "\n",
    "# Add figure-level titles above the clusters\n",
    "fig.text(left_cluster_center, y_pos, r\"$\\tau=$ 16 fs\", ha='center', va='bottom', fontsize=14)\n",
    "fig.text(right_cluster_center, y_pos, r\"$\\tau=$ 64 fs\", ha='center', va='bottom', fontsize=14)\n",
    "\n",
    "# Compute horizontal centers for columns\n",
    "x_positions = {\n",
    "    \"q\": (ax16_q_llpr.get_position().x0 + ax64_q_llpr.get_position().x0) / 2,\n",
    "    \"p\": (ax16_p_llpr.get_position().x1 + ax64_p_llpr.get_position().x1) / 2,\n",
    "}\n",
    "\n",
    "\n",
    "left_group_left = min(ax16_q_llpr.get_position().x0, ax16_q_mve.get_position().x0)\n",
    "left_group_right = max(ax16_p_llpr.get_position().x1, ax16_p_mve.get_position().x1)\n",
    "left_group_bottom = min(ax16_q_mve.get_position().y0, ax16_p_mve.get_position().y0)\n",
    "left_group_top = max(ax16_q_llpr.get_position().y1, ax16_p_llpr.get_position().y1)\n",
    "\n",
    "x_center = (left_group_left + left_group_right) / 2\n",
    "y_center = (left_group_bottom + left_group_top) / 2\n",
    "\n",
    "fig.text(left_group_left-0.05, y_center, \"Actual Error\", \n",
    "         va='center', ha='right', fontsize=14, rotation=90)\n",
    "fig.text(x_center, left_group_bottom - 0.05, \"Predicted Error\", \n",
    "         va='top', ha='center', fontsize=14)\n",
    "\n",
    "left_group_left = min(ax64_q_llpr.get_position().x0, ax64_q_mve.get_position().x0)\n",
    "left_group_right = max(ax64_p_llpr.get_position().x1, ax64_p_mve.get_position().x1)\n",
    "left_group_bottom = min(ax64_q_mve.get_position().y0, ax64_p_mve.get_position().y0)\n",
    "left_group_top = max(ax64_q_llpr.get_position().y1, ax64_p_llpr.get_position().y1)\n",
    "\n",
    "x_center = (left_group_left + left_group_right) / 2\n",
    "y_center = (left_group_bottom + left_group_top) / 2\n",
    "\n",
    "fig.text(left_group_left-0.05, y_center, \"Actual Error\", \n",
    "         va='center', ha='right', fontsize=14, rotation=90)\n",
    "fig.text(x_center, left_group_bottom - 0.05, \"Predicted Error\", \n",
    "         va='top', ha='center', fontsize=14)\n",
    "\n",
    "\n",
    "\n",
    "# ==== ood ====\n",
    "\n",
    "ax16_ood = fig.add_subplot(gs[4,:3])\n",
    "ax16_ood.axvspan(0.9**(1/3), 1.1**(1/3), color='grey', lw=0, alpha=0.2, label=\"training domain\")\n",
    "ax16_ood.plot(np.linspace(0.5, 1.5, 11), q_uncertainties_aleatoric_funny_16, \"o\", markersize=5, label=r\"$\\sigma_{\\text{MVE}}(\\tilde{\\boldsymbol{q}}_i)$\", color=\"orange\")\n",
    "ax16_ood.plot(np.linspace(0.5, 1.5, 11), q_uncertainties_epistemic_funny_16, \"o\", markersize=5, label=r\"$\\sigma_{\\text{LLPR}}(\\tilde{\\boldsymbol{q}}_i)$\", color=\"#228C22\")\n",
    "ax16_ood.set_xlabel(\"Cell scaling factor\",fontsize=14)\n",
    "ax16_ood.set_ylabel(r\"$\\sigma(\\tilde{\\boldsymbol{q}}_i)$\",fontsize=14)\n",
    "ax16_ood.tick_params(axis='both', labelsize=10) \n",
    "\n",
    "ax16_ood.legend(fontsize=10)\n",
    "\n",
    "ax64_ood = fig.add_subplot(gs[4,4:])\n",
    "\n",
    "ax64_ood.axvspan(0.9**(1/3), 1.1**(1/3), color='grey', lw=0, alpha=0.2, label=\"Training domain\")\n",
    "ax64_ood.plot(np.linspace(0.5, 1.5, 11), p_uncertainties_aleatoric_funny_16, \"o\", markersize=5, label=r\"$\\sigma_{\\text{MVE}}(\\tilde{\\boldsymbol{p}}_i)$\", color=\"orange\")\n",
    "ax64_ood.plot(np.linspace(0.5, 1.5, 11), p_uncertainties_epistemic_funny_16, \"o\", markersize=5,label=r\"$\\sigma_{\\text{LLPR}}(\\tilde{\\boldsymbol{p}}_i)$\", color=\"#228C22\")\n",
    "ax64_ood.set_xlabel(\"Cell scaling factor\",fontsize=14)\n",
    "ax64_ood.set_ylabel(r\"$\\sigma(\\tilde{\\boldsymbol{p}}_i)$\",fontsize=14)\n",
    "ax64_ood.tick_params(axis='both', labelsize=10) \n",
    "\n",
    "ax64_ood.legend(fontsize=10)\n",
    "\n",
    "\n",
    "import matplotlib.lines as mlines\n",
    "\n",
    "# Get y-position in figure coordinates\n",
    "y_divider = ax16_ood.get_position().y1 + 0.03  # Adjust this offset for visual balance\n",
    "\n",
    "# Draw a horizontal line across the full figure width (from 0 to 1)\n",
    "divider_line = mlines.Line2D([0.05, 0.91], [y_divider, y_divider], transform=fig.transFigure,\n",
    "                             color=\"black\", linewidth=1)\n",
    "fig.add_artist(divider_line)\n",
    "\n",
    "# 'a' for top part\n",
    "fig.text(0.05, 0.875, \"a\", fontsize=16, fontweight='bold', ha='left', va='bottom')\n",
    "\n",
    "# 'b' for bottom part\n",
    "fig.text(0.05, y_divider - 0.03, \"b\", fontsize=16, fontweight='bold', ha='left', va='top')\n",
    "\n",
    "\n",
    "\n",
    "plt.savefig(\"figure_uq.pdf\", dpi=300, bbox_inches=\"tight\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bf46bad",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
