{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "import pandas as pd\n",
    "import json\n",
    "import os\n",
    "\n",
    "RANDOM_SEEDS = ['']\n",
    "varying_M = ['1e-2', '1e-3', '1e-4', '3e-2', '3e-3', '3e-4', '5e-2', '5e-3','5e-4', '7e-2', '7e-3', '7e-4'] \n",
    "min_M = min([float(x) for x in varying_M])\n",
    "max_M = max([float(x) for x in varying_M])\n",
    "METHODS = ['retrain', 'gd', 'sgd']\n",
    "FLATTEN_METHODS = ['retrain_results']\n",
    "for m in varying_M:\n",
    "    FLATTEN_METHODS.append(f'gd_results_{m}')\n",
    "    FLATTEN_METHODS.append(f'sgd_results_{m}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "for method in FLATTEN_METHODS:\n",
    "    method_res = []\n",
    "    for seed in RANDOM_SEEDS:\n",
    "        path = f'./fmnist/{method}.json'\n",
    "        if not os.path.exists(path):\n",
    "            print('Wrong path:', path)\n",
    "            continue\n",
    "        seed_res = json.load(open(path))\n",
    "        if len(seed_res) != 60: \n",
    "            print('Incomplete run:', path)\n",
    "            continue\n",
    "        method_res.append(pd.Series(seed_res[60 - 1]))\n",
    "    if len(method_res) == 0:\n",
    "        continue\n",
    "    method_res = pd.concat(method_res)\n",
    "    results[method] = method_res.groupby(method_res.index)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_duplicate(handles, labels):\n",
    "    filter_handles, filter_labels = [], []\n",
    "    for h, l in zip(handles, labels):\n",
    "        if l not in filter_labels:\n",
    "            filter_handles.append(h)\n",
    "            filter_labels.append(l)\n",
    "    print(len(filter_handles))\n",
    "    return filter_handles, filter_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams.update({'font.size': 20})\n",
    "fig, ax = plt.subplots(ncols=1, figsize=(6, 3))\n",
    "# for e in ax: e.grid()\n",
    "\n",
    "# for i in range(2):\n",
    "ax.hlines(results['retrain_results'].mean()['df_acc'], min_M, max_M, ls='--', color='tab:cyan', label=r'Retraining, $D_{e}$')\n",
    "ax.hlines(results['retrain_results'].mean()['dtest_acc'], min_M, max_M, ls='--', color='tab:orange', label=r'Retraining, $D_{test}$')\n",
    "\n",
    "for m in varying_M:\n",
    "    method_name = f'gd_results_{m}'\n",
    "    x = float(m)\n",
    "    y1 = results[method_name].mean()['df_acc']\n",
    "    y2 = results[method_name].mean()['dtest_acc']\n",
    "    ax.scatter(x, y1, marker='X', c='tab:blue', label=r'$D_e$', s=80)\n",
    "    ax.scatter(x, y2, marker='^', c='tab:red', label=r'$D_{test}$', s=80)\n",
    "        \n",
    "    # method_name = f'sgd_results_{m}'\n",
    "    # x = float(m)\n",
    "    # y1 = results[method_name].mean()['df_acc']\n",
    "    # y2 = results[method_name].mean()['dtest_acc']\n",
    "    # ax[1].scatter(x, y1, marker='*', c='blue', label=r'$D_f$')\n",
    "    # ax[1].scatter(x, y2, marker='*', c='red', label=r'$D_{test}$')\n",
    "\n",
    "# for i in range(1):\n",
    "    # ax[i].set_xlim(-5, 220)\n",
    "    # ax[i].set_ylim(-5, 95)\n",
    "ax.set_xlabel('Step Size', labelpad=15)\n",
    "ax.set_ylabel('Accuracy (%)', labelpad=15)\n",
    "\n",
    "# ax[0].fill_between(range(75,220), -5, 95, alpha=0.2, facecolor='gray')\n",
    "# ax.set_title('GD', pad=15)\n",
    "# ax[1].set_title('SGD')\n",
    "ax.set_xscale('log')\n",
    "# ax[1].set_xscale('log')\n",
    "ax.set_xticks([1e-4, 1e-3, 1e-2, 1e-1])\n",
    "ax.set_yticks([0, 45, 90])\n",
    "\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "handles, labels = remove_duplicate(handles, labels)\n",
    "# fig.legend(handles, labels, loc='center left', ncols=1, bbox_to_anchor=(0.92, 0, 1, 1))\n",
    "fig.legend(handles, labels, loc='lower center', ncols=2, bbox_to_anchor=(0, 1, 1, 1))\n",
    "\n",
    "fig.savefig('step-size_sensitivity.png', dpi=220, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "unlearning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
