{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "d=100\n",
    "# load data:\n",
    "lambs = [0.0, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]\n",
    "seeds = [1, 11, 111, 1111, 11111]\n",
    "all_data = {}\n",
    "for lamb in lambs:\n",
    "    all_data[lamb] = []\n",
    "for seed in seeds:\n",
    "    filename = 'seed' + str(seed) + 'data.npy'\n",
    "    file = open(filename, 'rb')\n",
    "    object_file = pickle.load(file)\n",
    "    for lamb in lambs:\n",
    "        all_data[lamb].append(object_file[lamb])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_data = {}\n",
    "std_data = {}\n",
    "for lamb in lambs:\n",
    "    data = np.array(all_data[lamb])\n",
    "    mean_data[lamb] = np.mean(data, axis=0)\n",
    "    std_data[lamb] = np.std(data, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['font.size'] = '16'\n",
    "colors = plt.get_cmap('plasma')(np.linspace(0, 1, 6))\n",
    "\n",
    "fig,axs = plt.subplots(1, 3,figsize=(15, 5))\n",
    "fig.set_facecolor(\"white\")\n",
    "\n",
    "axs = np.ravel(axs)\n",
    "\n",
    "plt.sca(axs[2])\n",
    "plt.axhline(0., linestyle=\"dashed\", color='gray')\n",
    "for i in range(len(lambs)):\n",
    "    lamb = lambs[i]\n",
    "    test_losses = mean_data[lamb][1, :]\n",
    "    stddevs = std_data[lamb][1, :]\n",
    "    x_vals = [np.log(t+1)/np.log(d) for t in range(len(test_losses))]\n",
    "    plt.plot(x_vals, test_losses, color = colors[i], label=r'$\\lambda =$' + str(lamb))\n",
    "    plt.fill_between(x_vals, test_losses - stddevs, test_losses+stddevs, color=colors[i], alpha=0.3)\n",
    "plt.ylabel(\"Test Loss\")\n",
    "plt.xlabel(r'$\\log_d t$')\n",
    "plt.ylim(ymin=-0.01)\n",
    "\n",
    "# plt.legend()\n",
    "\n",
    "\n",
    "plt.sca(axs[1])\n",
    "plt.axhline(0., linestyle=\"dashed\", color='gray')\n",
    "for i in range(len(lambs)):\n",
    "    lamb = lambs[i]\n",
    "    train_losses = mean_data[lamb][0, :]\n",
    "    stddevs = std_data[lamb][0, :]\n",
    "    x_vals = [np.log(t+1)/np.log(d) for t in range(len(train_losses))]\n",
    "    plt.plot(x_vals, train_losses, color = colors[i], label=r'$\\lambda =$' + str(lamb))\n",
    "    plt.fill_between(x_vals, train_losses - stddevs, train_losses+stddevs, color=colors[i], alpha=0.3)\n",
    "plt.ylabel(\"Train Loss\")\n",
    "plt.xlabel(r'$\\log_d t$')\n",
    "plt.ylim(ymin=-0.01)\n",
    "\n",
    "# plt.legend()\n",
    "\n",
    "plt.sca(axs[0])\n",
    "for i in range(len(lambs)):\n",
    "    lamb = lambs[i]\n",
    "    reg = mean_data[lamb][2, :]\n",
    "    # cumsum_vec = np.cumsum(np.insert(reg, 0, 0))\n",
    "    # window_width = 100\n",
    "    # ma_vec = np.concatenate([np.zeros(window_width-1), (cumsum_vec[window_width:] - cumsum_vec[:-window_width]) / window_width])\n",
    "\n",
    "    stddevs = std_data[lamb][2, :]\n",
    "    x_vals = [np.log(t+1)/np.log(d) for t in range(len(reg))]\n",
    "    plt.plot(x_vals, reg, color = colors[i], label=r'$\\lambda =$' + str(lamb))\n",
    "    plt.fill_between(x_vals, reg- stddevs, reg+stddevs, color=colors[i], alpha=0.3)\n",
    "plt.ylabel(r'$\\mathcal{R}_3(\\mathbf{W})$')\n",
    "plt.xlabel(r'$\\log_d t$')\n",
    "plt.legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "# plt.savefig(\"d100_k2.pdf\", bbox_inches = 'tight')\n",
    "\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "a55c72e2be434f5765594e9ed5464ac27f0df6c2513a32031c896f1652198cf7"
  },
  "kernelspec": {
   "display_name": "Python 3.8.13 ('jax')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
