{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "import json\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "\n",
    "retraining_style = {'color': \"black\", 'linestyle': '--', 'linewidth': 2.5, 'marker': '^'}\n",
    "def read_data():\n",
    "    results = {}\n",
    "    steps = [1, 10, 20, 30, 40, 50, 70, 100, 200, \"retrain\"]\n",
    "    for num_step in steps:\n",
    "        if num_step == \"retrain\":\n",
    "            path = f\"./results/{num_step}.json\"\n",
    "        else:\n",
    "            path = f\"./results/step-{num_step}.json\"\n",
    "        stats = json.load(open(path))\n",
    "        res = []\n",
    "        for it_stat in stats:\n",
    "            metrics = it_stat[\"metrics\"]\n",
    "            for subset in (\"test\", \"rem\", \"era\"):\n",
    "                metrics[subset] = metrics[subset][1] * 100\n",
    "            res.append(metrics)\n",
    "        stats = pd.DataFrame(res)\n",
    "        stats.index *= 2\n",
    "        results[num_step] = stats\n",
    "    return results\n",
    "\n",
    "data = read_data()\n",
    "retraining_data = data.pop(\"retrain\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot(ax, data, label, **kwargs):\n",
    "    ax.plot(data, label=label, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams.update({'font.size': 19})\n",
    "fig, axs = plt.subplots(ncols=3, figsize=(13, 3.5))\n",
    "\n",
    "ax = axs[0]\n",
    "metric = \"test\"\n",
    "for num_step, res in data.items():\n",
    "    plot(ax, res[metric], label=f\"$T = {num_step}$\")\n",
    "plot(ax, retraining_data[metric], label=\"Retraining\", **retraining_style, markevery=5)\n",
    "ax.set_title(r\"$D_{test}$ Acc. (%)\", pad=15)\n",
    "ax.set_ylim(81.5, 95)\n",
    "ax.set_yticks([82, 88, 94])\n",
    "\n",
    "ax = axs[1]\n",
    "metric = \"era\"\n",
    "for num_step, res in data.items():\n",
    "    plot(ax, res[metric], label=f\"$T = {num_step}$\")\n",
    "ax.set_title(r\"$D_{e}$ Acc. (%)\", pad=15)\n",
    "plot(ax, retraining_data[metric], label=\"Retraining\", **retraining_style, markevery=5)\n",
    "# ax.set_ylim(0, 100)\n",
    "\n",
    "ax = axs[2]\n",
    "metric = \"rem\"\n",
    "for num_step, res in data.items():\n",
    "    plot(ax, res[metric], label=f\"$T = {num_step}$\")\n",
    "plot(ax, retraining_data[metric], label=\"Retraining\", **retraining_style, markevery=5)\n",
    "ax.set_title(r\"$D_{r}$ Acc. (%)\", pad=15)\n",
    "ax.set_ylim(93.8, 96)\n",
    "\n",
    "for i in range(3):\n",
    "    axs[i].set_xlabel(\"# Unlearning Rounds\", labelpad=15)\n",
    "    axs[i].set_xticks([1, 20, 40, 60])\n",
    "\n",
    "handles, labels = axs[-1].get_legend_handles_labels()\n",
    "handles = [handles[-1], *handles[:-1]]\n",
    "labels = [labels[-1], *labels[:-1]]\n",
    "fig.legend(handles, labels, loc='lower center', ncols=5, bbox_to_anchor=(0, 1, 1, 1))\n",
    "fig.tight_layout()\n",
    "plt.savefig(\"scr_step.png\", dpi=300, bbox_inches='tight');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "unlearning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
