{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_plot_setting():\n",
    "    plt.style.use(\"seaborn-v0_8-talk\")\n",
    "    plt.rcParams.update({\n",
    "        \"axes.titlesize\": \"x-large\",\n",
    "        \"axes.labelsize\": \"xx-large\",\n",
    "    })\n",
    "    os.makedirs(\"plots\", exist_ok=True)\n",
    "\n",
    "load_plot_setting()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MNIST dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Eigenspectrum plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data():\n",
    "    random_seeds = [1, 2, 3]\n",
    "    res = []\n",
    "    for seed in random_seeds:\n",
    "        path = f\"../temp_cr_newton_mnist_{seed}/hess.npy\"\n",
    "        print(f\"Loading Hessian matrix from {path}\")\n",
    "        hess = np.load(path)\n",
    "        print(\"Computing eigenvalues\")\n",
    "        eigen_values = np.linalg.eigvalsh(hess)\n",
    "        res.append(eigen_values)\n",
    "    return res\n",
    "\n",
    "data = read_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot(ax, x, bins=10):\n",
    "    ax.hist(x, bins=bins, alpha=.3, color='tab:blue');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(9, 7))\n",
    "plot(ax, data[0]);\n",
    "plot(ax, data[1]);\n",
    "plot(ax, data[2]);\n",
    "ax.set_xscale(\"log\")\n",
    "ax.set_yscale(\"log\")\n",
    "ax.set_xlabel(\"Eigenvalues (log scale)\");\n",
    "ax.set_ylabel(\"Frequency (log scale)\");\n",
    "plt.savefig(\"eigenspectrum_mnist.png\", dpi=200);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Performance radar plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data():\n",
    "    unlearn_methods = [\"retraining\", \"cr_newton\", \"pinv_newton\", \"damped_newton-gamma=1e-1\", \"damped_newton-gamma=1e-2\", \"damped_newton-gamma=1e-3\", \"damped_newton-gamma=1e-4\"]\n",
    "    random_seeds = [1,2,3]\n",
    "    outputs = {}\n",
    "\n",
    "    for method in unlearn_methods:\n",
    "        res = []\n",
    "        for seed in random_seeds:\n",
    "            path = f\"../oneshot_unlearning/outputs/mnist/seed-{seed}/by_class/{method}/stats.json\"\n",
    "            print(f\"Load stats from {path}\")\n",
    "            stats = json.load(open(path, \"r\"))\n",
    "            stats = pd.DataFrame(stats)\n",
    "            stats[\"index\"] = range(len(stats))\n",
    "            res.append(stats)\n",
    "        res = pd.concat(res).groupby(\"index\")\n",
    "        outputs[method] = {\"mean\": res.mean(), \"std\": res.std()}\n",
    "    \n",
    "    return outputs\n",
    "\n",
    "data = read_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from matplotlib.patches import Circle, RegularPolygon\n",
    "from matplotlib.path import Path\n",
    "from matplotlib.projections import register_projection\n",
    "from matplotlib.projections.polar import PolarAxes\n",
    "from matplotlib.spines import Spine\n",
    "from matplotlib.transforms import Affine2D\n",
    "\n",
    "\n",
    "def radar_factory(num_vars, frame='circle'):\n",
    "    \"\"\"\n",
    "    Create a radar chart with `num_vars` Axes.\n",
    "\n",
    "    This function creates a RadarAxes projection and registers it.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    num_vars : int\n",
    "        Number of variables for radar chart.\n",
    "    frame : {'circle', 'polygon'}\n",
    "        Shape of frame surrounding Axes.\n",
    "\n",
    "    \"\"\"\n",
    "    # calculate evenly-spaced axis angles\n",
    "    theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)\n",
    "\n",
    "    class RadarTransform(PolarAxes.PolarTransform):\n",
    "\n",
    "        def transform_path_non_affine(self, path):\n",
    "            # Paths with non-unit interpolation steps correspond to gridlines,\n",
    "            # in which case we force interpolation (to defeat PolarTransform's\n",
    "            # autoconversion to circular arcs).\n",
    "            if path._interpolation_steps > 1:\n",
    "                path = path.interpolated(num_vars)\n",
    "            return Path(self.transform(path.vertices), path.codes)\n",
    "\n",
    "    class RadarAxes(PolarAxes):\n",
    "\n",
    "        name = 'radar'\n",
    "        PolarTransform = RadarTransform\n",
    "\n",
    "        def __init__(self, *args, **kwargs):\n",
    "            super().__init__(*args, **kwargs)\n",
    "            # rotate plot such that the first axis is at the top\n",
    "            self.set_theta_zero_location('N')\n",
    "\n",
    "        def fill(self, *args, closed=True, **kwargs):\n",
    "            \"\"\"Override fill so that line is closed by default\"\"\"\n",
    "            return super().fill(closed=closed, *args, **kwargs)\n",
    "\n",
    "        def plot(self, *args, **kwargs):\n",
    "            \"\"\"Override plot so that line is closed by default\"\"\"\n",
    "            lines = super().plot(*args, **kwargs)\n",
    "            for line in lines:\n",
    "                self._close_line(line)\n",
    "\n",
    "        def _close_line(self, line):\n",
    "            x, y = line.get_data()\n",
    "            # FIXME: markers at x[0], y[0] get doubled-up\n",
    "            if x[0] != x[-1]:\n",
    "                x = np.append(x, x[0])\n",
    "                y = np.append(y, y[0])\n",
    "                line.set_data(x, y)\n",
    "\n",
    "        def set_varlabels(self, labels):\n",
    "            self.set_thetagrids(np.degrees(theta), labels)\n",
    "\n",
    "        def _gen_axes_patch(self):\n",
    "            # The Axes patch must be centered at (0.5, 0.5) and of radius 0.5\n",
    "            # in axes coordinates.\n",
    "            if frame == 'circle':\n",
    "                return Circle((0.5, 0.5), 0.5)\n",
    "            elif frame == 'polygon':\n",
    "                return RegularPolygon((0.5, 0.5), num_vars,\n",
    "                                      radius=.5, edgecolor=\"k\")\n",
    "            else:\n",
    "                raise ValueError(\"Unknown value for 'frame': %s\" % frame)\n",
    "\n",
    "        def _gen_axes_spines(self):\n",
    "            if frame == 'circle':\n",
    "                return super()._gen_axes_spines()\n",
    "            elif frame == 'polygon':\n",
    "                # spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.\n",
    "                spine = Spine(axes=self,\n",
    "                              spine_type='circle',\n",
    "                              path=Path.unit_regular_polygon(num_vars))\n",
    "                # unit_regular_polygon gives a polygon of radius 1 centered at\n",
    "                # (0, 0) but we want a polygon of radius 0.5 centered at (0.5,\n",
    "                # 0.5) in axes coordinates.\n",
    "                spine.set_transform(Affine2D().scale(.5).translate(.5, .5)\n",
    "                                    + self.transAxes)\n",
    "                return {'polar': spine}\n",
    "            else:\n",
    "                raise ValueError(\"Unknown value for 'frame': %s\" % frame)\n",
    "\n",
    "    register_projection(RadarAxes)\n",
    "    return theta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot(ax, theta, y_values, label, fill=True):\n",
    "    ax.plot(theta, y_values, label=label);\n",
    "    if fill:\n",
    "        ax.fill(theta, y_values, alpha=0.4, label='_nolegend_');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 3\n",
    "theta = radar_factory(N, frame='polygon')\n",
    "fig, ax = plt.subplots(figsize=(9,9), subplot_kw=dict(projection='radar'))\n",
    "metrics = [\"accu_forget_acc\", \"retain_acc\", \"test_acc\"]\n",
    "metric_labels = [\"$D_e$ Acc. (%)\", \"$D_r$ Acc. (%)\", \"$D_{test}$ Acc. (%)\"]\n",
    "ax.set_varlabels(metric_labels)\n",
    "\n",
    "plot(ax, theta, data[\"retraining\"][\"mean\"][metrics].values.flatten(), label=\"Retraining\")\n",
    "plot(ax, theta, data[\"cr_newton\"][\"mean\"][metrics].values.flatten(), label=\"CR-Newton\")\n",
    "plot(ax, theta, data[\"pinv_newton\"][\"mean\"][metrics].values.flatten(), label=\"PINV-Newton\", fill=False)\n",
    "plot(ax, theta, data[\"damped_newton-gamma=1e-4\"][\"mean\"][metrics].values.flatten(), label=\"Damped-Newton, $\\\\gamma$=1e-4\", fill=False)\n",
    "plot(ax, theta, data[\"damped_newton-gamma=1e-3\"][\"mean\"][metrics].values.flatten(), label=\"Damped-Newton, $\\\\gamma$=1e-3\", fill=False)\n",
    "plot(ax, theta, data[\"damped_newton-gamma=1e-2\"][\"mean\"][metrics].values.flatten(), label=\"Damped-Newton, $\\\\gamma$=1e-2\", fill=False)\n",
    "plot(ax, theta, data[\"damped_newton-gamma=1e-1\"][\"mean\"][metrics].values.flatten(), label=\"Damped-Newton, $\\\\gamma$=1e-1\", fill=False)\n",
    "ax.set_rgrids([0, 20, 40, 60, 80, 100], angle=0.)\n",
    "ax.legend(loc=\"upper left\");\n",
    "plt.savefig(\"catastrophic_forgetting_mnist.png\", dpi=200, bbox_inches=\"tight\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "N = 9\n",
    "theta = radar_factory(N, frame='polygon')\n",
    "\n",
    "data = example_data()\n",
    "spoke_labels = data.pop(0)\n",
    "\n",
    "fig, axs = plt.subplots(figsize=(9, 9), nrows=2, ncols=2,\n",
    "                        subplot_kw=dict(projection='radar'))\n",
    "fig.subplots_adjust(wspace=0.25, hspace=0.20, top=0.85, bottom=0.05)\n",
    "\n",
    "colors = ['b', 'r', 'g', 'm', 'y']\n",
    "# Plot the four cases from the example data on separate Axes\n",
    "for ax, (title, case_data) in zip(axs.flat, data):\n",
    "    ax.set_rgrids([0.2, 0.4, 0.6, 0.8])\n",
    "    ax.set_title(title, weight='bold', size='medium', position=(0.5, 1.1),\n",
    "                    horizontalalignment='center', verticalalignment='center')\n",
    "    for d, color in zip(case_data, colors):\n",
    "        ax.plot(theta, d, color=color)\n",
    "        ax.fill(theta, d, facecolor=color, alpha=0.25, label='_nolegend_')\n",
    "    ax.set_varlabels(spoke_labels)\n",
    "\n",
    "# add legend relative to top-left plot\n",
    "labels = ('Factor 1', 'Factor 2', 'Factor 3', 'Factor 4', 'Factor 5')\n",
    "legend = axs[0, 0].legend(labels, loc=(0.9, .95),\n",
    "                            labelspacing=0.1, fontsize='small')\n",
    "\n",
    "fig.text(0.5, 0.965, '5-Factor Solution Profiles Across Four Scenarios',\n",
    "            horizontalalignment='center', color='black', weight='bold',\n",
    "            size='large')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "unlearning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
