{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf8b6e4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "import sys\n",
    "import svgutils.transform as sg\n",
    "from svgutils.compose import *\n",
    "\n",
    "from plotting_utils import cm2inch, get_size_tuple,colors,method_names\n",
    "from utils.misc import get_output_dir\n",
    "\n",
    "from IPython.core.display import SVG\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64d459b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Load and prepare data ####\n",
    "out_dir = get_output_dir()\n",
    "# Change the value below for the different experiments of ice - using 50 parameters as in [redacted] or all 500 parameters\n",
    "num_params = 50\n",
    "if num_params == 500:\n",
    "    experiment_folder = out_dir/\"ice_experiment_500\"\n",
    "    method_names = [\"FNOPE\", \"NPE (spectral)\", \"FMPE (spectral)\", \"NPE (raw)\"] #FMPE raw fails for 500 params\n",
    "    custom_order = [1, 2, 3, 4]  # New order of methods_names\n",
    "    upper_lim = 0.8\n",
    "\n",
    "\n",
    "else:\n",
    "    experiment_folder = out_dir/\"ice_experiment\"\n",
    "    method_names = [\"FNOPE\", \"NPE (spectral)\", \"FMPE (raw)\", \"FMPE (spectral)\", \"NPE (raw)\"]\n",
    "    custom_order = [1, 2, 4, 5, 3]  # New order of methods_names\n",
    "    upper_lim = 0.7\n",
    "\n",
    "\n",
    "path = experiment_folder/ \"summary.csv\"\n",
    "data = pd.read_csv(path, usecols=[1, 2, 3, 4, 5, 6, 7])\n",
    "\n",
    "# Get the different methods\n",
    "methods = data[\"method\"].unique()\n",
    "n_sim = data[\"nsim\"].unique()\n",
    "\n",
    "# Calculate mean and SE for each method and each number of simulations\n",
    "mses_results = {}\n",
    "mses_real_results = {}\n",
    "sbc_results = {}\n",
    "mses_mean = np.zeros((methods.shape[0], n_sim.shape[0]))\n",
    "mses_SE = np.zeros((methods.shape[0], n_sim.shape[0]))\n",
    "mses_real_mean = np.zeros((methods.shape[0], n_sim.shape[0]))\n",
    "mses_real_SE = np.zeros((methods.shape[0], n_sim.shape[0]))\n",
    "sbcs_mean = np.zeros((methods.shape[0], n_sim.shape[0]))\n",
    "sbcs_SE = np.zeros((methods.shape[0], n_sim.shape[0]))\n",
    "\n",
    "for ii, method in enumerate(methods):\n",
    "    mses_results[method] = {}\n",
    "    mses_real_results[method] = {}\n",
    "    sbc_results[method] = {}\n",
    "    for kk, nsim in enumerate(n_sim):\n",
    "        mses_results[method][nsim] = []\n",
    "        mses_real_results[method][nsim] = []\n",
    "        sbc_results[method][nsim] = []\n",
    "        temp_mses = data[(data[\"method\"] == method) & (data[\"nsim\"] == nsim)][\n",
    "            \"predictive_mses\"\n",
    "        ]\n",
    "        temp_mses_real = data[(data[\"method\"] == method) & (data[\"nsim\"] == nsim)][\n",
    "            \"predictive_mses_real_data\"\n",
    "        ]\n",
    "        temp_sbcs = data[(data[\"method\"] == method) & (data[\"nsim\"] == nsim)][\"sbcs\"]\n",
    "\n",
    "        for ll in range(temp_mses.shape[0]):\n",
    "            s_mses = temp_mses.iloc[ll]\n",
    "            s_mses_clean = s_mses.replace(\"[\", \"\").replace(\"]\", \"\")\n",
    "            list_mses = [float(x) for x in s_mses_clean.split() if x != \"nan\"]\n",
    "            mses_results[method][nsim].extend(list_mses)\n",
    "\n",
    "        mses_mean[ii, kk] = np.mean(np.array(mses_results[method][nsim]))\n",
    "        mses_SE[ii, kk] = np.std(np.array(mses_results[method][nsim])) / np.sqrt(\n",
    "            len(mses_results[method][nsim])\n",
    "        )\n",
    "\n",
    "        for mm in range(temp_sbcs.shape[0]):\n",
    "            s_sbcs = temp_sbcs.iloc[mm]\n",
    "            s_sbcs_clean = s_sbcs.replace(\"[\", \"\").replace(\"]\", \"\")\n",
    "            list_sbcs = [float(x) for x in s_sbcs_clean.split() if x != \"nan\"]\n",
    "            sbc_results[method][nsim].extend(list_sbcs)\n",
    "\n",
    "        sbcs_mean[ii, kk] = np.mean(np.array(sbc_results[method][nsim]))\n",
    "        sbcs_SE[ii, kk] = np.std(np.array(sbc_results[method][nsim])) / np.sqrt(\n",
    "            len(sbc_results[method][nsim])\n",
    "        )\n",
    "\n",
    "        for ll in range(temp_mses_real.shape[0]):\n",
    "            s_mses_real = temp_mses_real.iloc[ll]\n",
    "            s_mses_real_clean = s_mses_real.replace(\"[\", \"\").replace(\"]\", \"\")\n",
    "            list_mses_real = [float(x) for x in s_mses_real_clean.split() if x != \"nan\"]\n",
    "            mses_real_results[method][nsim].extend(list_mses_real)\n",
    "\n",
    "        mses_real_mean[ii, kk] = np.mean(np.array(mses_real_results[method][nsim]))\n",
    "        mses_real_SE[ii, kk] = np.std(\n",
    "            np.array(mses_real_results[method][nsim])\n",
    "        ) / np.sqrt(len(mses_real_results[method][nsim]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24896663",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Calculate ideal SBC to add to SBC plot ####\n",
    "\n",
    "num_posterior_samples = 1000\n",
    "n_sbc = 100\n",
    "n_sbc_marginals = 50\n",
    "\n",
    "ranks = np.random.randint(0, num_posterior_samples, size=(n_sbc, n_sbc_marginals))\n",
    "\n",
    "coverage_values = torch.Tensor(ranks) / num_posterior_samples\n",
    "\n",
    "absolute_atcs = []\n",
    "\n",
    "for dim_idx in range(coverage_values.shape[1]):\n",
    "    # calculate empirical CDF via cumsum and normalize\n",
    "    hist, alpha_grid = torch.histogram(\n",
    "        coverage_values[:, dim_idx], density=True, bins=30\n",
    "    )\n",
    "    # add 0 to the beginning of the ecp curve to match the alpha grid\n",
    "    ecp = torch.cat([torch.Tensor([0]), torch.cumsum(hist, dim=0) / hist.sum()])\n",
    "    absolute_atc = (ecp - alpha_grid).abs().mean().item()\n",
    "    absolute_atcs.append(absolute_atc)\n",
    "\n",
    "absolute_atcs = torch.tensor(absolute_atcs)\n",
    "mean_absolute_atc = absolute_atcs.mean().numpy()\n",
    "print(mean_absolute_atc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d94ad1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(n_sim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2bcd47b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Create three plots for these metrics ####\n",
    "\n",
    "with plt.rc_context(fname=\"matplotlibrc\"):\n",
    "\n",
    "    plt.tight_layout()\n",
    "    # fig, axs = plt.subplots(1, 3, figsize=(16, 3))\n",
    "    #\n",
    "    #fig, axs = plt.subplots(1, 3, figsize=(8, 1.3))\n",
    "    fig, axs = plt.subplots(1, 3, figsize=(5.5206, 0.9))\n",
    "\n",
    "    fig.subplots_adjust(wspace=0.7)\n",
    "\n",
    "    # colors = [\"#9b2226\", \"#023e8a\", \"#00b4d8\", \"#0077b6\", \"#90e0ef\"]\n",
    "    # colors1 = [\"#9b2226\", \"#023e8a\", \"#00b4d8\", \"#0077b6\", \"#90e0ef\"]\n",
    "    # colors2 = [\"#9b2226\", \"#023e8a\", \"#00b4d8\", \"#0077b6\", \"#90e0ef\"]\n",
    "    colors = [colors[method_names.index(method_names[i])] for i in range(len(method_names))]\n",
    "\n",
    "    for mm in range(methods.shape[0]):\n",
    "\n",
    "        axs[0].errorbar(\n",
    "            n_sim,\n",
    "            mses_mean[mm, :],\n",
    "            yerr=mses_SE[mm, :],\n",
    "            fmt=\"o\",\n",
    "            linestyle=\"-\",\n",
    "            color=colors[mm],\n",
    "            label=method_names[mm],\n",
    "        )\n",
    "        axs[1].errorbar(\n",
    "            n_sim,\n",
    "            sbcs_mean[mm, :],\n",
    "            yerr=sbcs_SE[mm, :],\n",
    "            fmt=\"o\",\n",
    "            linestyle=\"-\",\n",
    "            color=colors[mm],\n",
    "            label=method_names[mm],\n",
    "        )\n",
    "        axs[2].errorbar(\n",
    "            n_sim,\n",
    "            mses_real_mean[mm, :],\n",
    "            yerr=mses_real_SE[mm, :],\n",
    "            fmt=\"o\",\n",
    "            linestyle=\"-\",\n",
    "            color=colors[mm],\n",
    "            label=method_names[mm],\n",
    "        )\n",
    "\n",
    "    axs[0].set_xscale(\"log\")\n",
    "    axs[0].set_xlabel(\"# simulations\")\n",
    "    axs[0].set_ylabel(\"MSE to synth. obs.\")\n",
    "    axs[0].set_xticks(n_sim)\n",
    "    axs[0].set_yticks([0, 0.6])\n",
    "    axs[0].set_ylim([0, upper_lim])\n",
    "    axs[0].minorticks_off()\n",
    "    axs[0].spines[\"left\"].set_position((\"outward\", 5))  # Move y-axis slightly left\n",
    "    axs[0].spines[\"bottom\"].set_position((\"outward\", 5))\n",
    "\n",
    "    axs[1].hlines(\n",
    "        mean_absolute_atc,\n",
    "        1e3,\n",
    "        1e5,\n",
    "        linestyle=\":\",\n",
    "        linewidth=2.5,\n",
    "        color=\"black\",\n",
    "        label=\"lower bound\",\n",
    "    )\n",
    "    axs[1].set_xscale(\"log\")\n",
    "    axs[1].set_xlabel(\"# simulations\")\n",
    "    axs[1].set_ylabel(\"SBC EoD\")\n",
    "    axs[1].set_xticks(n_sim)\n",
    "    axs[1].set_ylim([0, 0.2])\n",
    "    axs[1].set_yticks([0, 0.2])\n",
    "    axs[1].minorticks_off()\n",
    "    axs[1].spines[\"left\"].set_position((\"outward\", 5))  # Move y-axis slightly left\n",
    "    axs[1].spines[\"bottom\"].set_position((\"outward\", 5))\n",
    "\n",
    "    # Create the legend with the new order\n",
    "    # axs[1].legend(handles=handles,\n",
    "    # labels=labels,\n",
    "    # loc=\"upper center\",\n",
    "    # bbox_to_anchor=(0.5, -0.6),\n",
    "    # ncol=5\n",
    "    # )\n",
    "\n",
    "    axs[2].set_xscale(\"log\")\n",
    "    axs[2].set_xlabel(\"# simulations\")\n",
    "    axs[2].set_ylabel(\"MSE to real obs.\")\n",
    "    axs[2].set_xticks(n_sim)\n",
    "    axs[2].set_yticks([0, 0.6])\n",
    "    axs[2].set_ylim([0, upper_lim])\n",
    "    axs[2].minorticks_off()\n",
    "    axs[2].spines[\"left\"].set_position((\"outward\", 5))  # Move y-axis slightly left\n",
    "    axs[2].spines[\"bottom\"].set_position((\"outward\", 5))\n",
    "\n",
    "    # Add legends\n",
    "    # Get handles and labels\n",
    "    handles, labels = axs[1].get_legend_handles_labels()\n",
    "\n",
    "\n",
    "    # Reorder handles and labels based on custom order\n",
    "    handles_global = [handles[i] for i in custom_order]\n",
    "    labels_global = [labels[i] for i in custom_order]\n",
    "\n",
    "    handle_ideal = [handles[0]]\n",
    "    label_ideal = [labels[0]]\n",
    "\n",
    "    axs[1].legend(\n",
    "        handles=handle_ideal,\n",
    "        labels=label_ideal,\n",
    "        loc=\"upper right\",\n",
    "        bbox_to_anchor=(1.1, 1.1),\n",
    "        handlelength=1.5,\n",
    "        frameon=False,\n",
    "    )\n",
    "\n",
    "    # Create the legend with the new order\n",
    "    fig.legend(\n",
    "        handles=handles_global,\n",
    "        labels=labels_global,\n",
    "        loc=\"upper center\",\n",
    "        bbox_to_anchor=(0.5, -0.45),\n",
    "        ncol=5,\n",
    "    )\n",
    "    if num_params == 500:\n",
    "        plt.savefig(\n",
    "            \"ice_plots/results_ice_500.pdf\", format=\"pdf\", bbox_inches=\"tight\"\n",
    "        )\n",
    "    else:\n",
    "        plt.savefig(\"ice_plots/results_ice.svg\", format=\"svg\", bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93434fc3",
   "metadata": {},
   "source": [
    "## Put everything together with svg utils\n",
    "\n",
    "Remember to create the predictive plot in `plot_ice_predictive.ipynb` prior to running this cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6144f6c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"ice_plots/\"\n",
    "kwargs_text = {\"size\": \"7pt\", \"font\": \"Arial\", \"weight\": \"800\"}\n",
    "\n",
    "\n",
    "# create new SVG figure\n",
    "fig = sg.SVGFigure()\n",
    "fig.set_size((\"10cm\", \"6cm\"))\n",
    "\n",
    "# load matpotlib-generated figures\n",
    "fig0 = sg.fromfile(base_path + \"ice_transect_v2.svg\")\n",
    "fig1 = sg.fromfile(base_path + \"post_predictives.svg\")\n",
    "\n",
    "fig2 = sg.fromfile(base_path + \"results_ice.svg\")\n",
    "\n",
    "\n",
    "# get the plot objects\n",
    "plot0 = fig0.getroot()\n",
    "plot1 = fig1.getroot()\n",
    "plot2 = fig2.getroot()\n",
    "\n",
    "# get sizes\n",
    "size0 = get_size_tuple(fig0)\n",
    "size1 = get_size_tuple(fig1)\n",
    "size2 = get_size_tuple(fig2)\n",
    "\n",
    "# define scales\n",
    "scales = [0.19, 1, 1]\n",
    "\n",
    "# a: ice transect\n",
    "plot0.scale(scales[0])\n",
    "plot0.moveto(15, 15)\n",
    "\n",
    "# b: post predictive\n",
    "plot1.scale(scales[1])\n",
    "plot1.moveto(size0[0] * scales[0] + 25, 10)\n",
    "\n",
    "# c: results\n",
    "plot2.scale(scales[2])\n",
    "plot2.moveto(5, size0[1] * scales[0] + 45)\n",
    "\n",
    "# add text labels\n",
    "#txt0 = sg.TextElement(1, 10, \"a\", **kwargs_text)\n",
    "#txt1 = sg.TextElement(size0[0] * scales[0] + 22, 10, \"b\", **kwargs_text)\n",
    "#txt2 = sg.TextElement(1, size0[1] * scales[0] + 40, \"c\", **kwargs_text)\n",
    "#txt3 = sg.TextElement(1, size0[1] * scales[0] + 40, \"d\", **kwargs_text)\n",
    "#txt4 = sg.TextElement(1, size0[1] * scales[0] + 40, \"e\", **kwargs_text)\n",
    "\n",
    "txt0 = sg.TextElement(0, 10, \"a\", **kwargs_text)\n",
    "txt1 = sg.TextElement(size0[0] * scales[0] + 22, 10, \"b\", **kwargs_text)\n",
    "txt2 = sg.TextElement(0, size0[1] * scales[0] + 45, \"c\", **kwargs_text)\n",
    "txt3 = sg.TextElement(115, size0[1] * scales[0] + 45, \"d\", **kwargs_text)\n",
    "txt4 = sg.TextElement(235, size0[1] * scales[0] + 45, \"e\", **kwargs_text)\n",
    "\n",
    "\n",
    "# append plots and labels to figure\n",
    "fig.append(\n",
    "    [\n",
    "        plot0,\n",
    "        plot1,\n",
    "        plot2,\n",
    "    ]\n",
    ")\n",
    "fig.append([txt0, txt1, txt2, txt3, txt4])\n",
    "\n",
    "\n",
    "# save generated SVG files\n",
    "fig.save(\"ice_plots/ice_joint.svg\")\n",
    "SVG(\"ice_plots/ice_joint.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46a2bf34",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert to pdf using inkscape\n",
    "!/Applications/Inkscape.app/Contents/MacOS/inkscape ice_plots/ice_joint.svg --export-type=pdf --export-filename=ice_plots/ice_joint.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70ec754f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fourier_nets",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
