{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad0bc3ce-0a27-4913-8881-fa332dbfc97f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open('data/data_n500_figure3', 'rb') as f:\n",
    "    mydata = pickle.load(f)\n",
    "\n",
    "test_xs_1d, train_xs_1d, train_ys, train_loss, apply_fn_list, mu, mu_nngp, std = mydata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caff4c8e-8f1e-4503-881f-7d96a843b63c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import set_matplotlib_formats\n",
    "import matplotlib_inline\n",
    "\n",
    "matplotlib_inline.backend_inline.set_matplotlib_formats('pdf', 'svg')\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as mpatches\n",
    "import matplotlib.lines as mlines\n",
    "\n",
    "import itertools\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from matplotlib import colormaps as cm\n",
    "\n",
    "# Define plot colors\n",
    "viridis = cm['viridis']\n",
    "vir = viridis(.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc452e08-746c-46fc-9bc7-4fad882ec666",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define plot\n",
    "fig, ax = plt.subplots(1,1)\n",
    "fig.set_size_inches(6,4)\n",
    "ax.set_xticks((-np.pi,-np.pi/2,0,np.pi/2,np.pi), ('$-\\\\pi$','','0','','$\\\\pi$'))\n",
    "fig.supxlabel('$\\\\alpha$')\n",
    "ax.set_ylim(-2.5,3.5)\n",
    "\n",
    "# 1. Plot standard deviation\n",
    "ax.fill_between(\n",
    "        np.reshape(test_xs_1d, (-1)),\n",
    "        (mu - 2 * std).reshape((-1)),\n",
    "        (mu +  2 * std).reshape((-1)),\n",
    "        color='black', alpha=0.1)\n",
    "\n",
    "# 2. Plot trained networks\n",
    "for nn in apply_fn_list:\n",
    "    ax.plot(test_xs_1d, nn, color=vir, linewidth=1, alpha = 0.05)\n",
    "\n",
    "# Plot trained networks mean\n",
    "plt.plot(test_xs_1d, np.mean(apply_fn_list, axis=0), color=viridis(.5), linewidth=2.5)\n",
    "\n",
    "# 3. Plot mean\n",
    "plt.plot(test_xs_1d, mu, linewidth=1.2, color='k', linestyle='solid')\n",
    "\n",
    "# 4. Plot mean nngp\n",
    "plt.plot(test_xs_1d, mu_nngp, linewidth=1.2, color='0.2', linestyle='dashed')\n",
    "\n",
    "# 5. Plot training points\n",
    "ax.plot(train_xs_1d, train_ys, marker='x', color=viridis(.95), linestyle='', ms='7', mew='1.3')\n",
    "\n",
    "fig.savefig(\"figure3a.pdf\", bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7d0e6b6-fa88-4af5-b718-73eb3c418e8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# n=20\n",
    "with open('data_n20_figure3', 'rb') as f:\n",
    "    mydata = pickle.load(f)\n",
    "test_xs_1d, train_xs_1d, train_ys, train_loss, apply_fn_list_n20, mu, mu_nngp, std = mydata\n",
    "\n",
    "# n=100\n",
    "with open('data_n100_figure3', 'rb') as f:\n",
    "    mydata = pickle.load(f)\n",
    "_, _, _, _, apply_fn_list_n100, _, _, _ = mydata\n",
    "\n",
    "# n=500\n",
    "with open('data_n500_figure3', 'rb') as f:\n",
    "    mydata = pickle.load(f)\n",
    "_, _, _, _, apply_fn_list_n500, _, _, _ = mydata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1afa67c0-59f1-45b8-85ed-7e483c969a33",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define plot\n",
    "fig, ax = plt.subplots(1,1)\n",
    "vir1, vir2, vir3, vir4 = viridis(.1), viridis(.4), viridis(.7), viridis(.9)\n",
    "fig.set_size_inches(6,4)\n",
    "ax.set_xticks((-np.pi,-np.pi/2,0,np.pi/2,np.pi), ('$-\\\\pi$','','0','','$\\\\pi$'))\n",
    "fig.supxlabel('$\\\\alpha$')\n",
    "\n",
    "# 1. Plot trained networks mean\n",
    "plt.plot(test_xs_1d, np.mean(apply_fn_list_n20, axis=0), color=vir1, linewidth=2.4)\n",
    "plt.plot(test_xs_1d, np.mean(apply_fn_list_n100, axis=0), color=vir2, linewidth=2.4)\n",
    "plt.plot(test_xs_1d, np.mean(apply_fn_list_n500, axis=0), color=vir3, linewidth=2.4)\n",
    "\n",
    "# 2. Plot mean\n",
    "plt.plot(test_xs_1d, mu, linewidth=2., color='k', linestyle='solid')\n",
    "\n",
    "# 3. Plot training points\n",
    "ax.plot(train_xs_1d, train_ys, marker='x', color='0.6', linestyle='', ms='7', mew='1.3')\n",
    "\n",
    "patch1 = mpatches.Patch(facecolor=vir1)\n",
    "patch2 = mpatches.Patch(facecolor=vir2)\n",
    "patch3 = mpatches.Patch(facecolor=vir3)\n",
    "fig.legend(handles = [patch1, patch2, patch3], labels=[\"n=20\", \"n=100\", \"n=500\"],\n",
    "       loc=\"upper left\", borderaxespad=0.1, bbox_to_anchor=(.15, 0.85), prop={'size': 10})\n",
    "\n",
    "fig.savefig(\"figure3.pdf\", bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
