{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e8be22a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from matplotlib.lines import Line2D\n",
    "import torch\n",
    "from simulators.gp_priors import get_gaussian_process_prior_1d\n",
    "from experiments.evaluation_utils import GTPosterior\n",
    "from utils.metrics import sliced_wasserstein_distance\n",
    "from utils.misc import get_output_dir,get_data_dir\n",
    "from pathlib import Path\n",
    "import pickle\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "data_dir = get_data_dir()\n",
    "output_dir = get_output_dir()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44f16e70",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Load data to be plotted from csv ####\n",
    "\n",
    "out_dir = get_output_dir()\n",
    "\n",
    "path = out_dir / \"linear_gaussian_experiment/summary.csv\"\n",
    "data = pd.read_csv(path, usecols=[1, 2, 3, 4, 5, 6, 7])\n",
    "\n",
    "methods = data[\"method\"].unique()\n",
    "n_sim = data[\"nsim\"].unique()\n",
    "\n",
    "# Calculate mean and SE for each method and each number of simulations\n",
    "swd_results = {}\n",
    "swd_mean = np.zeros((methods.shape[0], n_sim.shape[0]))\n",
    "swd_SE = np.zeros((methods.shape[0], n_sim.shape[0]))\n",
    "\n",
    "for ii, method in enumerate(methods):\n",
    "    swd_results[method] = {}\n",
    "    for kk, nsim in enumerate(n_sim):\n",
    "        swd_results[method][nsim] = []\n",
    "        temp_swds = data[(data[\"method\"] == method) & (data[\"nsim\"] == nsim)][\"swds\"]\n",
    "        for ll in range(temp_swds.shape[0]):\n",
    "            clean_string = temp_swds.iloc[ll].strip(\"[]\")\n",
    "            swd_results[method][nsim].extend(list(map(float, clean_string.split())))\n",
    "        swd_mean[ii, kk] = np.mean(np.array(swd_results[method][nsim]))\n",
    "        swd_SE[ii, kk] = np.std(np.array(swd_results[method][nsim])) / np.sqrt(\n",
    "            len(swd_results[method][nsim])\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40c26d0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plotting with matplotlibrc file\n",
    "figure_version = \"v0\"\n",
    "\n",
    "with plt.rc_context(fname=\"matplotlibrc\"):\n",
    "\n",
    "    fig, ax = plt.subplots(1, 1, figsize=(1.66, 1.05))\n",
    "\n",
    "    methods_names = ['FNOPE', 'FNOPE (fix)',  'NPE (spectral)', 'FMPE (raw)', 'FMPE (spectral)']\n",
    "    colors = ['#9b2226', '#CA6702', '#023e8a', '#00b4d8', '#0077b6']\n",
    "\n",
    "    for mm in range(methods.shape[0]):\n",
    "\n",
    "        ax.errorbar(\n",
    "            n_sim,\n",
    "            swd_mean[mm, :],\n",
    "            yerr=swd_SE[mm, :],\n",
    "            fmt=\"o\",\n",
    "            linestyle=\"-\",\n",
    "            color=colors.pop(0),\n",
    "            label=methods_names[mm],\n",
    "        )\n",
    "\n",
    "    ax.set_xscale(\"log\")\n",
    "    ax.set_xlabel(\"# simulations\")\n",
    "    ax.set_ylabel(\"SWD\")\n",
    "    ax.set_xticks(n_sim)\n",
    "    ax.set_yticks([0, 1.0])\n",
    "    ax.minorticks_off()\n",
    "\n",
    "    # Create an empty entry\n",
    "    empty_entry = Line2D([], [], color=\"none\", label=\"\")\n",
    "\n",
    "    # Add the empty entry at the third position\n",
    "    handles, labels = ax.get_legend_handles_labels()\n",
    "    \n",
    "    # Reorder handles and labels based on custom order\n",
    "    # custom_order = [1, 2, 4, 3, 5, 0]  # New order of methods_names with lower bound\n",
    "    custom_order = [0, 1, 2, 4, 3]  # New order of methods_names without lower bound\n",
    "    handles = [handles[i] for i in custom_order]\n",
    "    labels = [labels[i] for i in custom_order]\n",
    "    handles.insert(2, empty_entry)  # Insert at the third position (index 2)\n",
    "    labels.insert(2, \"\")  # Add an empty label at the same position\n",
    "\n",
    "    # Update the legend\n",
    "    plt.legend(\n",
    "        handles=handles,\n",
    "        labels=labels,\n",
    "        loc=\"upper center\",\n",
    "        bbox_to_anchor=(0.5, -0.7),  # Centered below the axes\n",
    "        ncol=2,  # Number of columns\n",
    "    )\n",
    "\n",
    "    ax.spines[\"left\"].set_position((\"outward\", 5))  # Move y-axis slightly left\n",
    "    ax.spines[\"bottom\"].set_position((\"outward\", 5))\n",
    "    # plt.show()\n",
    "    plt.savefig(f\"linear_gaussian_plots/results_linearGaussian_{figure_version}.pdf\", format=\"pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fourier_nets",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
