{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python392jvsc74a57bd04db6292f37c6ec4dc50894061fe55ccde808eb8d30dbd8f895ae7b935187ec56",
   "display_name": "Python 3.9.2 64-bit ('gsc_functions': conda)"
  },
  "metadata": {
   "interpreter": {
    "hash": "4db6292f37c6ec4dc50894061fe55ccde808eb8d30dbd8f895ae7b935187ec56"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from spline_ls import save_dir\n",
    "from gscfunc.utils import *\n",
    "\n",
    "# Color blind friendly palette \n",
    "palette = ['#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628', '#984ea3', '#999999', '#e41a1c', '#dede00']\n",
    "markers = ['o', '^', 's', '*']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = []\n",
    "for f in save_dir.glob(\"*n4\"):\n",
    "    if f.is_dir():\n",
    "        cur_config = json.loads((f / \"config.json\").read_text())\n",
    "        cur_df = pd.read_csv(f / \"results.csv\")\n",
    "        # Remove the last seed as job are still running. Comment otherwise!\n",
    "        # cur_df = cur_df.loc[cur_df['seed'] < cur_df['seed'].max()]\n",
    "        df.append(cur_df)\n",
    "        print(f\"Found file at {f}\")\n",
    "df = pd.concat(df)\n",
    "print(f\"n max: {df['n'].max()}\\n# seeds: {len(df['seed'].unique())}\")"
   ]
  },
  {
   "source": [
    "## Lowest value obtained"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _min_loss_obtained(group, columns, by='loss_l2'):\n",
    "    return group.sort_values(by).iloc[0][columns]\n",
    "df_loss = df.groupby(['n', 'r', 't', 'seed']).apply(_min_loss_obtained, columns=['loss_l2', 'reg']).reset_index()\n",
    "df_loss['log_n'] = np.log(df_loss['n'])\n",
    "df_loss['log_loss'] = np.log(df_loss['loss_l2'])\n",
    "# Saving a copy\n",
    "df_loss_copy = df_loss.copy()"
   ]
  },
  {
   "source": [
    "## Regression coefficients"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reg_coeff = regression_coefficients(df_loss.groupby(['r', 't', 'n'])['loss_l2'].agg(['mean']).reset_index(), 'n', 'mean', groupby=['r', 't'], log=True).reset_index()\n",
    "reg_coeff['a_th'] = -reg_coeff.apply(lambda row: rate_risk(row['r'], row['t'], alpha=2.), axis=1)\n",
    "reg_coeff.set_index(['r', 't'], inplace=True)\n",
    "reg_coeff.drop('b', axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_loss_ms = df_loss.groupby(['r', 't', 'log_n'])[['log_loss', 'reg', 'loss_l2']].agg(['mean', 'std'])\n",
    "df_loss_ms"
   ]
  },
  {
   "source": [
    "## Sample complexity plots"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({\n",
    "    # Use LaTeX to write all text\n",
    "    \"text.usetex\": True,\n",
    "    \"font.family\": \"serif\",\n",
    "    \"font.sans-serif\": [\"Helvetica\"],\n",
    "    'text.latex.preamble': r'\\usepackage{amsmath}',\n",
    "    # Use 10pt font in plots, to match 10pt font in document\n",
    "    \"axes.labelsize\": 10,\n",
    "    \"font.size\": 10,\n",
    "    # Make the legend/label fonts a little smaller\n",
    "    \"legend.fontsize\": 8,\n",
    "    \"xtick.labelsize\": 8,\n",
    "    \"ytick.labelsize\": 8})"
   ]
  },
  {
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(192/72, 192/72))\n",
    "palette = plt.get_cmap('tab10').colors\n",
    "ts_to_plot = [1, 3, 8]\n",
    "r_to_plot = 10.25\n",
    "\n",
    "for i, t in enumerate(ts_to_plot):\n",
    "    c = palette[i]\n",
    "    df_plot = df_loss_ms.xs((r_to_plot, t), level=('r', 't')).xs('loss_l2', axis=1).reset_index()\n",
    "    df_plot['n'] = np.exp(df_plot['log_n'])\n",
    "\n",
    "    ax.plot(df_plot['n'], df_plot['mean'], label=f\"$t={t}$\", color=c, marker=markers[i], markersize=3, linestyle='')\n",
    "    ax.plot(df_plot['n'], linear_trend(df_plot, 'n', 'mean', -rate_risk(r_to_plot, t, 2.), log=True), alpha=.5, linestyle='-', color=c)\n",
    "\n",
    "ax.set_xlabel(\"$n$\")\n",
    "ax.set_ylabel(\"$L(\\\\hat{\\\\theta}_\\\\lambda^t) - L(\\\\theta^\\\\star)$\")\n",
    "ax.set_xscale('log')\n",
    "ax.set_yscale('log')\n",
    "plt.legend()\n",
    "ax.tick_params(axis='y', labelrotation = 90)\n",
    "ax.margins(0, 0)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(f\"ls_sample_complexity_r{r_to_plot}.pdf\", transparent=True, bbox_inches='tight', pad_inches=0)"
   ],
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "source": [
    "## Regularization plots"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_loss['log_reg'] = np.log(df_loss['reg'])\n",
    "df_reg_ms = df_loss.groupby(['r', 't', 'log_n'])[['log_reg']].agg(['mean', 'std'])\n",
    "df_reg_ms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(192/72, 192/72))\n",
    "palette = plt.get_cmap('tab10').colors\n",
    "ts_to_plot = [1, 3, 8]\n",
    "r_to_plot = 0.25\n",
    "\n",
    "for i, t in enumerate(ts_to_plot):\n",
    "    c = palette[i]\n",
    "    df_plot = df_reg_ms.xs((r_to_plot, t), level=('r', 't')).xs('log_reg', axis=1).reset_index()\n",
    "    df_plot['n'] = np.exp(df_plot['log_n'])\n",
    "\n",
    "    ax.plot(df_plot['n'], np.exp(df_plot['mean']), label=f\"$t={t}$\", color=c, marker=markers[i], markersize=3, linestyle='')\n",
    "    ax.plot(df_plot['n'], np.exp(linear_trend(df_plot, 'log_n', 'mean', -rate_reg(r_to_plot, t, 2.), log=False)), alpha=.5, linestyle='-', color=c)\n",
    "\n",
    "ax.set_xlabel(\"$n$\")\n",
    "ax.set_ylabel(\"$\\\\lambda$\")\n",
    "ax.set_xscale('log')\n",
    "ax.set_yscale('log')\n",
    "plt.legend()\n",
    "ax.tick_params(axis='y', labelrotation = 90)\n",
    "ax.margins(0, 0)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(f\"ls_optimal_reg_r{r_to_plot}.pdf\", transparent=True, bbox_inches='tight', pad_inches=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}