{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a82cd7a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import pickle\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.gridspec import GridSpec\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "from utils.misc import read_pickle, get_data_dir, get_output_dir, get_project_root\n",
    "from plotting_utils import colors,method_names, float_to_power_of_ten, get_size_tuple,cm2inch\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "data_dir = get_data_dir()\n",
    "output_dir = get_output_dir()\n",
    "root_dir = get_project_root()\n",
    "\n",
    "simformer_idx = method_names.index(\"simformer\")\n",
    "simformer_color = colors[simformer_idx]\n",
    "fno_idx = method_names.index(\"FNOPE\")\n",
    "fnope_color = colors[fno_idx]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac166618",
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = [\"fno\",\"simformer\"]\n",
    "nsims = [1_000, 10_000, 100_000] #number of simulations of run\n",
    "eval_points = [20, 40] #whether we condition on 20 or 40 points\n",
    "init_conditions = [\"True\", \"False\"] #Whether we use the initial condition estimate from simformer for the predictive simulations (for both methods) or the prior\n",
    "out_dir = get_output_dir()\n",
    "experiment_folder = Path(out_dir/\"sir_experiment\" / \"FNO_FMPE\")\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54e326b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_pickle(experiment_folder/\"summary.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a0c5923",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We have 4 plots we can generate for the choices we have\n",
    "n_cond = 40 #20 or 40\n",
    "use_simformer_init = True #True or False\n",
    "#which of the 100 observations will be plotted\n",
    "obs_index = 56 #40, True\n",
    "# obs_index = 56 #40, False\n",
    "# obs_index = 13 #20, True\n",
    "# obs_index = 13 #20, False\n",
    "results = df[(df[\"eval_num\"] == n_cond) & (df[\"simformer_initial_condition\"] == str(use_simformer_init))]\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "369dd18a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "sbc_results = {}\n",
    "sbc_mean = np.zeros((len(methods), len(nsims)))\n",
    "sbc_SE = np.zeros((len(methods), len(nsims)))\n",
    "\n",
    "tarp_results = {}\n",
    "tarp_mean = np.zeros((len(methods), len(nsims)))\n",
    "tarp_SE = np.zeros((len(methods), len(nsims)))\n",
    "\n",
    "predictive_mse_results = {}\n",
    "predictive_mse_mean = np.zeros((len(methods), len(nsims)))\n",
    "predictive_mse_SE = np.zeros((len(methods), len(nsims)))\n",
    "\n",
    "for ii, method in enumerate(methods):\n",
    "    sbc_results[method] = {}\n",
    "    tarp_results[method] = {}\n",
    "    predictive_mse_results[method] = {}\n",
    "    for kk, nsim in enumerate(nsims):\n",
    "        sbc_results[method][nsim] = []\n",
    "        tarp_results[method][nsim] = []\n",
    "        predictive_mse_results[method][nsim] = []\n",
    "        temp_sbcs = results[(results['method'] == method) & (results['nsim'] == nsim)]['sbcs']\n",
    "        for ll in range(temp_sbcs.shape[0]):\n",
    "            s_sbcs = temp_sbcs.iloc[ll]\n",
    "            sbc_results[method][nsim].extend(s_sbcs)\n",
    "        sbc_mean[ii, kk] = np.mean(np.array(sbc_results[method][nsim]))\n",
    "        sbc_SE[ii, kk] = np.std(np.array(sbc_results[method][nsim]))/np.sqrt(len(sbc_results[method][nsim]))\n",
    "        \n",
    "        temp_tarps = results[(results['method'] == method) & (results['nsim'] == nsim)]['tarps']\n",
    "        for mm in range(temp_tarps.shape[0]):\n",
    "            s_tarps = temp_tarps.iloc[mm]\n",
    "            tarp_results[method][nsim].extend([s_tarps])\n",
    "        tarp_mean[ii, kk] = np.mean(np.array(tarp_results[method][nsim]))\n",
    "        tarp_SE[ii, kk] = np.std(np.array(tarp_results[method][nsim]))/np.sqrt(len(tarp_results[method][nsim]))\n",
    "\n",
    "        temp_predictive_mse = results[(results['method'] == method) & (results['nsim'] == nsim)]['predictive_mses']\n",
    "        for nn in range(temp_predictive_mse.shape[0]):\n",
    "            s_predictive_mse = temp_predictive_mse.iloc[nn]\n",
    "            predictive_mse_results[method][nsim].extend(s_predictive_mse)\n",
    "        predictive_mse_mean[ii, kk] = np.mean(np.array(predictive_mse_results[method][nsim]))\n",
    "        predictive_mse_SE[ii, kk] = np.std(np.array(predictive_mse_results[method][nsim]))/np.sqrt(len(predictive_mse_results[method][nsim]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f475c8e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"SBC mean\")\n",
    "print(sbc_mean)\n",
    "print(\"SBC SE\")\n",
    "print(sbc_SE)\n",
    "print(\"TARP mean\")\n",
    "print(tarp_mean)\n",
    "print(\"TARP SE\")\n",
    "print(tarp_SE)\n",
    "print(\"Predictive MSE mean\")\n",
    "print(predictive_mse_mean)\n",
    "print(\"Predictive MSE SE\")\n",
    "print(predictive_mse_SE)\n",
    "\n",
    "fno_sbc_mean = sbc_mean[0]\n",
    "fno_sbc_SE = sbc_SE[0]\n",
    "simformer_sbc_mean = sbc_mean[1]\n",
    "simformer_sbc_SE = sbc_SE[1]\n",
    "fno_tarp_mean = tarp_mean[0]\n",
    "fno_tarp_SE = tarp_SE[0]\n",
    "simformer_tarp_mean = tarp_mean[1]\n",
    "simformer_tarp_SE = tarp_SE[1]\n",
    "fno_predictive_mse_mean = predictive_mse_mean[0]\n",
    "fno_predictive_mse_SE = predictive_mse_SE[0]\n",
    "simformer_predictive_mse_mean = predictive_mse_mean[1]\n",
    "simformer_predictive_mse_SE = predictive_mse_SE[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c301c1eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Now read summary files for posterior and predictive plots\n",
    "\n",
    "fno_results_file = out_dir / f\"sir_experiment/FNO_FMPE/num_sim_100000_eval_points_{n_cond}_initial_conditions_{str(use_simformer_init)}/fno_predictive_summary.pkl\"\n",
    "with open(fno_results_file, \"rb\") as f:\n",
    "    fno_results = pickle.load(f)\n",
    "print(fno_results.keys())\n",
    "\n",
    "simformer_results_file = out_dir / f\"sir_experiment/FNO_FMPE/num_sim_100000_eval_points_{n_cond}_initial_conditions_{str(use_simformer_init)}/simformer_predictive_summary.pkl\"\n",
    "with open(simformer_results_file, \"rb\") as f:\n",
    "    simformer_results = pickle.load(f)\n",
    "print(simformer_results.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "585f3005",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Also need to load the metadata for the evaluation from the simformer results file\n",
    "\n",
    "theta_test = simformer_results[\"theta_test\"]\n",
    "test_x = simformer_results[\"x_test\"]\n",
    "fno_posterior_samples = fno_results[\"posterior_samples\"]\n",
    "fno_predictive_samples = fno_results[\"posterior_predictive_samples\"]\n",
    "simformer_posterior_samples = simformer_results[\"posterior_samples\"]\n",
    "simformer_predictive_samples = simformer_results[\"posterior_predictive_samples\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "795f0f88",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data_name = f\"posterior100k_samples_{n_cond}_time_points.npz\"\n",
    "test_data = np.load(experiment_folder / test_data_name)\n",
    "\n",
    "test_times_theta = torch.from_numpy(test_data[\"meta_data\"][:,2:2+n_cond]).to(device)\n",
    "test_times_x = torch.from_numpy(test_data[\"meta_data\"][:,2+n_cond:2+2*n_cond]).to(device)\n",
    "\n",
    "eval_times = test_times_theta[obs_index].detach().cpu().numpy()\n",
    "test_times_x_with_0 = torch.cat((torch.Tensor([0.0]).to(device), test_times_x[obs_index]), dim=0).detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc44d4ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Compute means and stds for methods\n",
    "\n",
    "fno_cont_samples = fno_posterior_samples[:,obs_index,2:]\n",
    "simformer_cont_samples = simformer_posterior_samples[:,obs_index,2:]\n",
    "fno_finite_samples = fno_posterior_samples[:,obs_index,:2]\n",
    "simformer_finite_samples = simformer_posterior_samples[:,obs_index,:2]\n",
    "\n",
    "fno_predictives = fno_predictive_samples[:,obs_index]\n",
    "simformer_predictives = simformer_predictive_samples[:,obs_index]\n",
    "\n",
    "theta_test_cont = theta_test[:,2:]\n",
    "theta_test_finite = theta_test[:,:2]\n",
    "\n",
    "\n",
    "\n",
    "fno_mean = fno_cont_samples.mean(axis=0)\n",
    "fno_std = fno_cont_samples.std(axis=0)\n",
    "\n",
    "fno_finite_mean = fno_finite_samples.mean(axis=0)\n",
    "fno_finite_std = fno_finite_samples.std(axis=0)\n",
    "\n",
    "fno_predictive_mean = fno_predictives.mean(axis=0)\n",
    "fno_predictive_std = fno_predictives.std(axis=0)\n",
    "\n",
    "\n",
    "simformer_mean = simformer_cont_samples.mean(axis=0)\n",
    "simformer_std = simformer_cont_samples.std(axis=0)\n",
    "simformer_finite_mean = simformer_finite_samples.mean(axis=0)\n",
    "simformer_finite_std = simformer_finite_samples.std(axis=0)\n",
    "\n",
    "simformer_predictive_mean = simformer_predictives.mean(axis=0)\n",
    "simformer_predictive_std = simformer_predictives.std(axis=0)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8baa9630",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Calculate the minimal achievable SBC from uniform ranks ####\n",
    "\n",
    "num_posterior_samples = 1000\n",
    "n_sbc = 100\n",
    "\n",
    "ranks = np.random.randint(0, num_posterior_samples, size=(n_sbc, n_cond+2))\n",
    "\n",
    "coverage_values = torch.Tensor(ranks) / num_posterior_samples\n",
    "\n",
    "atcs = []\n",
    "absolute_atcs = []\n",
    "\n",
    "for dim_idx in range(coverage_values.shape[1]):\n",
    "    # calculate empirical CDF via cumsum and normalize\n",
    "    hist, alpha_grid = torch.histogram(\n",
    "    coverage_values[:, dim_idx], density=True, bins=30\n",
    "    )\n",
    "    # add 0 to the beginning of the ecp curve to match the alpha grid\n",
    "    ecp = torch.cat([torch.Tensor([0]), torch.cumsum(hist, dim=0) / hist.sum()])\n",
    "    atc = (ecp - alpha_grid).mean().item()\n",
    "    absolute_atc = (ecp - alpha_grid).abs().mean().item()\n",
    "    atcs.append(atc)\n",
    "    absolute_atcs.append(absolute_atc)\n",
    "\n",
    "atcs = torch.tensor(atcs)\n",
    "absolute_atcs = torch.tensor(absolute_atcs)\n",
    "print(absolute_atcs)\n",
    "\n",
    "mean_absolute_atc = absolute_atcs.mean().numpy()\n",
    "print(mean_absolute_atc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2dd32e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a figure\n",
    "from plotting_utils import colors\n",
    "fig_version = \"v0\"\n",
    "simformer_idx = method_names.index(\"simformer\")\n",
    "#simformer_color = colors[simformer_idx]\n",
    "simformer_color = 'turquoise'\n",
    "fno_idx = method_names.index(\"FNOPE\")\n",
    "fnope_color = colors[fno_idx]\n",
    "#with plt.rc_context(fname=root_dir/\"plots\"/\"matplotlibrc\"):\n",
    "with plt.rc_context(fname=\"matplotlibrc\"):\n",
    "    kwargs_text = {\"fontsize\": \"10\", \"font\": \"Arial\", \"weight\": \"800\"}\n",
    "\n",
    "    fig = plt.figure(figsize=(3.47, 4))\n",
    "\n",
    "    def truncate(x, decimals=1):\n",
    "        return np.floor(x * 10**decimals) / 10**decimals\n",
    "\n",
    "    # Define the grid layout\n",
    "    gs = GridSpec(14, 2, width_ratios=[3,1.2], height_ratios=[1]*3 + [0.01]*2 + [1]*9,figure=fig, hspace=5., wspace=0.6)\n",
    "\n",
    "    fig.text(-0.04, 0.9, 'a', ha='center',**kwargs_text)\n",
    "    fig.text(-0.04, 0.63, 'b', ha='center', va='bottom' ,**kwargs_text)\n",
    "    fig.text(0.58, 0.63, 'c', ha='center', va='bottom', **kwargs_text)\n",
    "\n",
    "\n",
    "    # Posterior pairplot\n",
    "    ax1 = fig.add_subplot(gs[0:3,1])\n",
    "    ax1.set_xlabel(\"Recovery rate\")\n",
    "    ax1.set_ylabel(\"Death rate\", labelpad=2)\n",
    "\n",
    "\n",
    "    # KDE plot for fno_finite_samples\n",
    "    ax1.scatter(theta_test_finite[obs_index, 0], \n",
    "                theta_test_finite[obs_index, 1], \n",
    "                color='black', label='GT',zorder=10)\n",
    "\n",
    "    sns.kdeplot(x=fno_finite_samples[:, 0], \n",
    "                y=fno_finite_samples[:, 1], \n",
    "                ax=ax1,\n",
    "                color = fnope_color,\n",
    "                levels=4,\n",
    "                fill=False)\n",
    "    # KDE plot for simformer_finite_samples\n",
    "    sns.kdeplot(x=simformer_finite_samples[:, 0], \n",
    "                y=simformer_finite_samples[:, 1], \n",
    "                ax=ax1,\n",
    "                color=simformer_color,\n",
    "                alpha=0.75,\n",
    "                levels=4,\n",
    "                fill=False)\n",
    "\n",
    "    #Set the min and max xlabels to be the minimum of fno_finite_samples and simformer_finite_samples\n",
    "    # x_min = min(fno_finite_samples[:, 0].min().item(), simformer_finite_samples[:, 0].min().item())\n",
    "    # x_max = max(fno_finite_samples[:, 0].max().item(), simformer_finite_samples[:, 0].max().item())\n",
    "    # y_min = min(fno_finite_samples[:, 1].min().item(), simformer_finite_samples[:, 1].min().item())\n",
    "    # y_max = max(fno_finite_samples[:, 1].max().item(), simformer_finite_samples[:, 1].max().item())\n",
    "    # x_min = np.round(x_min, 2)\n",
    "    # x_max = np.round(x_max, 2)\n",
    "\n",
    "    # y_min = np.round(y_min, 2)\n",
    "    # y_max = np.round(y_max, 2)\n",
    "    # if x_min == x_max:\n",
    "    #     x_max += 0.05\n",
    "    # if y_min == y_max:\n",
    "    #     y_max += 0.05\n",
    "    # print(x_min, x_max)\n",
    "    # print(y_min, y_max)\n",
    "    # ax1.set_xlim(x_min, x_max)\n",
    "    # ax1.set_ylim(y_min, y_max)\n",
    "    x_min,x_max = 0.2,0.4\n",
    "    y_min,y_max = 0.2,0.4\n",
    "\n",
    "    ax1.set_xticks([x_min, x_max])\n",
    "    ax1.set_yticks([y_min, y_max])\n",
    "    ax1.set_xticklabels([x_min, x_max])\n",
    "    ax1.set_yticklabels([y_min, y_max])\n",
    "    # ax1.xaxis.set_ticks_position('top')     # Move ticks to the top\n",
    "    # ax1.xaxis.set_label_position('top')     # Move axis label to the top\n",
    "\n",
    "\n",
    "\n",
    "    # Functional posterior\n",
    "    ax2 = fig.add_subplot(gs[0:3, 0])\n",
    "    ax2.set_ylabel(\"Contact rate\")\n",
    "    # Plot fno_mean with shaded region for fno_std\n",
    "    ax2.plot(eval_times, fno_mean, label='FNO', color=fnope_color)\n",
    "    ax2.fill_between(eval_times, \n",
    "                    fno_mean - fno_std, \n",
    "                    fno_mean + fno_std, \n",
    "                    color=fnope_color, alpha=0.2)\n",
    "\n",
    "    # Plot simformer_mean with shaded region for simformer_std\n",
    "    ax2.plot(eval_times, simformer_mean, label='Simformer', color=simformer_color)\n",
    "    ax2.fill_between(eval_times, \n",
    "                    simformer_mean - simformer_std, \n",
    "                    simformer_mean + simformer_std, \n",
    "                    color=simformer_color, alpha=0.4)\n",
    "\n",
    "\n",
    "    ax2.plot(eval_times,\n",
    "                theta_test_cont[obs_index], \n",
    "                color = \"black\",\n",
    "                label = \"GT\",\n",
    "                linestyle='dashed')\n",
    "\n",
    "    ymin, ymax = ax2.get_ylim()\n",
    "    ax2.set_yticks([ymin, ymax])\n",
    "    ymin= 0.01\n",
    "    ax2.set_yticklabels([f'{ymin:.1f}', f'{ymax:.1f}'])\n",
    "    ax2.set_xlabel(\"Time [days]\")\n",
    "\n",
    "\n",
    "\n",
    "    #Metric plots on right (Predictive MSE and SBC)\n",
    "    ax3 = fig.add_subplot(gs[5:9, 1])\n",
    "    mse_scale = 3.0\n",
    "\n",
    "    ax3.errorbar(nsims,\n",
    "                simformer_predictive_mse_mean,\n",
    "                yerr=simformer_predictive_mse_SE,\n",
    "                fmt='o',\n",
    "                linestyle='-',\n",
    "                color=simformer_color,\n",
    "                )\n",
    "\n",
    "    ax3.errorbar(nsims,\n",
    "                fno_predictive_mse_mean,\n",
    "                yerr=fno_predictive_mse_SE,\n",
    "                fmt='o',\n",
    "                linestyle='-',\n",
    "                color=fnope_color,\n",
    "                )\n",
    "\n",
    "    ax3.set_yticks([0, 8e-3])\n",
    "    ax3.tick_params(labelbottom=False)\n",
    "\n",
    "    from matplotlib.ticker import ScalarFormatter\n",
    "\n",
    "\n",
    "    # Force scientific notation with offset\n",
    "    formatter = ScalarFormatter(useMathText=True)\n",
    "    formatter.set_powerlimits((-3, -3))  # Force 1e-3 scaling\n",
    "    ax3.yaxis.set_major_formatter(formatter)\n",
    "\n",
    "    # Optional: move the offset text (e.g. ×10⁻³) to top or bottom\n",
    "    ax3.ticklabel_format(axis='y', style='scientific')\n",
    "    ax3.yaxis.offsetText.set_visible(True)  # This is the ×10⁻³ label\n",
    "\n",
    "\n",
    "    #ax3.set_ylabel('MSE',labelpad=18)\n",
    "    ax3.set_ylabel('MSE', labelpad=8)\n",
    "    # Draw and then adjust the offset text\n",
    "    fig.canvas.draw()  # Ensure offsetText is created before modifying\n",
    "\n",
    "    offset = ax3.yaxis.get_offset_text()\n",
    "    # offset.set_fontsize(10)                      # Make it smaller\n",
    "    offset.set_horizontalalignment('left')      # Align left\n",
    "    offset.set_x(-0.15)  \n",
    "\n",
    "\n",
    "\n",
    "    ax4 = fig.add_subplot(gs[10:14, 1],sharex=ax3)\n",
    "\n",
    "    ax4.errorbar(nsims,\n",
    "                simformer_sbc_mean,\n",
    "                yerr=simformer_sbc_SE,\n",
    "                fmt='o',\n",
    "                linestyle='-',\n",
    "                color=simformer_color,\n",
    "                )\n",
    "\n",
    "    ax4.errorbar(nsims,\n",
    "                fno_sbc_mean,\n",
    "                yerr=fno_sbc_SE,\n",
    "                fmt='o',\n",
    "                linestyle='-',\n",
    "                color=fnope_color,\n",
    "                )\n",
    "        \n",
    "    #ax4.hlines(mean_absolute_atc, 500, 200000, linestyle=':', color='black', label=f'lower\\nbound')\n",
    "    ax4.hlines(mean_absolute_atc, 1e3, 1e5, linestyle=':', color='black', label=f'lower\\nbound')\n",
    "    ax4.set_xscale(\"log\")\n",
    "    ax4.set_xlabel('# simulations')\n",
    "    ax4.set_ylabel('SBC EoD', labelpad=2)\n",
    "    ax4.set_xticks(nsims)\n",
    "    ax4.set_xlim(500,200_000)\n",
    "    ax4.minorticks_off()\n",
    "    ax4.set_yticks([0, 0.2])\n",
    "    ax4.set_ylim(0, 0.2)\n",
    "    ax4.legend(handlelength=1.1,\n",
    "            loc='upper right',\n",
    "            bbox_to_anchor=(1.1, 1.1)\n",
    "                )\n",
    "\n",
    "\n",
    "    #Posterior Predictive Plots\n",
    "    channel_labels = ['Infected', 'Recovered', 'Deceased']\n",
    "    # Create the third panel on the bottom right\n",
    "    for channel in range(simformer_predictive_mean.shape[0]):\n",
    "        ax = fig.add_subplot(gs[2+3*(channel+1):2+3*(channel+2), 0],sharex=ax2)\n",
    "\n",
    "        ax.set_ylabel(channel_labels[channel])\n",
    "        # Plot for FNO predictive mean and std\n",
    "        ax.plot(test_times_x_with_0, \n",
    "                    fno_predictive_mean[channel], \n",
    "                    color = fnope_color,\n",
    "                    linestyle='solid',\n",
    "                    label = \"FNOPE\")\n",
    "        ax.fill_between(test_times_x_with_0, \n",
    "                            fno_predictive_mean[channel] - fno_predictive_std[channel], \n",
    "                            fno_predictive_mean[channel] + fno_predictive_std[channel], \n",
    "                            color = fnope_color,\n",
    "                            alpha=0.2)\n",
    "\n",
    "        # Plot for Simformer predictive mean and std\n",
    "        ax.plot(test_times_x_with_0,\n",
    "                simformer_predictive_mean[channel],  \n",
    "                color = simformer_color,\n",
    "                label = \"Simformer\")\n",
    "                \n",
    "        ax.fill_between(test_times_x_with_0,\n",
    "                        simformer_predictive_mean[channel] - simformer_predictive_std[channel], \n",
    "                        simformer_predictive_mean[channel] + simformer_predictive_std[channel], \n",
    "                        color = simformer_color,\n",
    "                        alpha=0.4)\n",
    "        ax.plot(test_times_x_with_0[1:],\n",
    "                    test_x[obs_index, channel], \n",
    "                    color = \"black\",\n",
    "                    label = \"Data\",\n",
    "                    linestyle='',\n",
    "                    marker = \"x\",\n",
    "                    markersize = 3)\n",
    "        \n",
    "        ymin, ymax = ax.get_ylim()\n",
    "        ymax = truncate(ymax, 2)\n",
    "        ax.set_yticks([ymin, ymax])\n",
    "        ymin = 0.01\n",
    "\n",
    "        ax.set_yticklabels([f'{ymin:.1f}', f'{ymax:.1f}'])\n",
    "\n",
    "        if channel < simformer_predictive_mean.shape[0]-1:\n",
    "            plt.setp(ax.get_xticklabels(), visible=False)\n",
    "            ax.tick_params(labelbottom=False)\n",
    "\n",
    "            if channel == 0:\n",
    "                label_order = ['FNOPE', 'Simformer', 'Data']\n",
    "                handles, labels = ax.get_legend_handles_labels()\n",
    "                label_to_handle = dict(zip(labels, handles))\n",
    "                ordered_handles = [label_to_handle[label] for label in label_order]\n",
    "\n",
    "                # Set legend for axs[0]\n",
    "                ax.legend(ordered_handles,\n",
    "                          label_order,\n",
    "                          frameon=False,\n",
    "                          ncol=1,\n",
    "                          handlelength=1.5,\n",
    "                          bbox_to_anchor=(1.05, 1.15))\n",
    "        else:\n",
    "            ax.set_xlabel(\"Time [days]\",labelpad=8)\n",
    "\n",
    "\n",
    "    for ax in fig.axes:\n",
    "        ax.spines['top'].set_visible(False)\n",
    "        ax.spines['right'].set_visible(False)\n",
    "        ax.spines['left'].set_position(('outward', 5))     # Move y-axis slightly left\n",
    "        ax.spines['bottom'].set_position(('outward', 1))     # Move y-axis slightly left\n",
    "\n",
    "    # ax1.spines['top'].set_visible(True)\n",
    "    # ax1.spines['bottom'].set_visible(False)\n",
    "    # # ax1.spines['top'].set_position(('outward', 5))     # Move y-axis slightly left\n",
    "    # # plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "    fig.savefig(f\"sir_plots/sir_summary_n_eval_{n_cond}_simformer_init_{str(use_simformer_init)}_{fig_version}.svg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"sir_plots/sir_summary_n_eval_{n_cond}_simformer_init_{str(use_simformer_init)}_{fig_version}.pdf\", bbox_inches='tight')\n",
    "    fig.savefig(f\"sir_plots/sir_summary_n_eval_{n_cond}_simformer_init_{str(use_simformer_init)}_{fig_version}.png\", bbox_inches='tight', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd81181e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fourier_nets",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
