{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python392jvsc74a57bd0162f7e99a1de538f8d997c341f1c90e0db991e9064797812f6a3d0b59d889ee8",
   "display_name": "Python 3.9.2 64-bit ('gsc_cyan': conda)"
  },
  "metadata": {
   "interpreter": {
    "hash": "4db6292f37c6ec4dc50894061fe55ccde808eb8d30dbd8f895ae7b935187ec56"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from spline_logistic import save_dir\n",
    "from gscfunc.utils import *\n",
    "\n",
    "# Color blind friendly palette \n",
    "palette = ['#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628', '#984ea3', '#999999', '#e41a1c', '#dede00']\n",
    "markers = ['o', '^', 's', '*']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = []\n",
    "for f in save_dir.glob(\"*cyan_n4\"):\n",
    "    if f.is_dir():\n",
    "        cur_config = json.loads((f / \"config.json\").read_text())\n",
    "        cur_df = pd.read_csv(f / \"results.csv\")\n",
    "        # Remove the last seed as job are still running. Comment otherwise!\n",
    "        # cur_df = cur_df.loc[cur_df['seed'] < cur_df['seed'].max()]\n",
    "        cur_df = cur_df.loc[cur_df['n'] < 1e4]\n",
    "        df.append(cur_df)\n",
    "        print(f\"Found file at {f}\")\n",
    "df = pd.concat(df)\n",
    "print(f\"n max: {df['n'].max()}\\n# seeds: {len(df['seed'].unique())}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Fraction not converged: {np.mean(df['n_steps'] == 1000):.1%}\")"
   ]
  },
  {
   "source": [
    "## Lowest value obtained\n",
    "Don't know if there's a more intuitive way with Pandas:"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _min_loss_obtained(group, columns, by='excess_risk'):\n",
    "    return group.sort_values(by).iloc[0][columns]\n",
    "df_loss = df.groupby(['n', 'r', 't', 'seed']).apply(_min_loss_obtained, columns=['excess_risk', 'reg']).reset_index()\n",
    "df_loss['log_n'] = np.log(df_loss['n'])\n",
    "df_loss['log_excess_risk'] = np.log(df_loss['excess_risk'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Saving a copy\n",
    "df_loss_copy = df_loss.copy()\n",
    "# df_loss = df_loss_copy.copy()"
   ]
  },
  {
   "source": [
    "## Learning rate plots"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "### 1. Regression coefficient\n",
    "\n",
    "We want to estimate the quantities $a, b$ in: $\\log ER = a \\log n + b$\n",
    "\n",
    "What seems to work is:\n",
    "* Take the mean\n",
    "* Take the log\n",
    "* Compute the regression\n",
    "\n",
    "I would have thought that computing the log and not taking the mean would be better, but it is clearly not. \n",
    "\n",
    "\n"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reg_coeff = regression_coefficients(df_loss.groupby(['r', 't', 'n'])['excess_risk'].agg(['mean']).reset_index(), 'n', 'mean', groupby=['r', 't'], log=True).reset_index()\n",
    "reg_coeff['a_th'] = -reg_coeff.apply(lambda row: rate_risk(row['r'], row['t'], alpha=2.), axis=1)\n",
    "reg_coeff.set_index(['r', 't'], inplace=True)\n",
    "reg_coeff.drop('b', axis=1)"
   ]
  },
  {
   "source": [
    "### 2. Mean and Std\n",
    "\n",
    "They are computed on the log quantities. "
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_loss_ms = df_loss.groupby(['r', 't', 'log_n'])[['excess_risk', 'log_excess_risk', 'reg']].agg(['mean', 'std'])\n",
    "df_loss_ms"
   ]
  },
  {
   "source": [
    "### 3. Plots\n",
    "\n",
    "And finally, we plot."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({\n",
    "    # Use LaTeX to write all text\n",
    "    \"text.usetex\": True,\n",
    "    \"font.family\": \"serif\",\n",
    "    \"font.sans-serif\": [\"Helvetica\"],\n",
    "    'text.latex.preamble': r'\\usepackage{amsmath}',\n",
    "    # Use 10pt font in plots, to match 10pt font in document\n",
    "    \"axes.labelsize\": 10,\n",
    "    \"font.size\": 10,\n",
    "    # Make the legend/label fonts a little smaller\n",
    "    \"legend.fontsize\": 8,\n",
    "    \"xtick.labelsize\": 8,\n",
    "    \"ytick.labelsize\": 8})  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(190/72, 190/72))\n",
    "ts_to_plot = [1, 3, 8]\n",
    "r_to_plot = 10.25\n",
    "\n",
    "for i, t in enumerate(ts_to_plot):\n",
    "    c = palette[i]\n",
    "    df_plot = df_loss_ms.xs((r_to_plot, t), level=('r', 't')).xs('excess_risk', axis=1).reset_index()\n",
    "    df_plot['n'] = np.exp(df_plot['log_n'])\n",
    "\n",
    "    ax.plot(df_plot['n'], df_plot['mean'], label=f\"$t={t}$\", color=c, linestyle='', marker=markers[i], markersize=3)\n",
    "    ax.plot(df_plot['n'], linear_trend(df_plot, 'n', 'mean', -rate_risk(r_to_plot, t, 2.), log=True), alpha=.5, linestyle='-', color=c)\n",
    "\n",
    "ax.set_xlabel(\"$n$\")\n",
    "ax.set_ylabel(\"$L(\\\\hat{\\\\theta}_\\\\lambda^t) - L(\\\\theta^\\\\star)$\")\n",
    "# ax.set_title(f\"$r={r_to_plot}$, $\\\\alpha = {2}$\")\n",
    "ax.set_xscale('log')\n",
    "ax.set_yscale('log')\n",
    "plt.legend()\n",
    "ax.tick_params(axis='y', labelrotation = 90)\n",
    "ax.margins(0, 0)\n",
    "# ax.xaxis.set_major_locator(plt.NullLocator())\n",
    "# ax.yaxis.set_major_locator(plt.NullLocator())\n",
    "# ax.grid(True, which='both')\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(f\"logistic_sample_complexity_r{r_to_plot}.pdf\", transparent=True, bbox_inches='tight', pad_inches=0)"
   ]
  },
  {
   "source": [
    "## Improvement over Tikhonov\n",
    "\n",
    "Aim is to show that it is consistently better than Tikho."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Put t in columns\n",
    "df_imp = df_loss.pivot_table(index=['r', 'n', 'seed'], columns='t', values='excess_risk')\n",
    "df_imp "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Log?\n",
    "df_imp = np.log(df_imp)\n",
    "\n",
    "# Compute diff with t=1\n",
    "for k in df_imp.columns:\n",
    "    df_imp[f\"{k}_diff\"] = df_imp[k] - df_imp[1]\n",
    "# Rename the column 1, 2, ..., t\n",
    "df_imp = df_imp.drop([k for k in df_imp.columns if not str(k).endswith('_diff')], axis=1)\n",
    "df_imp = df_imp.rename(lambda x: int(x[0]), axis=1)\n",
    "df_imp = df_imp.drop(1, axis=1)\n",
    "df_imp"
   ]
  },
  {
   "source": [
    "### Fraction of the times where IT improves stuff"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(df_imp.stack().unstack(['t', 'n']) < 0).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Aggregate seed\n",
    "df_imp_agg = df_imp.stack().reset_index().rename({0: 'loss_l2'}, axis=1)\n",
    "df_imp_agg = df_imp_agg.groupby(['r', 'n', 't'])['loss_l2'].agg(['mean', 'std'])"
   ]
  },
  {
   "source": [
    "### Plots"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(192/72, 192/72))\n",
    "ts_to_plot = [3, 8]\n",
    "r_to_plot = 10.25\n",
    "\n",
    "for i, t in enumerate(ts_to_plot):\n",
    "    c = palette[i+1]\n",
    "    df_plot = df_imp_agg.xs((r_to_plot, t), level=('r', 't')).reset_index()\n",
    "    df_plot['log_n'] = np.log(df_plot['n'])\n",
    "\n",
    "    ax.plot(df_plot['n'], np.exp(df_plot['mean']), label=t, color=c, marker=markers[i], markersize=3)\n",
    "    add_ci(df_plot['n'], df_plot['mean'], df_plot['std'], log=True, edges_kwargs={'alpha':0, 'color': c}, fill_kwargs={'alpha':.2, 'color':c}, ax=ax)\n",
    "\n",
    "ax.axhline(1, color='k', linestyle='--')\n",
    "\n",
    "ax.set_xlim(df_plot['n'].min(), df_plot['n'].max())\n",
    "ax.set_ylabel('$ER(t) / ER(1)$')\n",
    "ax.set_xlabel('$n$')\n",
    "ax.set_xscale('log')\n",
    "ax.set_yscale('log')\n",
    "ax.tick_params(axis='y', labelrotation = 90)\n",
    "ax.minorticks_off()\n",
    "ax.margins(0, 0)\n",
    "plt.tight_layout()\n",
    "plt.legend()\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(f\"logistic_IT_vs_Tikho_r{r_to_plot}.pdf\", transparent=True, bbox_inches='tight', pad_inches=0)"
   ]
  },
  {
   "source": [
    "## Optimal $\\lambda$"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "### 1. Regression coefficients"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reg_coeff_lambd = regression_coefficients(df_loss.groupby(['n', 'r', 't'])['reg'].agg(['mean', 'std']).reset_index(), 'n', 'mean', groupby=['r', 't'], log=True).reset_index()\n",
    "reg_coeff_lambd['a_th'] = -reg_coeff_lambd.apply(lambda row: rate_reg(row['r'], row['t'], alpha=2.), axis=1)\n",
    "reg_coeff_lambd.set_index(['r', 't'], inplace=True)\n",
    "reg_coeff_lambd"
   ]
  },
  {
   "source": [
    "### 2. Aggregate over seed"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_loss['log_reg'] = np.log(df_loss['reg'])\n",
    "df_reg_ms = df_loss.groupby(['r', 't', 'log_n'])[['log_reg']].agg(['mean', 'std'])\n",
    "df_reg_ms"
   ]
  },
  {
   "source": [
    "### 3. Plots"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(192/72, 192/72))\n",
    "palette = plt.get_cmap('tab10').colors\n",
    "ts_to_plot = [1, 3, 8]\n",
    "r_to_plot = 0.25\n",
    "\n",
    "for i, t in enumerate(ts_to_plot):\n",
    "    c = palette[i]\n",
    "    df_plot = df_reg_ms.xs((r_to_plot, t), level=('r', 't')).xs('log_reg', axis=1).reset_index()\n",
    "    df_plot['n'] = np.exp(df_plot['log_n'])\n",
    "\n",
    "    ax.plot(df_plot['n'], np.exp(df_plot['mean']), label=f\"$t={t}$\", color=c, linestyle='', marker=markers[i], markersize=3)\n",
    "    ax.plot(df_plot['n'], np.exp(linear_trend(df_plot, 'log_n', 'mean', -rate_reg(r_to_plot, t, 2.), log=False)), alpha=.5, linestyle='-', color=c)\n",
    "\n",
    "ax.set_xlabel(\"$n$\")\n",
    "ax.set_ylabel(\"$\\\\lambda$\")\n",
    "ax.set_xscale('log')\n",
    "ax.set_yscale('log')\n",
    "plt.legend()\n",
    "ax.tick_params(axis='y', labelrotation = 90)\n",
    "ax.margins(0, 0)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(f\"logistic_optimal_reg_r{r_to_plot}.pdf\", transparent=True, bbox_inches='tight', pad_inches=0)"
   ]
  },
  {
   "source": [
    "## About the chosen regularization\n",
    "\n",
    "This is just a technical part to check that doing all the regularization path is not necessary: we can stop as soon as `loss_l2` is increasing. This is true in theory, but we check that it is also in practice (optimization errors might make this statement false)."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "### Loss is decreasing then increasing\n",
    "\n",
    "If we look at `loss = f(reg)`, we should be decreasing then increasing. Checked by the condition below. "
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_is_convex_with_reg(df):\n",
    "    def _is_convex_t(group):\n",
    "        return (np.sum(np.sign(group.sort_values('reg')['loss_l2'].diff()).diff()>0) == 1).prod()\n",
    "    return df.groupby('t').apply(_is_convex_t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = df.groupby(['r', 'n', 'seed']).apply(loss_is_convex_with_reg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(res)"
   ]
  },
  {
   "source": [
    "This condition is satisfied. "
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idxmin = df.groupby(['n', 'r', 't', 'seed'])['loss_l2'].idxmin()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.iloc[idxmin].groupby('t')['reg'].describe()"
   ]
  },
  {
   "source": [
    "### Reg($t$) < Reg($t-1$)\n",
    "\n",
    "If we want to stop the computation of IT($t$) early, we have to ensure that IT($t+1$) does not need lower regularization. Again, true in theory, but what about practice? "
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for label, group in df.iloc[idxmin].groupby(['n', 'seed', 'r']):\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reg_is_monotonic(df):\n",
    "    return df.sort_values('t')['reg'].is_monotonic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = df.iloc[idxmin].groupby(['n', 'seed', 'r']).apply(reg_is_monotonic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(res)"
   ]
  },
  {
   "source": [
    "This condition is satisfied too!"
   ],
   "cell_type": "markdown",
   "metadata": {}
  }
 ]
}