{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cc9fe8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "import pandas as pd\n",
    "from matplotlib import gridspec\n",
    "import matplotlib.ticker as mticker\n",
    "import torch\n",
    "\n",
    "import svgutils.transform as sg\n",
    "from svgutils.compose import *\n",
    "\n",
    "from plotting_utils import cm2inch, get_size_tuple\n",
    "from IPython.display import SVG, display\n",
    "\n",
    "import matplotlib as mpl\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "\n",
    "from utils.misc import get_output_dir\n",
    "\n",
    "out_dir = get_output_dir()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba1f192b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Unpickle and load data ####\n",
    "\n",
    "# Load data for fix FNO-FMPE\n",
    "\n",
    "\n",
    "with open(out_dir/'darcy_experiment_snr_30'/'FNO_FMPE_always_equispaced_True_predictive_summary.pkl', 'rb') as f1:\n",
    "    data_FNOfix = pickle.load(f1)\n",
    "\n",
    "theta_FNOfix = data_FNOfix['theta_test']\n",
    "x_FNOfix = data_FNOfix['x_test']\n",
    "posterior_FNOfix = data_FNOfix['posterior_samples']\n",
    "posterior_predictive_FNOfix = data_FNOfix['posterior_predictive_samples'] # this contains NaNs which need to be filtered out\n",
    "\n",
    "# Load data for flex FNO-FMPE\n",
    "with open(out_dir/'darcy_experiment_snr_30'/'FNO_FMPE_always_equispaced_False_target_gridsize_2048_predictive_summary.pkl', 'rb') as f2:\n",
    "    data_FNOflex = pickle.load(f2)\n",
    "\n",
    "theta_FNOflex = data_FNOflex['theta_test']\n",
    "x_FNOflex = data_FNOflex['x_test']\n",
    "posterior_FNOflex = data_FNOflex['posterior_samples']\n",
    "posterior_predictive_FNOflex = data_FNOflex['posterior_predictive_samples'] # this contains NaNs which need to be filtered out\n",
    "\n",
    "# Load data for spectral NPE\n",
    "with open(out_dir/'darcy_experiment_snr_30'/'spectral_NPE_predictive_summary.pkl', 'rb') as f3:\n",
    "    data_specNPE = pickle.load(f3)\n",
    "\n",
    "theta_specNPE = data_specNPE['theta_test']\n",
    "x_specNPE = data_specNPE['x_test']\n",
    "posterior_specNPE = data_specNPE['posterior_samples']\n",
    "posterior_predictive_specNPE = data_specNPE['posterior_predictive_samples'] # this contains NaNs which need to be filtered out\n",
    "\n",
    "# Load data for raw FMPE\n",
    "with open(out_dir/'darcy_experiment_snr_30'/'raw_FMPE_predictive_summary.pkl', 'rb') as f4:\n",
    "    data_rawFMPE = pickle.load(f4)\n",
    "\n",
    "theta_rawFMPE = data_rawFMPE['theta_test']\n",
    "x_rawFMPE = data_rawFMPE['x_test']\n",
    "posterior_rawFMPE = data_rawFMPE['posterior_samples']\n",
    "posterior_predictive_rawFMPE = data_rawFMPE['posterior_predictive_samples'] # this contains NaNs which need to be filtered out\n",
    "\n",
    "# Load data for spetral FMPE\n",
    "with open(out_dir/'darcy_experiment_snr_30'/'spectral_FMPE_predictive_summary.pkl', 'rb') as f5:\n",
    "    data_specFMPE = pickle.load(f5)\n",
    "\n",
    "theta_specFMPE = data_specFMPE['theta_test']\n",
    "x_specFMPE = data_specFMPE['x_test']\n",
    "posterior_specFMPE = data_specFMPE['posterior_samples']\n",
    "posterior_predictive_specFMPE = data_specFMPE['posterior_predictive_samples'] # this contains NaNs which need to be filtered out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39ad9c33",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Plot and save observation for overview figure ####\n",
    "\n",
    "plt.imshow(x_FNOfix[-1,0,:,:], vmin=0, vmax=0.08)\n",
    "plt.axis('off')\n",
    "plt.savefig('observation3.png', dpi=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89dadf56",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Compute necessary statistics for FNO flex ####\n",
    "\n",
    "true_theta_flex = np.exp(theta_FNOflex[0,0,:,:])\n",
    "mean_theta_flex = np.mean(np.exp(posterior_FNOflex[:,0,:,:]), axis=0) # theta should be visualized in physical space\n",
    "sd_theta_flex = np.std(np.exp(posterior_FNOflex[:,0,:,:]), axis=0)\n",
    "absolute_error_theta_flex = np.abs(mean_theta_flex - true_theta_flex)\n",
    "true_observation_flex = x_FNOflex[0,0,:,:]\n",
    "\n",
    "\n",
    "# Get mean observation (kick out data with NaNs)\n",
    "posterior_predictive_samples_filtered = np.zeros_like(posterior_predictive_FNOflex[:,0,:,:])\n",
    "\n",
    "n_skipped_observations = 0 \n",
    "n_used_layer = 0\n",
    "for ii in range(posterior_predictive_FNOflex.shape[0]):\n",
    "    if np.sum(~np.isnan(posterior_predictive_FNOflex[ii,0,:,:])) == 0:\n",
    "        n_skipped_observations += 1\n",
    "    elif np.any(posterior_predictive_FNOflex[ii,0,:,:] > 1):\n",
    "        n_skipped_observations += 1\n",
    "    else:\n",
    "        posterior_predictive_samples_filtered[n_used_layer,:,:] = posterior_predictive_FNOflex[ii, 0, :, :]\n",
    "        n_used_layer += 1 \n",
    "\n",
    "\n",
    "posterior_predictive_samples_filtered = posterior_predictive_samples_filtered[:n_used_layer,:,:]\n",
    "\n",
    "mean_observation = np.mean(posterior_predictive_samples_filtered, axis=0)\n",
    "rand_idx = np.random.randint(0, posterior_predictive_samples_filtered.shape[0])\n",
    "single_observation = posterior_predictive_samples_filtered[rand_idx, :, :]\n",
    "relative_error_observation = np.abs((mean_observation - true_observation_flex)/true_observation_flex) \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3053dd92",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Compute necessary statistics for FNO fix ####\n",
    "\n",
    "true_theta_fix = np.exp(theta_FNOfix[0,0,:,:])\n",
    "mean_theta_fix = np.mean(np.exp(posterior_FNOfix[:,0,:,:]), axis=0) # theta should be visualized in physical space\n",
    "sd_theta_fix = np.std(np.exp(posterior_FNOfix[:,0,:,:]), axis=0)\n",
    "absolute_error_theta_fix = np.abs(mean_theta_fix - true_theta_fix)\n",
    "true_observation_fix = x_FNOfix[0,0,:,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcad8675",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Plot of mean theta with crosssections for overview figure ####\n",
    "\n",
    "plt.imshow(true_theta_fix, vmin=0, vmax=1.8)\n",
    "plt.axis('off')\n",
    "plt.savefig('true_theta.png', dpi=200)\n",
    "plt.show()\n",
    "\n",
    "plt.imshow(np.exp(posterior_FNOfix[10,0,:,:]), vmin=0, vmax=1.8)\n",
    "plt.axis('off')\n",
    "plt.axhline(y=40, color='lightgray', linestyle='--', linewidth=5)\n",
    "plt.axvline(x=30, color='lightgray', linestyle='--', linewidth=5)\n",
    "plt.savefig('theta_posterior_sample.png', dpi=200)\n",
    "plt.show()\n",
    "\n",
    "with plt.rc_context(fname='matplotlibrc'):\n",
    "\n",
    "    plt.plot(np.arange(0,129), mean_theta_fix[40,:], color='black', linewidth=4)\n",
    "    plt.plot(np.arange(129), mean_theta_fix[40,:] - sd_theta_fix[40,:], color='black', linewidth=0.6)\n",
    "    plt.plot(np.arange(129), mean_theta_fix[40,:] + sd_theta_fix[40,:], color='black', linewidth=0.6)\n",
    "    plt.fill_between(np.arange(0,129), mean_theta_fix[40,:] - sd_theta_fix[40,:], mean_theta_fix[40,:] + sd_theta_fix[40,:], color='lightgrey', alpha=1)\n",
    "    #plt.plot(np.arange(129), true_theta_fix[40,:])\n",
    "    plt.xlim(0,129)\n",
    "    plt.gca().axes.get_xaxis().set_visible(False)\n",
    "    plt.gca().axes.get_yaxis().set_visible(False)\n",
    "    plt.gca().axes.set_aspect(50)\n",
    "    plt.savefig('horizontal_uncertainty.png', dpi=200)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "    plt.plot(np.arange(0,129), mean_theta_fix[:,30], color='black', linewidth=4)\n",
    "    plt.plot(np.arange(129), mean_theta_fix[:,30] - sd_theta_fix[:,30], color='black', linewidth=0.6)\n",
    "    plt.plot(np.arange(129), mean_theta_fix[:,30] + sd_theta_fix[:,30], color='black', linewidth=0.6)\n",
    "    plt.fill_between(np.arange(0,129), mean_theta_fix[:,30] - sd_theta_fix[:,30], mean_theta_fix[:,30] + sd_theta_fix[:,30], color='lightgrey', alpha=1)\n",
    "    plt.xlim(0,129)\n",
    "    plt.gca().axes.get_xaxis().set_visible(False)\n",
    "    plt.gca().axes.get_yaxis().set_visible(False)\n",
    "    plt.gca().axes.set_aspect(50)\n",
    "\n",
    "    plt.savefig('vertical_uncertainty.png', dpi=200)\n",
    "    plt.show()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e97dfdfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Get MSE and SBC for Darcy ####\n",
    "\n",
    "path = out_dir/'darcy_experiment_snr_30/summary.csv'\n",
    "data = pd.read_csv(path, usecols=[1, 2, 3, 4, 5, 6, 7, 8])\n",
    "\n",
    "# Get the different methods \n",
    "methods = data['method'].unique()\n",
    "n_sim = data['nsim'].unique()\n",
    "\n",
    "# Calculate mean and SE for each method and each number of simulations\n",
    "mses_results = {}\n",
    "sbc_results = {}\n",
    "logprob_results = {}\n",
    "mses_mean = np.zeros((methods.shape[0], n_sim.shape[0]))\n",
    "mses_SE = np.zeros((methods.shape[0], n_sim.shape[0]))\n",
    "sbcs_mean = np.zeros((methods.shape[0], n_sim.shape[0]))\n",
    "sbcs_SE = np.zeros((methods.shape[0], n_sim.shape[0]))\n",
    "logprob_mean = np.zeros((methods.shape[0], n_sim.shape[0]))\n",
    "logprob_SE = np.zeros((methods.shape[0], n_sim.shape[0]))\n",
    "\n",
    "for ii, method in enumerate(methods):\n",
    "\n",
    "    mses_results[method] = {}\n",
    "    sbc_results[method] = {}\n",
    "    logprob_results[method] = {}\n",
    "\n",
    "    for kk, nsim in enumerate(n_sim):\n",
    "\n",
    "\n",
    "        mses_results[method][nsim] = []\n",
    "        sbc_results[method][nsim] = []\n",
    "        logprob_results[method][nsim] = []\n",
    "\n",
    "        temp_mses = data[(data['method'] == method) & (data['nsim'] == nsim)]['predictive_mses']\n",
    "        temp_sbcs = data[(data['method'] == method) & (data['nsim'] == nsim)]['sbcs']\n",
    "        temp_logprob = data[(data['method'] == method) & (data['nsim'] == nsim)]['posterior_log_probs']\n",
    "        \n",
    "        for ll in range(temp_mses.shape[0]):\n",
    "            s_mses = temp_mses.iloc[ll]\n",
    "            s_mses_clean = s_mses.replace('[', '').replace(']', '')\n",
    "            list_mses = [float(x) for x in s_mses_clean.split() if x != 'nan']\n",
    "            filtered_list_mses = [x for x in list_mses if x <= 1] # filter out unreasonable high numbers for MSE\n",
    "            mses_results[method][nsim].extend(filtered_list_mses)\n",
    "\n",
    "        mses_mean[ii, kk] = np.mean(np.array(mses_results[method][nsim]))\n",
    "        mses_SE[ii, kk] = np.std(np.array(mses_results[method][nsim]))/np.sqrt(len(mses_results[method][nsim]))\n",
    "\n",
    "        for mm in range(temp_sbcs.shape[0]):\n",
    "            s_sbcs = temp_sbcs.iloc[mm]\n",
    "            s_sbcs_clean = s_sbcs.replace('[', '').replace(']', '')\n",
    "\n",
    "            list_sbcs = [float(x) for x in s_sbcs_clean.split() if x != 'nan']\n",
    "            sbc_results[method][nsim].extend(list_sbcs)\n",
    "\n",
    "        sbcs_mean[ii, kk] = np.mean(np.array(sbc_results[method][nsim]))\n",
    "        sbcs_SE[ii, kk] = np.std(np.array(sbc_results[method][nsim]))/np.sqrt(len(sbc_results[method][nsim]))\n",
    "\n",
    "        if method in [\"spectral_NPE\", \"baseline_spectral_FMPE\"]:\n",
    "            continue\n",
    "            \n",
    "        else:\n",
    "            for nn in range(temp_logprob.shape[0]):\n",
    "                s_logprob = temp_logprob.iloc[mm]\n",
    "                s_logprob_clean = s_logprob.replace('[', '').replace(']', '')\n",
    "                list_logprob = [float(x) for x in s_logprob_clean.split() if x != 'nan']\n",
    "                logprob_results[method][nsim].extend(list_logprob)\n",
    "\n",
    "            logprob_mean[ii, kk] = np.mean(np.array(logprob_results[method][nsim]))\n",
    "            logprob_SE[ii, kk] = np.std(np.array(logprob_results[method][nsim]))/np.sqrt(len(logprob_results[method][nsim]))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3b6de9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Calculate MSE for inferred thetas to ground truth ####\n",
    "\n",
    "# Compute the posterior mean for each method and each theta\n",
    "mean_posterior_FNOfix = np.mean(np.exp(posterior_FNOfix), axis=0)\n",
    "posterior_MSE_FNOfix = np.sum((mean_posterior_FNOfix - theta_FNOfix[:,0,:,:])**2, axis=(1,2))/posterior_FNOfix.shape[-1]**2\n",
    "posterior_MSE_FNOfix = np.mean(posterior_MSE_FNOfix)\n",
    "print(posterior_MSE_FNOfix)\n",
    "\n",
    "mean_posterior_FNOflex = np.mean(np.exp(posterior_FNOflex), axis=0)\n",
    "posterior_MSE_FNOflex = np.sum((mean_posterior_FNOflex - theta_FNOflex[:,0,:,:])**2, axis=(1,2))/posterior_FNOflex.shape[-1]**2\n",
    "posterior_MSE_FNOflex = np.mean(posterior_MSE_FNOflex)\n",
    "print(posterior_MSE_FNOflex)\n",
    "\n",
    "\n",
    "mean_posterior_specNPE = np.mean(np.exp(posterior_specNPE), axis=0)\n",
    "posterior_MSE_specNPE = np.sum((mean_posterior_specNPE - theta_FNOflex[:,0,:,:])**2, axis=(1,2))/posterior_specNPE.shape[-1]**2\n",
    "posterior_MSE_specNPE = np.mean(posterior_MSE_specNPE)\n",
    "print(posterior_MSE_specNPE)\n",
    "\n",
    "mean_posterior_specFMPE = np.mean(np.exp(posterior_specFMPE), axis=0)\n",
    "posterior_MSE_specFMPE = np.sum((mean_posterior_specFMPE - theta_FNOflex[:,0,:,:])**2, axis=(1,2))/posterior_specFMPE.shape[-1]**2\n",
    "posterior_MSE_specFMPE = np.mean(posterior_MSE_specFMPE)\n",
    "print(posterior_MSE_specFMPE)\n",
    "\n",
    "mean_posterior_rawFMPE = np.mean(np.exp(posterior_rawFMPE), axis=0)\n",
    "posterior_MSE_rawFMPE = np.sum((mean_posterior_rawFMPE - theta_FNOflex[:,0,:,:])**2, axis=(1,2))/posterior_rawFMPE.shape[-1]**2\n",
    "posterior_MSE_rawFMPE = np.mean(posterior_MSE_rawFMPE)\n",
    "print(posterior_MSE_rawFMPE)\n",
    "\n",
    "# Compute the MSE of the posterior mean to the ground truth \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d45f94b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Figure for appendix comparing posterior samples between the different methods ####\n",
    "\n",
    "with plt.rc_context(fname='matplotlibrc'):\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    fig, axs = plt.subplots(6, 6, figsize=(8, 8))\n",
    "    fig.subplots_adjust(wspace=0.01)\n",
    "    fig.subplots_adjust(hspace=0.1)\n",
    "\n",
    "    # Plot ground truth theta\n",
    "    im1 = axs[0,0].imshow(np.exp(theta_FNOflex[6,0,:,:]), vmin = 0.2, vmax = 1.8,)\n",
    "    im2 = axs[0,1].imshow(np.exp(theta_FNOflex[7,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im3 = axs[0,2].imshow(np.exp(theta_FNOflex[2,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im4 = axs[0,3].imshow(np.exp(theta_FNOflex[3,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im5 = axs[0,4].imshow(np.exp(theta_FNOflex[4,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im6 = axs[0,5].imshow(np.exp(theta_FNOflex[5,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "\n",
    "    axs[0,0].set_ylabel(f'Ground truth', fontsize=9)\n",
    "    axs[0,0].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,0].set_xlabel(r'$\\theta_1$', labelpad=10, fontsize=10)\n",
    "\n",
    "    # Plot theta from FNO-FMPE fix\n",
    "    im7 = axs[1,0].imshow(np.exp(posterior_FNOfix[5,6,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im8 = axs[1,1].imshow(np.exp(posterior_FNOfix[2,7,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im9 = axs[1,2].imshow(np.exp(posterior_FNOfix[2,2,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im10 = axs[1,3].imshow(np.exp(posterior_FNOfix[0,3,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im11= axs[1,4].imshow(np.exp(posterior_FNOfix[0,4,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im12 = axs[1,5].imshow(np.exp(posterior_FNOfix[1,5,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "\n",
    "    axs[1,0].set_ylabel(f'FNOPE (fix)', fontsize=9)\n",
    "    axs[0,1].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,1].set_xlabel(r'$\\theta_2$', labelpad=10, fontsize=10)\n",
    "\n",
    "    # Plot theta from FNO-FMPE flex\n",
    "    im13 = axs[2,0].imshow(np.exp(posterior_FNOflex[1,6,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im14 = axs[2,1].imshow(np.exp(posterior_FNOflex[6,7,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im15 = axs[2,2].imshow(np.exp(posterior_FNOflex[3,2,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im16 = axs[2,3].imshow(np.exp(posterior_FNOflex[1,3,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im17= axs[2,4].imshow(np.exp(posterior_FNOflex[4,4,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im18 = axs[2,5].imshow(np.exp(posterior_FNOflex[3,5,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "\n",
    "    axs[2,0].set_ylabel(f'FNOPE', fontsize=9)\n",
    "    axs[0,2].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,2].set_xlabel(r'$\\theta_3$', labelpad=10, fontsize=10)\n",
    "\n",
    "    # Plot theta from spectral NPE\n",
    "    im19 = axs[3,0].imshow(np.exp(posterior_specNPE[0,6,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im20 = axs[3,1].imshow(np.exp(posterior_specNPE[0,7,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im21 = axs[3,2].imshow(np.exp(posterior_specNPE[0,2,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im22 = axs[3,3].imshow(np.exp(posterior_specNPE[0,3,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im23= axs[3,4].imshow(np.exp(posterior_specNPE[0,4,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im24 = axs[3,5].imshow(np.exp(posterior_specNPE[0,5,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "\n",
    "    axs[3,0].set_ylabel(f'NPE (spectral)', fontsize=9)\n",
    "    axs[0,3].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,3].set_xlabel(r'$\\theta_4$', labelpad=10, fontsize=10)\n",
    "\n",
    "    # Plot theta from spectral FMPE\n",
    "    im25 = axs[4,0].imshow(np.exp(posterior_specFMPE[0,6,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im26 = axs[4,1].imshow(np.exp(posterior_specFMPE[0,7,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im27 = axs[4,2].imshow(np.exp(posterior_specFMPE[0,2,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im28 = axs[4,3].imshow(np.exp(posterior_specFMPE[0,3,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im29= axs[4,4].imshow(np.exp(posterior_specFMPE[0,4,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im30 = axs[4,5].imshow(np.exp(posterior_specFMPE[0,5,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "\n",
    "    axs[4,0].set_ylabel(f'FMPE (spectral)', fontsize=9)\n",
    "    axs[0,4].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,4].set_xlabel(r'$\\theta_5$', labelpad=10, fontsize=10)\n",
    "\n",
    "    # Plot theta from raw FMPE\n",
    "    im31 = axs[5,0].imshow(np.exp(posterior_rawFMPE[0,6,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im32 = axs[5,1].imshow(np.exp(posterior_rawFMPE[0,7,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im33 = axs[5,2].imshow(np.exp(posterior_rawFMPE[0,2,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im34 = axs[5,3].imshow(np.exp(posterior_rawFMPE[0,3,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im35= axs[5,4].imshow(np.exp(posterior_rawFMPE[0,4,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im36 = axs[5,5].imshow(np.exp(posterior_rawFMPE[0,5,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "\n",
    "    axs[5,0].set_ylabel(f'FMPE (raw)', fontsize=9)\n",
    "    axs[0,5].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,5].set_xlabel(r'$\\theta_6$', labelpad=10, fontsize=10)\n",
    "\n",
    "    for ax in axs.flat:\n",
    "        #ax.axis('off')\n",
    "        # Hide spines\n",
    "        for spine in ax.spines.values():\n",
    "            spine.set_visible(False)\n",
    "    \n",
    "        # Hide ticks and tick labels\n",
    "        ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)\n",
    "\n",
    "    # Save figure \n",
    "    plt.savefig('darcy_plots/posterior_samples_Darcy.pdf', format='pdf', bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67f5c267",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Plot mean posteriors for appendix ####\n",
    "\n",
    "with plt.rc_context(fname='matplotlibrc'):\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    fig, axs = plt.subplots(6, 6, figsize=(8, 8))\n",
    "    fig.subplots_adjust(wspace=0.01)\n",
    "    fig.subplots_adjust(hspace=0.1)\n",
    "\n",
    "    # Plot ground truth theta\n",
    "    im1 = axs[0,0].imshow(np.exp(theta_FNOflex[6,0,:,:]), vmin = 0.2, vmax = 1.8,)\n",
    "    im2 = axs[0,1].imshow(np.exp(theta_FNOflex[7,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im3 = axs[0,2].imshow(np.exp(theta_FNOflex[2,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im4 = axs[0,3].imshow(np.exp(theta_FNOflex[3,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im5 = axs[0,4].imshow(np.exp(theta_FNOflex[4,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    im6 = axs[0,5].imshow(np.exp(theta_FNOflex[5,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "\n",
    "    axs[0,0].set_ylabel(f'Ground truth', fontsize=9)\n",
    "    axs[0,0].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,0].set_xlabel(r'$\\theta_1$', labelpad=10, fontsize=10)\n",
    "\n",
    "    # Plot theta from FNO-FMPE fix\n",
    "    im7 = axs[1,0].imshow(np.mean(np.exp(posterior_FNOfix[:,6,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im8 = axs[1,1].imshow(np.mean(np.exp(posterior_FNOfix[:,7,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im9 = axs[1,2].imshow(np.mean(np.exp(posterior_FNOfix[:,2,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im10 = axs[1,3].imshow(np.mean(np.exp(posterior_FNOfix[:,3,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im11= axs[1,4].imshow(np.mean(np.exp(posterior_FNOfix[:,4,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im12 = axs[1,5].imshow(np.mean(np.exp(posterior_FNOfix[:,5,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "\n",
    "    axs[1,0].set_ylabel(f'FNOPE (fix)', fontsize=9)\n",
    "    axs[0,1].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,1].set_xlabel(r'$\\theta_2$', labelpad=10, fontsize=10)\n",
    "\n",
    "    # Plot theta from FNO-FMPE flex\n",
    "    im13 = axs[2,0].imshow(np.mean(np.exp(posterior_FNOflex[:,6,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im14 = axs[2,1].imshow(np.mean(np.exp(posterior_FNOflex[:,7,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im15 = axs[2,2].imshow(np.mean(np.exp(posterior_FNOflex[:,2,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im16 = axs[2,3].imshow(np.mean(np.exp(posterior_FNOflex[:,3,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im17= axs[2,4].imshow(np.mean(np.exp(posterior_FNOflex[:,4,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im18 = axs[2,5].imshow(np.mean(np.exp(posterior_FNOflex[:,5,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "\n",
    "    axs[2,0].set_ylabel(f'FNOPE', fontsize=9)\n",
    "    axs[0,2].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,2].set_xlabel(r'$\\theta_3$', labelpad=10, fontsize=10)\n",
    "\n",
    "    # Plot theta from spectral NPE\n",
    "    im19 = axs[3,0].imshow(np.mean(np.exp(posterior_specNPE[:,6,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im20 = axs[3,1].imshow(np.mean(np.exp(posterior_specNPE[:,7,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im21 = axs[3,2].imshow(np.mean(np.exp(posterior_specNPE[:,2,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im22 = axs[3,3].imshow(np.mean(np.exp(posterior_specNPE[:,3,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im23= axs[3,4].imshow(np.mean(np.exp(posterior_specNPE[:,4,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im24 = axs[3,5].imshow(np.mean(np.exp(posterior_specNPE[:,5,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "\n",
    "    axs[3,0].set_ylabel(f'NPE (spectral)', fontsize=9)\n",
    "    axs[0,3].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,3].set_xlabel(r'$\\theta_4$', labelpad=10, fontsize=10)\n",
    "\n",
    "    # Plot theta from spectral FMPE\n",
    "    im25 = axs[4,0].imshow(np.mean(np.exp(posterior_specFMPE[:,6,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im26 = axs[4,1].imshow(np.mean(np.exp(posterior_specFMPE[:,7,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im27 = axs[4,2].imshow(np.mean(np.exp(posterior_specFMPE[:,2,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im28 = axs[4,3].imshow(np.mean(np.exp(posterior_specFMPE[:,3,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im29= axs[4,4].imshow(np.mean(np.exp(posterior_specFMPE[:,4,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im30 = axs[4,5].imshow(np.mean(np.exp(posterior_specFMPE[:,5,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "\n",
    "    axs[4,0].set_ylabel(f'FMPE (spectral)', fontsize=9)\n",
    "    axs[0,4].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,4].set_xlabel(r'$\\theta_5$', labelpad=10, fontsize=10)\n",
    "\n",
    "    # Plot theta from raw FMPE\n",
    "    im31 = axs[5,0].imshow(np.mean(np.exp(posterior_rawFMPE[:,6,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im32 = axs[5,1].imshow(np.mean(np.exp(posterior_rawFMPE[:,7,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im33 = axs[5,2].imshow(np.mean(np.exp(posterior_rawFMPE[:,2,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im34 = axs[5,3].imshow(np.mean(np.exp(posterior_rawFMPE[:,3,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im35= axs[5,4].imshow(np.mean(np.exp(posterior_rawFMPE[:,4,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "    im36 = axs[5,5].imshow(np.mean(np.exp(posterior_rawFMPE[:,5,:,:]), axis=0), vmin = 0.2, vmax = 1.8)\n",
    "\n",
    "    axs[5,0].set_ylabel(f'FMPE (raw)', fontsize=9)\n",
    "    axs[0,5].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,5].set_xlabel(r'$\\theta_6$', labelpad=10, fontsize=10)\n",
    "\n",
    "    for ax in axs.flat:\n",
    "        #ax.axis('off')\n",
    "        # Hide spines\n",
    "        for spine in ax.spines.values():\n",
    "            spine.set_visible(False)\n",
    "    \n",
    "        # Hide ticks and tick labels\n",
    "        ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)\n",
    "\n",
    "    cbar_mappable = mpl.cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=0.2, vmax=1.8), cmap='viridis')\n",
    "\n",
    "    # Add a colorbar to the right of the entire figure\n",
    "    cbar_ax = fig.add_axes([0.91, 0.112, 0.015, 0.768])  # [left, bottom, width, height]\n",
    "    cbar = fig.colorbar(cbar_mappable, cax=cbar_ax, orientation='vertical')\n",
    "    #cbar.set_label('Permeability', fontsize=9)\n",
    "    cbar.ax.tick_params(labelsize=9)\n",
    "\n",
    "    #for spine in cbar.ax.spines.values():\n",
    "        #spine.set_linewidth(0.5)\n",
    "\n",
    "    # Save figure \n",
    "    plt.savefig('darcy_plots/posterior_mean_Darcy.pdf', format='pdf', bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10ffdfc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Plot sd of posteriors for appendix ####\n",
    "\n",
    "with plt.rc_context(fname='matplotlibrc'):\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    fig, axs = plt.subplots(5, 6, figsize=(8, 6.66))\n",
    "    fig.subplots_adjust(wspace=0.01)\n",
    "    fig.subplots_adjust(hspace=0.1)\n",
    "\n",
    "    # Plot ground truth theta\n",
    "    #im1 = axs[0,0].imshow(np.exp(theta_FNOflex[6,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    #im2 = axs[0,1].imshow(np.exp(theta_FNOflex[7,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    #im3 = axs[0,2].imshow(np.exp(theta_FNOflex[2,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "   # im4 = axs[0,3].imshow(np.exp(theta_FNOflex[3,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    #im5 = axs[0,4].imshow(np.exp(theta_FNOflex[4,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    #im6 = axs[0,5].imshow(np.exp(theta_FNOflex[5,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "\n",
    "    #axs[0,0].set_ylabel(f'Ground truth', fontsize=9)\n",
    "\n",
    "\n",
    "    # Plot theta from FNO-FMPE fix\n",
    "    im7 = axs[0,0].imshow(np.std(np.exp(posterior_FNOfix[:,6,:,:]), axis=0), vmin = 0.0, vmax = 0.5)\n",
    "    im8 = axs[0,1].imshow(np.std(np.exp(posterior_FNOfix[:,7,:,:]), axis=0), vmin = 0.0, vmax = 0.5)\n",
    "    im9 = axs[0,2].imshow(np.std(np.exp(posterior_FNOfix[:,2,:,:]), axis=0), vmin = 0.0, vmax = 0.5)\n",
    "    im10 = axs[0,3].imshow(np.std(np.exp(posterior_FNOfix[:,3,:,:]), axis=0), vmin = 0.0, vmax = 0.5)\n",
    "    im11= axs[0,4].imshow(np.std(np.exp(posterior_FNOfix[:,4,:,:]), axis=0), vmin = 0.0, vmax = 0.5)\n",
    "    im12 = axs[0,5].imshow(np.std(np.exp(posterior_FNOfix[:,5,:,:]), axis=0), vmin = 0.0, vmax = 0.5)\n",
    "\n",
    "    axs[0,0].set_ylabel(f'FNOPE (fix)', fontsize=9)\n",
    "\n",
    "    axs[0,0].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,0].set_xlabel(r'$\\theta_1$', labelpad=10, fontsize=10)\n",
    "    axs[0,1].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,1].set_xlabel(r'$\\theta_2$', labelpad=10, fontsize=10)\n",
    "    axs[0,2].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,2].set_xlabel(r'$\\theta_3$', labelpad=10, fontsize=10)\n",
    "    axs[0,3].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,3].set_xlabel(r'$\\theta_4$', labelpad=10, fontsize=10)\n",
    "    axs[0,4].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,4].set_xlabel(r'$\\theta_5$', labelpad=10, fontsize=10)\n",
    "    axs[0,5].xaxis.set_label_position('top')  # Move label to top\n",
    "    axs[0,5].set_xlabel(r'$\\theta_6$', labelpad=10, fontsize=10)\n",
    "\n",
    "    # Plot theta from FNO-FMPE flex\n",
    "    im13 = axs[1,0].imshow(np.std(np.exp(posterior_FNOflex[:,6,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im14 = axs[1,1].imshow(np.std(np.exp(posterior_FNOflex[:,7,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im15 = axs[1,2].imshow(np.std(np.exp(posterior_FNOflex[:,2,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im16 = axs[1,3].imshow(np.std(np.exp(posterior_FNOflex[:,3,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im17= axs[1,4].imshow(np.std(np.exp(posterior_FNOflex[:,4,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im18 = axs[1,5].imshow(np.std(np.exp(posterior_FNOflex[:,5,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "\n",
    "    axs[1,0].set_ylabel(f'FNOPE', fontsize=9)\n",
    "\n",
    "\n",
    "    # Plot theta from spectral NPE\n",
    "    im19 = axs[2,0].imshow(np.std(np.exp(posterior_specNPE[:,6,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im20 = axs[2,1].imshow(np.std(np.exp(posterior_specNPE[:,7,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im21 = axs[2,2].imshow(np.std(np.exp(posterior_specNPE[:,2,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im22 = axs[2,3].imshow(np.std(np.exp(posterior_specNPE[:,3,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im23= axs[2,4].imshow(np.std(np.exp(posterior_specNPE[:,4,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im24 = axs[2,5].imshow(np.std(np.exp(posterior_specNPE[:,5,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "\n",
    "    axs[2,0].set_ylabel(f'NPE (spectral)', fontsize=9)\n",
    "\n",
    "    # Plot theta from spectral FMPE\n",
    "    im25 = axs[3,0].imshow(np.std(np.exp(posterior_specFMPE[:,6,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im26 = axs[3,1].imshow(np.std(np.exp(posterior_specFMPE[:,7,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im27 = axs[3,2].imshow(np.std(np.exp(posterior_specFMPE[:,2,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im28 = axs[3,3].imshow(np.std(np.exp(posterior_specFMPE[:,3,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im29= axs[3,4].imshow(np.std(np.exp(posterior_specFMPE[:,4,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im30 = axs[3,5].imshow(np.std(np.exp(posterior_specFMPE[:,5,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "\n",
    "    axs[3,0].set_ylabel(f'FMPE (spectral)', fontsize=9)\n",
    "\n",
    "\n",
    "    # Plot theta from raw FMPE\n",
    "    im31 = axs[4,0].imshow(np.std(np.exp(posterior_rawFMPE[:,6,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im32 = axs[4,1].imshow(np.std(np.exp(posterior_rawFMPE[:,7,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im33 = axs[4,2].imshow(np.std(np.exp(posterior_rawFMPE[:,2,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im34 = axs[4,3].imshow(np.std(np.exp(posterior_rawFMPE[:,3,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im35= axs[4,4].imshow(np.std(np.exp(posterior_rawFMPE[:,4,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "    im36 = axs[4,5].imshow(np.std(np.exp(posterior_rawFMPE[:,5,:,:]), axis=0), vmin = 0.0, vmax = 0.6)\n",
    "\n",
    "    axs[4,0].set_ylabel(f'FMPE (raw)', fontsize=9)\n",
    "\n",
    "    for ax in axs.flat:\n",
    "        #ax.axis('off')\n",
    "        # Hide spines\n",
    "        for spine in ax.spines.values():\n",
    "            spine.set_visible(False)\n",
    "    \n",
    "        # Hide ticks and tick labels\n",
    "        ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)\n",
    "\n",
    "    cbar_mappable = mpl.cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=0.0, vmax=0.6), cmap='viridis')\n",
    "\n",
    "    # Add a colorbar to the right of the entire figure\n",
    "    cbar_ax = fig.add_axes([0.91, 0.112, 0.015, 0.768])  # [left, bottom, width, height]\n",
    "    cbar = fig.colorbar(cbar_mappable, cax=cbar_ax, orientation='vertical')\n",
    "    #cbar.set_label('Permeability', fontsize=9)\n",
    "    cbar.ax.tick_params(labelsize=9)\n",
    "\n",
    "    #for spine in cbar.ax.spines.values():\n",
    "        #spine.set_linewidth(0.5)\n",
    "\n",
    "    # Save figure \n",
    "    plt.savefig('darcy_plots/posterior_sd_Darcy.pdf', format='pdf', bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "761a952e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Calculate the minimal achievable SBC from uniform ranks ####\n",
    "\n",
    "num_posterior_samples = 100\n",
    "n_sbc = 50\n",
    "n_sbc_marginals = 50\n",
    "\n",
    "ranks = np.random.randint(0, num_posterior_samples, size=(n_sbc, n_sbc_marginals))\n",
    "\n",
    "coverage_values = torch.Tensor(ranks) / num_posterior_samples\n",
    "\n",
    "atcs = []\n",
    "absolute_atcs = []\n",
    "\n",
    "for dim_idx in range(coverage_values.shape[1]):\n",
    "    # calculate empirical CDF via cumsum and normalize\n",
    "    hist, alpha_grid = torch.histogram(\n",
    "    coverage_values[:, dim_idx], density=True, bins=30\n",
    "    )\n",
    "    # add 0 to the beginning of the ecp curve to match the alpha grid\n",
    "    ecp = torch.cat([torch.Tensor([0]), torch.cumsum(hist, dim=0) / hist.sum()])\n",
    "    atc = (ecp - alpha_grid).mean().item()\n",
    "    absolute_atc = (ecp - alpha_grid).abs().mean().item()\n",
    "    atcs.append(atc)\n",
    "    absolute_atcs.append(absolute_atc)\n",
    "\n",
    "atcs = torch.tensor(atcs)\n",
    "absolute_atcs = torch.tensor(absolute_atcs)\n",
    "\n",
    "mean_absolute_atc = absolute_atcs.mean().numpy()\n",
    "print(mean_absolute_atc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa1a836e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Try with svg utils ####\n",
    "\n",
    "methods_names = ['FNOPE (fix)', 'FNOPE', 'NPE (spectral)', 'FMPE (raw)', 'FMPE (spectral)']\n",
    "colors1 = ['#CA6702', '#9b2226', '#023e8a', '#00b4d8', '#0077b6']\n",
    "colors2 = ['#CA6702', '#9b2226', '#023e8a', '#00b4d8', '#0077b6']\n",
    "colors3 = ['#CA6702', '#9b2226', '#023e8a', '#00b4d8', '#0077b6']\n",
    "\n",
    "\n",
    "with plt.rc_context(fname=\"matplotlibrc\"):\n",
    "\n",
    "    plt.tight_layout()\n",
    "\n",
    "    fig, axs = plt.subplots(1, 3, figsize=(5.5206, 0.9))\n",
    "\n",
    "    fig.subplots_adjust(wspace=0.7)\n",
    "\n",
    "    zorder = np.array([5, 4, 3, 2, 1])\n",
    "\n",
    "    for mm in range(methods.shape[0]):\n",
    "\n",
    "        axs[0].errorbar(\n",
    "            n_sim,\n",
    "            mses_mean[mm, :],\n",
    "            yerr=mses_SE[mm, :],\n",
    "            fmt=\"o\",\n",
    "            linestyle=\"-\",\n",
    "            color=colors1.pop(0),\n",
    "            label=methods_names[mm],\n",
    "            zorder = zorder[mm]\n",
    "        )\n",
    "        axs[1].errorbar(\n",
    "            n_sim,\n",
    "            sbcs_mean[mm, :],\n",
    "            yerr=sbcs_SE[mm, :],\n",
    "            fmt=\"o\",\n",
    "            linestyle=\"-\",\n",
    "            color=colors2.pop(0),\n",
    "            label=methods_names[mm],\n",
    "            zorder = zorder[mm]\n",
    "        )\n",
    "        \n",
    "        if mm == 2 or mm == 4:\n",
    "            continue\n",
    "        else:\n",
    "            axs[2].errorbar(\n",
    "                n_sim,\n",
    "                logprob_mean[mm, :],\n",
    "                yerr=logprob_SE[mm, :],\n",
    "                fmt=\"o\",\n",
    "                linestyle=\"-\",\n",
    "                color=colors3[mm],\n",
    "                label=methods_names[mm],\n",
    "                zorder = zorder[mm]\n",
    "                )\n",
    "\n",
    "    axs[0].set_xscale(\"log\")\n",
    "    axs[0].set_xlabel(\"# simulations\")\n",
    "    axs[0].set_ylabel(\"MSE\")\n",
    "    axs[0].set_xticks(n_sim)\n",
    "    axs[0].set_xlim([890, 11220])\n",
    "    axs[0].set_yticks([1e-4, 3e-4])\n",
    "    axs[0].set_ylim([0.9e-4, 4e-4])\n",
    "    formatter = mticker.ScalarFormatter(useMathText=True)\n",
    "    formatter.set_scientific(True)\n",
    "    formatter.set_powerlimits((0, 0))\n",
    "    axs[0].yaxis.set_major_formatter(formatter)\n",
    "    axs[0].minorticks_off()\n",
    "    axs[0].spines[\"left\"].set_position((\"outward\", 5))  # Move y-axis slightly left\n",
    "    axs[0].spines[\"bottom\"].set_position((\"outward\", 5))\n",
    "\n",
    "    axs[1].hlines(mean_absolute_atc, 1e3, 1e4, linestyle=':', color='black', label='lower bound', linewidth=2.5)\n",
    "    axs[1].set_xscale(\"log\")\n",
    "    axs[1].set_xlabel(\"# simulations\")\n",
    "    axs[1].set_xlim([890, 11220])\n",
    "    axs[1].set_ylabel(\"SBC EoD\")\n",
    "    axs[1].set_xticks(n_sim)\n",
    "    axs[1].set_ylim([0.0, 0.2])\n",
    "    axs[1].set_yticks([0, 0.2])\n",
    "    axs[1].minorticks_off()\n",
    "    axs[1].spines[\"left\"].set_position((\"outward\", 5))  # Move y-axis slightly left\n",
    "    axs[1].spines[\"bottom\"].set_position((\"outward\", 5))\n",
    "\n",
    "    axs[2].set_xscale(\"log\")\n",
    "    axs[2].set_xlabel(\"# simulations\")\n",
    "    axs[2].set_xlim([890, 11220])\n",
    "    axs[2].set_ylabel(\"Posterior log prob\")\n",
    "    axs[2].set_xticks(n_sim)\n",
    "    axs[2].set_yticks([-4, 2])\n",
    "    axs[2].set_ylim([-4.5, 3])\n",
    "    #formatter = mticker.ScalarFormatter(useMathText=True)\n",
    "    #formatter.set_scientific(True)\n",
    "    #formatter.set_powerlimits((0, 0))\n",
    "    #axs[2].yaxis.set_major_formatter(formatter)\n",
    "    axs[2].minorticks_off()\n",
    "    axs[2].spines[\"left\"].set_position((\"outward\", 5))  # Move y-axis slightly left\n",
    "    axs[2].spines[\"bottom\"].set_position((\"outward\", 5))\n",
    "\n",
    "    # Add legends\n",
    "    # Get handles and labels\n",
    "    handles, labels = axs[1].get_legend_handles_labels()\n",
    "\n",
    "    custom_order = [1, 2, 3, 5, 4]  # New order of methods_names\n",
    "\n",
    "    # Reorder handles and labels based on custom order\n",
    "    handles_global = [handles[i] for i in custom_order]\n",
    "    labels_global = [labels[i] for i in custom_order]\n",
    "\n",
    "    handle_ideal = [handles[0]]\n",
    "    label_ideal = [labels[0]]\n",
    "\n",
    "\n",
    "    axs[1].legend(handles=handle_ideal,\n",
    "                  labels=label_ideal,\n",
    "                  loc='upper right',\n",
    "                  bbox_to_anchor=(1.1, 1.1),\n",
    "                  frameon=False,\n",
    "                  handlelength=1.5\n",
    "                  )\n",
    "\n",
    "    \n",
    "\n",
    "    # Create the legend with the new order\n",
    "    fig.legend(handles=handles_global,\n",
    "                  labels=labels_global,\n",
    "                  loc=\"upper center\",\n",
    "                  bbox_to_anchor=(0.5, -0.45),\n",
    "                  ncol=5\n",
    "                  )\n",
    "\n",
    "\n",
    "    plt.savefig('darcy_plots/metricsDarcy.svg', format=\"svg\", bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9779c956",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Plot posterior samples in different plot ####\n",
    "\n",
    "with plt.rc_context(fname=\"matplotlibrc\"):\n",
    "\n",
    "    plt.tight_layout()\n",
    "\n",
    "    fig, axs = plt.subplots(2, 6, figsize=(5.5206, 1.8))\n",
    "\n",
    "    fig.subplots_adjust(wspace=0.1, hspace=0.09)\n",
    "\n",
    "    axs[0,0].imshow(np.exp(theta_FNOflex[0,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    axs[0,1].imshow(np.exp(posterior_FNOfix[4,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    axs[0,2].imshow(np.exp(posterior_FNOflex[13,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    axs[0,3].imshow(np.exp(posterior_specNPE[3,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    axs[0,4].imshow(np.exp(posterior_specFMPE[3,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    axs[0,5].imshow(np.exp(posterior_rawFMPE[3,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "\n",
    "    axs[1,0].imshow(np.exp(theta_FNOflex[1,0,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    axs[1,1].imshow(np.exp(posterior_FNOfix[0,1,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    axs[1,2].imshow(np.exp(posterior_FNOflex[0,1,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    axs[1,3].imshow(np.exp(posterior_specNPE[0,1,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    axs[1,4].imshow(np.exp(posterior_specFMPE[0,1,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "    axs[1,5].imshow(np.exp(posterior_rawFMPE[0,1,:,:]), vmin = 0.2, vmax = 1.8)\n",
    "\n",
    "    axs[0,0].set_ylabel(r'$\\theta_1$', fontsize=8)\n",
    "    axs[0,0].xaxis.set_label_position('top')  \n",
    "    axs[0,0].set_xlabel(f'Ground truth', fontsize=7)\n",
    "    axs[0,1].xaxis.set_label_position('top')  \n",
    "    axs[0,1].set_xlabel(f'FNOPE (fix)', fontsize=7)\n",
    "    axs[0,2].xaxis.set_label_position('top') \n",
    "    axs[0,2].set_xlabel(f'FNOPE', fontsize=7)\n",
    "    axs[0,3].xaxis.set_label_position('top')  \n",
    "    axs[0,3].set_xlabel(f'NPE (spectral)', fontsize=7)\n",
    "    axs[0,4].xaxis.set_label_position('top')  \n",
    "    axs[0,4].set_xlabel(f'FMPE (spectral)', fontsize=7)\n",
    "    axs[0,5].xaxis.set_label_position('top')  \n",
    "    axs[0,5].set_xlabel(f'FMPE (raw)', fontsize=7)\n",
    "\n",
    "    axs[1,0].set_ylabel(r'$\\theta_2$', fontsize=8)\n",
    "    axs[1,0].xaxis.set_label_position('top')  \n",
    "\n",
    "    for ax in axs.flat:\n",
    "        #ax.axis('off')\n",
    "        ax.set_aspect('equal', adjustable='box')\n",
    "        # Hide spines\n",
    "        for spine in ax.spines.values():\n",
    "            spine.set_visible(False)\n",
    "        \n",
    "        # Hide ticks and tick labels\n",
    "        ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)\n",
    "\n",
    "    for spine in axs[0, 1].spines.values():\n",
    "        spine.set_visible(True)\n",
    "        spine.set_edgecolor('#CA6702')  \n",
    "        spine.set_linewidth(2)      \n",
    "\n",
    "    for spine in axs[1, 1].spines.values():\n",
    "        spine.set_visible(True)\n",
    "        spine.set_edgecolor('#CA6702')  \n",
    "        spine.set_linewidth(2)  \n",
    "\n",
    "    for spine in axs[0, 2].spines.values():\n",
    "        spine.set_visible(True)\n",
    "        spine.set_edgecolor('#9b2226')  \n",
    "        spine.set_linewidth(2)   \n",
    "\n",
    "    for spine in axs[1, 2].spines.values():\n",
    "        spine.set_visible(True)\n",
    "        spine.set_edgecolor('#9b2226')  \n",
    "        spine.set_linewidth(2)   \n",
    "\n",
    "    plt.savefig('darcy_plots/samplesDarcy_colored.svg', format=\"svg\", bbox_inches=\"tight\")\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b456da3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Patch svgs together ####\n",
    "\n",
    "#base_path = \"ice_plots/\"\n",
    "kwargs_text = {\"size\": \"7pt\", \"font\": \"Arial\", \"weight\": \"800\"}\n",
    "\n",
    "\n",
    "# create new SVG figure\n",
    "fig = sg.SVGFigure(\"14cm\", \"5cm\")\n",
    "\n",
    "# load matpotlib-generated figures\n",
    "fig0 = sg.fromfile('darcy_plots/samplesDarcy_colored.svg')\n",
    "fig1 = sg.fromfile('darcy_plots/metricsDarcy.svg')\n",
    "\n",
    "\n",
    "# get the plot objects\n",
    "plot0 = fig0.getroot()\n",
    "plot1 = fig1.getroot()\n",
    "\n",
    "# get sizes\n",
    "size0 = get_size_tuple(fig0)\n",
    "size1 = get_size_tuple(fig1)\n",
    "\n",
    "# a: posterior samples\n",
    "plot0.moveto(20, 15)\n",
    "\n",
    "# b: metrics\n",
    "plot1.moveto(10, size0[1] + 20)\n",
    "\n",
    "# add text labels\n",
    "txt0 = sg.TextElement(8, 25, \"a\", **kwargs_text)\n",
    "txt1 = sg.TextElement(8, size0[1] + 25, \"b\", **kwargs_text)\n",
    "txt2 = sg.TextElement(120, size0[1] + 25, \"c\", **kwargs_text)\n",
    "txt3 = sg.TextElement(241, size0[1] + 25, \"d\", **kwargs_text)\n",
    "\n",
    "\n",
    "# append plots and labels to figure\n",
    "fig.append(\n",
    "    [\n",
    "        plot0,\n",
    "        plot1\n",
    "    ]\n",
    ")\n",
    "fig.append([txt0, txt1, txt2, txt3])\n",
    "\n",
    "fig.save('darcy_plots/darcy_joint_colored.svg')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcf85273",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !/Applications/Inkscape.app/Contents/MacOS/inkscape darcy_plots/darcy_joint_colored.svg --export-area-drawing --export-type=pdf --export-filename=darcy_plots/darcy_joint_colored.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3edea60",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fourier_nets",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
