{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "import json\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "# ROOT_DIR = '~/unlearning/cure_newton'\n",
    "# OLD_ROOT_DIR = '~/unlearning/others/curenewton_old/main_results'\n",
    "TOFU_ROOT_DIR = '~/cure_newton/sequential_unlearning/tofu_debug'\n",
    "TOFU_ROOT_DIR_2 = '~/cure_newton/sequential_unlearning/tofu'\n",
    "CIFAR10_ROOT = '~/cure_newton/sequential_unlearning/cifar10-class'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def setup_plot():\n",
    "    lw = 3\n",
    "    ms = 8\n",
    "    style = {\n",
    "        'Retraining': {'color': \"#FF0000\", 'linestyle': '-', 'linewidth': lw, 'marker': 'o', 'ms': ms},\n",
    "        'Original': {'color': \"#D62728\", 'linestyle': '--', 'linewidth': lw, 'marker': 's', 'ms': ms},\n",
    "        'StoCuReNU (ours)': {'color': \"#2CA02C\", 'linestyle': '-.', 'linewidth': lw, 'marker': '^', 'ms': ms},\n",
    "        'CureNewton (ours)': {'color': \"#1F77B4\", 'linestyle': ':', 'linewidth': lw, 'marker': 'D', 'ms': ms},\n",
    "        'GA': {'color': \"#9467BD\", 'linestyle': '--', 'linewidth': lw, 'marker': 'v', 'ms': ms},\n",
    "        'GD': {'color': \"#8C564B\", 'linestyle': '-.', 'linewidth': lw, 'marker': '<', 'ms': ms},\n",
    "        'Rand. Lbls.': {'color':  \"#FF7F0E\", 'linestyle': ':', 'linewidth': lw, 'marker': '>', 'ms': ms},\n",
    "        'PINV-Newton': {'color': \"#7F7F7F\", 'linestyle': '-', 'linewidth': lw, 'marker': 'P', 'ms': ms},\n",
    "        'Damped-Newton': {'color': \"#BCBD22\", 'linestyle': '--', 'linewidth': lw, 'marker': 'X', 'ms': ms},\n",
    "        'GDiff Tune': {'color': \"#17BECF\", 'linestyle': '-.', 'linewidth': lw, 'marker': '*', 'ms': ms},\n",
    "        'SCRUB': {'color': None, 'linestyle': None, 'linewidth': None, 'marker': None, 'ms': None},\n",
    "        'GDiff': {'color': \"#E377C2\", 'linestyle': '-', 'linewidth': lw, 'marker': 'h', 'ms': ms},\n",
    "        'DPO': {'color': \"#AEC7E8\", 'linestyle': '--', 'linewidth': lw, 'marker': 'H', 'ms': ms},\n",
    "    }\n",
    "    label = {\n",
    "        'retraining': 'Retraining',\n",
    "        'original': 'Original',\n",
    "        'ga': 'GA',\n",
    "        'random_labels': 'Rand. Lbls.',\n",
    "        'scrub': 'SCRUB',\n",
    "        'pinv_newton': 'PINV-Newton',\n",
    "        'damp_newton': 'Damped-Newton',\n",
    "        'ntk': 'NTK',\n",
    "        'sgd': 'GD',\n",
    "        'gd': 'GD',\n",
    "        'gdiff': 'GDiff',\n",
    "        'cr_newton': 'CureNewton (ours)',\n",
    "        'scr_newton': 'StoCuReNU (ours)',\n",
    "    }\n",
    "    return style, label\n",
    "\n",
    "plot_styles, plot_labels = setup_plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot(ax, mean, std, label, **kwargs):\n",
    "    color = kwargs[\"color\"]\n",
    "    x = mean.index + 1  # unlearn round starts from 1\n",
    "    ax.plot(x, mean, label=label, **kwargs);\n",
    "    ax.fill_between(x, mean-std, mean+std, alpha=.2, facecolor=color);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "import matplotlib\n",
    "\n",
    "matplotlib.rcParams.update({'font.size': 20})\n",
    "\n",
    "# === Load ROUGE Data ===\n",
    "def read_rouge_data():\n",
    "    unlearn_methods = [\"original\", \"retraining\", \"gd\", \"scr_newton\", \"ga\", \"scrub_slow\", \"gdiff\", \"dpo\"]\n",
    "    random_seeds = [1,2,3]\n",
    "    runs = ['1', '2', '3', '4', '5']\n",
    "    columns_of_interest = ['ROUGE Forget', 'ROUGE Retain', 'ROUGE Real Authors']\n",
    "    outputs = {}\n",
    "    for method in unlearn_methods:\n",
    "        outputs[method] = {}\n",
    "        for run in runs:\n",
    "            dfs = []\n",
    "            for seed in random_seeds:\n",
    "                if \"gdiff\" in method:\n",
    "                    csv_path = os.path.join(TOFU_ROOT_DIR_2, f'seed-{seed}/eval/{method}/round-{run}/tofu.csv') \n",
    "                else:\n",
    "                    csv_path = os.path.join(TOFU_ROOT_DIR, f'seed-{seed}/eval/{method}-{run}/tofu.csv')\n",
    "                if not os.path.exists(csv_path):\n",
    "                    print(f\"Not path: {csv_path}\")\n",
    "                df = pd.read_csv(csv_path)\n",
    "                dfs.append(df[columns_of_interest])\n",
    "            combined = pd.concat(dfs).reset_index(drop=True)\n",
    "            outputs[method][run] = {\"mean\": combined.mean(), \"std\": combined.std()}\n",
    "    return outputs\n",
    "\n",
    "# === Load Accuracy Data ===\n",
    "def read_acc_data():\n",
    "    unlearn_methods = [\"original\", \"retraining\", \"gd\", \"random_labels\", \"ga\", \"scr_newton_gdiff\", \"scrub\"]\n",
    "    random_seeds = [1,2,3]\n",
    "    outputs = {}\n",
    "    for method in unlearn_methods:\n",
    "        res = []\n",
    "        for seed in random_seeds:\n",
    "            path = os.path.join(CIFAR10_ROOT, f'seed-{seed}/by-class/{method}/stats.json')\n",
    "            if not os.path.exists(path):\n",
    "                print(f\"Not exist: {path}\")\n",
    "                continue\n",
    "            stats = json.load(open(path, \"r\"))\n",
    "            stats = pd.DataFrame(stats)\n",
    "            stats[\"index\"] = range(len(stats))\n",
    "            res.append(stats)\n",
    "        res = pd.concat(res).groupby(\"index\")\n",
    "        outputs[method] = {\"mean\": res.mean(), \"std\": res.std()}\n",
    "    return outputs\n",
    "\n",
    "rouge_data = read_rouge_data()\n",
    "acc_data = read_acc_data()\n",
    "\n",
    "# === Plotting ===\n",
    "fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(15, 9))\n",
    "axs = axs.flatten()  # Flatten to make indexing easier\n",
    "\n",
    "rouge_metrics = ['ROUGE Real Authors', 'ROUGE Forget', 'ROUGE Retain']\n",
    "rouge_titles = ['$D_{test}$ ROUGE', '$D_e$ ROUGE', '$D_r$ ROUGE']\n",
    "acc_metrics = ['test_acc', 'accu_forget_acc', 'retain_acc']\n",
    "acc_titles = ['$D_{test}$ Acc. (%)', '$D_e$ Acc. (%)', '$D_r$ Acc. (%)']\n",
    "\n",
    "# --- ROUGE subplots (first 3) ---\n",
    "for i, (ax, metric, title) in enumerate(zip(axs[:3], rouge_metrics, rouge_titles)):\n",
    "    for method_key, display_label in plot_labels.items():\n",
    "        if method_key not in rouge_data:\n",
    "            continue\n",
    "        style = plot_styles.get(display_label, {})\n",
    "        if display_label == \"SCRUB\":\n",
    "            style = plot_styles[\"CureNewton (ours)\"]\n",
    "        try:\n",
    "            runs_data = [rouge_data[method_key][run] for run in ['1', '2', '3', '4', '5']]\n",
    "            mean_series = pd.Series([run['mean'][metric] for run in runs_data])\n",
    "            std_series = pd.Series([run['std'][metric] for run in runs_data])\n",
    "            plot(ax, mean_series, std_series, display_label, **style)\n",
    "        except Exception as e:\n",
    "            print(f\"Skipping {method_key} for {metric} due to: {e}\")\n",
    "            continue\n",
    "    ax.set_title(title, pad=15)\n",
    "    ax.set_xticks([1, 2, 3, 4, 5])\n",
    "    ax.set_xlabel('# Unlearning Rounds', labelpad=10)\n",
    "\n",
    "# --- Accuracy subplots (last 3) ---\n",
    "for i, (ax, metric, title) in enumerate(zip(axs[3:], acc_metrics, acc_titles)):\n",
    "    for method_key, display_label in plot_labels.items():\n",
    "        if method_key not in acc_data:\n",
    "            continue\n",
    "        style = plot_styles.get(display_label, {})\n",
    "        if display_label == \"SCRUB\":\n",
    "            style = plot_styles[\"CureNewton (ours)\"]\n",
    "        plot(ax, acc_data[method_key][\"mean\"][metric], acc_data[method_key][\"std\"][metric], display_label, **style)\n",
    "    ax.set_title(title, pad=15)\n",
    "    ax.set_xlabel('# Unlearning Rounds', labelpad=10)\n",
    "    ax.set_xticks([1, 2, 3, 4, 5])\n",
    "    if 'test' in metric or 'retain' in metric:\n",
    "        ax.set_yticks([0, 45, 90])\n",
    "\n",
    "# --- Legend ---\n",
    "handles, labels = axs[-1].get_legend_handles_labels()\n",
    "rand_lbls_style = plot_styles[\"Rand. Lbls.\"]\n",
    "handles.append(Line2D([0], [0],\n",
    "    label=\"Rand. Lbls.\",\n",
    "    color=rand_lbls_style[\"color\"],\n",
    "    linestyle=rand_lbls_style[\"linestyle\"],\n",
    "    marker=rand_lbls_style[\"marker\"],\n",
    "    linewidth=rand_lbls_style[\"linewidth\"],\n",
    "    markersize=rand_lbls_style[\"ms\"]\n",
    "))\n",
    "labels.append(\"Rand. Lbls.\")\n",
    "\n",
    "# order = [0, 1, 8, 2, 4, 3, 5, 6 ,7]  # change this to match your desired order\n",
    "# handles = [handles[i] for i in order]\n",
    "# labels = [labels[i] for i in order]\n",
    "\n",
    "fig.legend(handles, labels, loc='lower center', ncols=9, bbox_to_anchor=(0, 1, 1, 1))\n",
    "\n",
    "fig.tight_layout()\n",
    "# plt.savefig(\"outputs/sequential_combined_llm_cifar.png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Class-level Sequential Unlearning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TOFU\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data():\n",
    "    unlearn_methods = [\"original\", \"retraining\", \"gd\", \"scr_newton\", \"ga\", \"scrub\", \"KL\", \"dpo\"]\n",
    "    random_seeds = [1,2,3]\n",
    "    runs = ['1','2','3','4','5']\n",
    "    columns_of_interest = [\n",
    "        'ROUGE Forget',\n",
    "        'ROUGE Retain',\n",
    "        'ROUGE Real Authors'\n",
    "    ]\n",
    "    outputs = {}\n",
    "    for method in unlearn_methods:\n",
    "        outputs[method] = {}\n",
    "        for run in runs:\n",
    "            dfs = []\n",
    "            for seed in random_seeds:\n",
    "                csv_path = os.path.join(ROOT_DIR, f'seed-{seed}/eval/{method}-{run}/tofu.csv')\n",
    "                df = pd.read_csv(csv_path)\n",
    "                dfs.append(df[columns_of_interest])\n",
    "                combined = pd.concat(dfs).reset_index(drop=True)\n",
    "            outputs[method][run] = {\"mean\": combined.mean(), \"std\": combined.std()}\n",
    "    \n",
    "    return outputs\n",
    "\n",
    "data = read_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "from matplotlib.lines import Line2D\n",
    "\n",
    "matplotlib.rcParams.update({'font.size': 20})\n",
    "fig, axs = plt.subplots(ncols=3, figsize=(13, 4))\n",
    "\n",
    "# Titles for the 3 metrics\n",
    "plot_titles = {\n",
    "    'ROUGE Real Authors': '$D_{test}$ ROUGE',\n",
    "    'ROUGE Forget': '$D_e$ ROUGE',\n",
    "    'ROUGE Retain': '$D_r$ ROUGE'\n",
    "}\n",
    "metric_order = ['ROUGE Real Authors', 'ROUGE Forget', 'ROUGE Retain']\n",
    "\n",
    "# Plot each subplot\n",
    "for ax, metric in zip(axs, metric_order):\n",
    "    for method_key, display_label in plot_labels.items():\n",
    "        if method_key not in data:\n",
    "            continue\n",
    "        style = plot_styles[display_label]\n",
    "        if display_label == \"SCRUB\":\n",
    "            style = plot_styles[\"CureNewton (ours)\"]\n",
    "        try:\n",
    "            runs_data = [data[method_key][run] for run in ['1', '2', '3', '4', '5']]\n",
    "            mean_series = pd.Series([run['mean'][metric] for run in runs_data])\n",
    "            std_series = pd.Series([run['std'][metric] for run in runs_data])\n",
    "            plot(ax, mean_series, std_series, display_label, **style)\n",
    "        except Exception as e:\n",
    "            print(f\"Skipping {method_key} for {metric} due to: {e}\")\n",
    "            continue\n",
    "    ax.set_title(plot_titles[metric], pad=15)\n",
    "    ax.set_xticks([1, 2, 3, 4, 5])\n",
    "    ax.set_xticklabels(['1', '2', '3', '4', '5'])\n",
    "\n",
    "for i in range(3):\n",
    "    axs[i].set_xlabel('# Unlearning Rounds', labelpad=15)\n",
    "\n",
    "rand_lbls_style = plot_styles[\"Rand. Lbls.\"]\n",
    "rand_lbls_handle = Line2D(\n",
    "    [0], [0],\n",
    "    label=\"Rand. Lbls.\",\n",
    "    color=rand_lbls_style[\"color\"],\n",
    "    linestyle=rand_lbls_style[\"linestyle\"],\n",
    "    marker=rand_lbls_style[\"marker\"],\n",
    "    linewidth=rand_lbls_style[\"linewidth\"],\n",
    "    markersize=rand_lbls_style[\"ms\"]\n",
    ")\n",
    "\n",
    "# Get current handles + labels from one of the axes\n",
    "handles, labels = axs[1].get_legend_handles_labels()\n",
    "\n",
    "# Add the proxy to the legend\n",
    "handles.append(rand_lbls_handle)\n",
    "labels.append(\"Rand. Lbls.\")\n",
    "\n",
    "# order = [0, 4, 1, 6, 4, 5]  # change this to match your desired order\n",
    "# handles = [handles[i] for i in order]\n",
    "# labels = [labels[i] for i in order]\n",
    "\n",
    "fig.legend(handles, labels, loc='lower center', ncols=4, bbox_to_anchor=(0, 1, 1, 1))\n",
    "\n",
    "fig.tight_layout()\n",
    "# plt.savefig(\"outputs/sequential_class_llama.png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Llama"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data():\n",
    "    unlearn_methods = [\"original\", \"retraining\", \"gd\", \"scr_newton\", \"random_labels\", \"ga\", \"scrub\"]\n",
    "    random_seeds = [1,2,5]\n",
    "    outputs = {}\n",
    "\n",
    "    for method in unlearn_methods:\n",
    "        res = []\n",
    "        for seed in random_seeds:\n",
    "            path = f\"{ROOT_DIR}/sequential_unlearning/llama/ag_news/seed-{seed}/by-class/{method}/stats.json\"\n",
    "            # print(f\"Load stats from {path}\")\n",
    "            stats = json.load(open(path, \"r\"))\n",
    "            stats = pd.DataFrame(stats)\n",
    "            stats[\"index\"] = range(len(stats))\n",
    "            res.append(stats)\n",
    "        res = pd.concat(res).groupby(\"index\")\n",
    "        outputs[method] = {\"mean\": res.mean(), \"std\": res.std()}\n",
    "    \n",
    "    return outputs\n",
    "\n",
    "data = read_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig, axes = plt.subplots(nrows=2, figsize=(10,6))\n",
    "matplotlib.rcParams.update({'font.size': 20})\n",
    "fig, axs = plt.subplots(ncols=3, figsize=(13, 4))\n",
    "# for e in axs: e.grid()\n",
    "\n",
    "ax = axs[1]\n",
    "metrics = \"accu_forget_eval_accuracy\"\n",
    "for method, label in plot_labels.items():\n",
    "    style = plot_styles[label]\n",
    "    if label == \"SCRUB\":\n",
    "        style = plot_styles[\"CureNewton (ours)\"]\n",
    "    if method not in data: continue\n",
    "    plot(ax, data[method][\"mean\"][metrics], data[method][\"std\"][metrics], label=label, **style)\n",
    "\n",
    "ax = axs[0]\n",
    "metrics = \"test_eval_accuracy\"\n",
    "for method, label in plot_labels.items():\n",
    "    style = plot_styles[label]\n",
    "    if label == \"SCRUB\":\n",
    "        style = plot_styles[\"CureNewton (ours)\"]\n",
    "    if method not in data: continue\n",
    "    plot(ax, data[method][\"mean\"][metrics], data[method][\"std\"][metrics], label=label, **style)\n",
    "\n",
    "ax = axs[2]\n",
    "metrics = \"retain_eval_accuracy\"\n",
    "for method, label in plot_labels.items():\n",
    "    print(method)\n",
    "    style = plot_styles[label]\n",
    "    if label == \"SCRUB\":\n",
    "        style = plot_styles[\"CureNewton (ours)\"]\n",
    "    if method not in data: continue\n",
    "    plot(ax, data[method][\"mean\"][metrics], data[method][\"std\"][metrics], label=label, **style)\n",
    "\n",
    "axs[0].set_title('$D_{test}$ Acc. (%)', pad=15)\n",
    "axs[1].set_title('$D_e$  Acc. (%)', pad=15)\n",
    "axs[2].set_title('$D_r$  Acc. (%)', pad=15)\n",
    "axs[1].set_ylim(0, 110)\n",
    "\n",
    "handles, labels = axs[1].get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc='lower center', ncols=4, bbox_to_anchor=(0, 1, 1, 1))\n",
    "for i in range(3):\n",
    "    # axs[i].set_ylabel('Accuracy (%)')\n",
    "    axs[i].set_xlabel('# Unlearning Rounds', labelpad=15)\n",
    "axs[0].set_yticks([20, 60, 100])\n",
    "axs[1].set_yticks([0, 50, 100])\n",
    "axs[2].set_yticks([20, 60, 100])\n",
    "axs[1].set_ylim([-5, 105])\n",
    "axs[2].set_ylim([20, 100])\n",
    "\n",
    "fig.tight_layout();\n",
    "# plt.savefig(\"outputs/sequential_class_llama.png\", dpi=300, bbox_inches='tight');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ResNet18"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data():\n",
    "    unlearn_methods = [\"original\", \"retraining\", \"gd\", \"scr_newton\", \"scrub\", \"random_labels\", \"ga\"]\n",
    "    random_seeds = [1,2,3]\n",
    "    outputs = {}\n",
    "\n",
    "    for method in unlearn_methods:\n",
    "        res = []\n",
    "        for seed in random_seeds:\n",
    "            path = f\"../main_results/sequential_unlearning/cifar10/seed-{seed}/by-class/{method}/stats.json\"\n",
    "            # print(f\"Load stats from {path}\")\n",
    "            stats = json.load(open(path, \"r\"))\n",
    "            stats = pd.DataFrame(stats)\n",
    "            stats[\"index\"] = range(len(stats))\n",
    "            res.append(stats)\n",
    "        res = pd.concat(res).groupby(\"index\")\n",
    "        outputs[method] = {\"mean\": res.mean(), \"std\": res.std()}\n",
    "    \n",
    "    return outputs\n",
    "\n",
    "data = read_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(data[\"scr_newton\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams.update({'font.size': 20})\n",
    "fig, axs = plt.subplots(ncols=3, figsize=(13, 4))\n",
    "\n",
    "ax = axs[1]\n",
    "metrics = \"accu_forget_acc\"\n",
    "for method, label in plot_labels.items():\n",
    "    style = plot_styles[label]\n",
    "    if label == \"SCRUB\":\n",
    "        style = plot_styles[\"CureNewton (ours)\"]\n",
    "    if method not in data: continue\n",
    "    plot(ax, data[method][\"mean\"][metrics], data[method][\"std\"][metrics], label=label, **style)\n",
    "\n",
    "ax = axs[0]\n",
    "metrics = \"test_acc\"\n",
    "for method, label in plot_labels.items():\n",
    "    style = plot_styles[label]\n",
    "    if label == \"SCRUB\":\n",
    "        style = plot_styles[\"CureNewton (ours)\"]\n",
    "    if method not in data: continue\n",
    "    plot(ax, data[method][\"mean\"][metrics], data[method][\"std\"][metrics], label=label, **style)\n",
    "\n",
    "ax = axs[2]\n",
    "metrics = \"retain_acc\"\n",
    "for method, label in plot_labels.items():\n",
    "    style = plot_styles[label]\n",
    "    if label == \"SCRUB\":\n",
    "        style = plot_styles[\"CureNewton (ours)\"]\n",
    "    if method not in data: continue\n",
    "    plot(ax, data[method][\"mean\"][metrics], data[method][\"std\"][metrics], label=label, **style)\n",
    "\n",
    "axs[0].set_title('$D_{test}$ Acc. (%)', pad=15)\n",
    "axs[1].set_title('$D_e$  Acc. (%)', pad=15)\n",
    "axs[2].set_title('$D_r$  Acc. (%)', pad=15)\n",
    "\n",
    "handles, labels = axs[0].get_legend_handles_labels()\n",
    "# fig.legend(handles, labels, loc='lower center', ncols=4, bbox_to_anchor=(0, 1, 1, 1))\n",
    "for i in range(3):\n",
    "    # axs[i].set_ylabel('Accuracy (%)')\n",
    "    axs[i].set_xlabel('# Unlearning Rounds', labelpad=15)\n",
    "    axs[i].set_xticks([1, 2, 3, 4, 5])\n",
    "axs[0].set_yticks([0, 45, 90])\n",
    "axs[2].set_yticks([0, 45, 90])\n",
    "\n",
    "fig.tight_layout();\n",
    "plt.savefig(\"outputs/sequential_class_resnet18.png\", dpi=300, bbox_inches='tight');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### FashionMNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data():\n",
    "    unlearn_methods = [\n",
    "        \"original\", \"retraining\", \"gd\", \"scr_newton\", \"random_labels\", \"ga\", \\\n",
    "        \"cr_newton\", \"pinv_newton\", \"damp_newton\", \"ntk\"]\n",
    "    random_seeds = [127, 128, 129]\n",
    "    outputs = {}\n",
    "\n",
    "    for method in unlearn_methods:\n",
    "        res = []\n",
    "        for seed in random_seeds:\n",
    "            path = f\"../sequential_unlearning/outputs/fashion_mnist/by-class/seed-{seed}/{method}/stats.json\"\n",
    "            print(f\"Load stats from {path}\")\n",
    "            stats = json.load(open(path, \"r\"))\n",
    "            stats = pd.DataFrame(stats)\n",
    "            stats[\"index\"] = range(len(stats))\n",
    "            res.append(stats)\n",
    "        res = pd.concat(res).groupby(\"index\")\n",
    "        outputs[method] = {\"mean\": res.mean(), \"std\": res.std()}\n",
    "    \n",
    "    return outputs\n",
    "\n",
    "data = read_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[\"original\"][\"mean\"][\"dtest_acc\"] = 92.21\n",
    "data[\"original\"][\"std\"][\"dtest_acc\"] = 0.21\n",
    "data[\"original\"][\"mean\"][\"df_all_acc\"] = 92.23\n",
    "data[\"original\"][\"std\"][\"df_all_acc\"] = 0.03\n",
    "data[\"original\"][\"mean\"][\"dr_acc\"] = 92.65\n",
    "data[\"original\"][\"std\"][\"dr_acc\"] = 0.01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[\"original\"][\"mean\"] = pd.concat([data[\"original\"][\"mean\"]] * 60, ignore_index=True) \n",
    "data[\"original\"][\"mean\"][\"round\"] = range(60)\n",
    "data[\"original\"][\"std\"] = pd.concat([data[\"original\"][\"std\"]] * 60, ignore_index=True) \n",
    "data[\"original\"][\"std\"][\"round\"] = range(60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams.update({'font.size': 17})\n",
    "fig, axs = plt.subplots(ncols=3, figsize=(13, 3.5))\n",
    "\n",
    "ax = axs[1]\n",
    "metrics = \"df_all_acc\"\n",
    "x = data[\"retraining\"][\"mean\"][\"round\"]\n",
    "for method, label in plot_labels.items():\n",
    "    style = plot_styles[label]\n",
    "    if method not in data: continue\n",
    "    plot(ax, x, data[method][\"mean\"][metrics], data[method][\"std\"][metrics], label=label, markevery=10, **style)\n",
    "\n",
    "ax = axs[0]\n",
    "metrics = \"dtest_acc\"\n",
    "x = data[\"retraining\"][\"mean\"][\"round\"]\n",
    "for method, label in plot_labels.items():\n",
    "    style = plot_styles[label]\n",
    "    if method not in data: continue\n",
    "    plot(ax, x, data[method][\"mean\"][metrics], data[method][\"std\"][metrics], label=label, markevery=10, **style)\n",
    "\n",
    "ax = axs[2]\n",
    "metrics = \"dr_acc\"\n",
    "x = data[\"retraining\"][\"mean\"][\"round\"]\n",
    "for method, label in plot_labels.items():\n",
    "    style = plot_styles[label]\n",
    "    if method not in data: continue\n",
    "    plot(ax, x, data[method][\"mean\"][metrics], data[method][\"std\"][metrics], label=label, markevery=10, **style)\n",
    "\n",
    "axs[0].set_title('$D_{test}$ Acc. (%)', pad=15)\n",
    "axs[1].set_title('$D_e$  Acc. (%)', pad=15)\n",
    "axs[2].set_title('$D_r$  Acc. (%)', pad=15)\n",
    "\n",
    "handles, labels = axs[1].get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc='lower center', ncols=5, bbox_to_anchor=(0, 1, 1, 1))\n",
    "for i in range(3):\n",
    "    # axs[i].set_ylabel('Accuracy (%)')\n",
    "    axs[i].set_xlabel('Num Requests')\n",
    "axs[0].set_yticks([0, 45, 90])\n",
    "axs[1].set_yticks([0, 45, 90])\n",
    "axs[2].set_yticks([0, 45, 90])\n",
    "axs[0].set_ylim(-5, 98)\n",
    "axs[1].set_ylim(-5, 98)\n",
    "# axs[2].set_ylim(-5, 98)\n",
    "\n",
    "fig.tight_layout();\n",
    "plt.savefig(\"sequential_fashion_mnist.png\", dpi=220, bbox_inches='tight');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Instance-level Sequential Unlearning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Llama-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data():\n",
    "    unlearn_methods = [\"original\", \"retraining\", \"ga\", \"gd\", \"random_labels\", \"scr_newton\", \"scrub\"]\n",
    "    random_seeds = [1,2,3]\n",
    "    outputs = {}\n",
    "\n",
    "    for method in unlearn_methods:\n",
    "        res = []\n",
    "        for seed in random_seeds:\n",
    "            path = f\"{OLD_ROOT_DIR}/sequential_unlearning/llama/ag_news/seed-{seed}/by-instance/{method}/stats.json\"\n",
    "            # print(f\"Load stats from {path}\")\n",
    "            stats = json.load(open(path, \"r\"))\n",
    "            stats = pd.DataFrame(stats)\n",
    "            stats[\"index\"] = range(len(stats))\n",
    "            # print(len(stats), method)\n",
    "            res.append(stats)\n",
    "        res = pd.concat(res).groupby(\"index\")\n",
    "        outputs[method] = {\"mean\": res.mean(), \"std\": res.std()}\n",
    "    \n",
    "    return outputs\n",
    "\n",
    "data = read_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = read_data()\n",
    "matplotlib.rcParams.update({'font.size': 17})\n",
    "fig, axs = plt.subplots(ncols=3, figsize=(13, 3.5))\n",
    "\n",
    "ax = axs[1]\n",
    "metrics = \"accu_forget_eval_accuracy\"\n",
    "x = data[\"retraining\"][\"mean\"][\"unlearn_round\"] + 1\n",
    "for method, label in plot_labels.items():\n",
    "    style = plot_styles[label]\n",
    "    if label == \"SCRUB\":\n",
    "        style = plot_styles[\"CureNewton (ours)\"]\n",
    "    if method not in data: continue\n",
    "    plot(ax, data[method][\"mean\"][metrics], data[method][\"std\"][metrics], label=label, **style)\n",
    "\n",
    "ax = axs[0]\n",
    "metrics = \"test_eval_accuracy\"\n",
    "for method, label in plot_labels.items():\n",
    "    style = plot_styles[label]\n",
    "    if label == \"SCRUB\":\n",
    "        style = plot_styles[\"CureNewton (ours)\"]\n",
    "    if method not in data: continue\n",
    "    plot(ax, data[method][\"mean\"][metrics], data[method][\"std\"][metrics], label=label, **style)\n",
    "\n",
    "ax = axs[2]\n",
    "metrics = \"retain_eval_accuracy\"\n",
    "x = data[\"retraining\"][\"mean\"][\"unlearn_round\"]\n",
    "for method, label in plot_labels.items():\n",
    "    style = plot_styles[label]\n",
    "    if label == \"SCRUB\":\n",
    "        style = plot_styles[\"CureNewton (ours)\"]\n",
    "    if method not in data: continue\n",
    "    plot(ax, data[method][\"mean\"][metrics], data[method][\"std\"][metrics], label=label, **style)\n",
    "\n",
    "axs[0].set_title('$D_{test}$ Acc. (%)', pad=15)\n",
    "axs[1].set_title('$D_e$  Acc. (%)', pad=15)\n",
    "axs[2].set_title('$D_r$  Acc. (%)', pad=15)\n",
    "\n",
    "# normal view\n",
    "axs[0].set_ylim(20, 102)\n",
    "axs[0].set_yticks([25, 60, 95])\n",
    "axs[1].set_ylim(20, 102)\n",
    "axs[1].set_yticks([25, 60, 95])\n",
    "axs[2].set_ylim(20, 102)\n",
    "axs[2].set_yticks([25, 60, 95])\n",
    "\n",
    "# zoom\n",
    "# axs[0].set_ylim(91, 97)\n",
    "# axs[0].set_yticks([92, 94, 96])\n",
    "# axs[1].set_ylim(92, 98)\n",
    "# axs[1].set_yticks([93, 95, 97])\n",
    "# axs[2].set_ylim(93, 99)\n",
    "# axs[2].set_yticks([94, 96, 98])\n",
    "\n",
    "handles, labels = axs[1].get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc='lower center', ncols=4, bbox_to_anchor=(0, 1, 1, 1))\n",
    "for i in range(3):\n",
    "    axs[i].set_xlabel('# Unlearning Rounds', labelpad=15)\n",
    "    axs[i].set_xticks([1, 2, 3, 4, 5, 6])\n",
    "\n",
    "fig.tight_layout();\n",
    "# plt.savefig(\"outputs/sequential_instance_llama.png\", dpi=300, bbox_inches='tight');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CIFAR-10 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data():\n",
    "    import os\n",
    "    unlearn_methods = [\"original\", \"retraining\", \"sgd\", \"ga\", \"random_labels\", \"scr_newton\", \"scrub\"]\n",
    "    random_seeds = [1,2,3]\n",
    "    outputs = {}\n",
    "\n",
    "    for method in unlearn_methods:\n",
    "        res = []\n",
    "        for seed in random_seeds:\n",
    "            path = f\"../main_results/sequential_unlearning/cifar10/seed-{seed}/by-instance/{method}/stats.json\"\n",
    "            if not os.path.exists(path):\n",
    "                print(f\"Can't read {path}\")\n",
    "                continue\n",
    "            # print(f\"Load stats from {path}\")\n",
    "            stats = json.load(open(path, \"r\"))\n",
    "            stats = pd.DataFrame(stats)\n",
    "            stats[\"index\"] = range(len(stats))\n",
    "            # print(len(stats), method)\n",
    "            res.append(stats)\n",
    "        res = pd.concat(res).groupby(\"index\")\n",
    "        outputs[method] = {\"mean\": res.mean(), \"std\": res.std()}\n",
    "    \n",
    "    return outputs\n",
    "\n",
    "data = read_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig, axes = plt.subplots(nrows=2, figsize=(10,6))\n",
    "data = read_data()\n",
    "matplotlib.rcParams.update({'font.size': 17})\n",
    "fig, axs = plt.subplots(ncols=3, figsize=(13, 4))\n",
    "\n",
    "ax = axs[1]\n",
    "metrics = \"accu_forget_acc\"\n",
    "for method, label in plot_labels.items():\n",
    "    style = plot_styles[label]\n",
    "    if label == \"SCRUB\":\n",
    "        style = plot_styles[\"CureNewton (ours)\"]\n",
    "    if method not in data: \n",
    "        print(\"Does not have\" + method)\n",
    "        continue\n",
    "    plot(ax, data[method][\"mean\"][metrics], data[method][\"std\"][metrics], label=label, **style)\n",
    "\n",
    "ax = axs[0]\n",
    "metrics = \"test_acc\"\n",
    "for method, label in plot_labels.items():\n",
    "    style = plot_styles[label]\n",
    "    if label == \"SCRUB\":\n",
    "        style = plot_styles[\"CureNewton (ours)\"]\n",
    "    if method not in data: continue\n",
    "    plot(ax, data[method][\"mean\"][metrics], data[method][\"std\"][metrics], label=label, **style)\n",
    "\n",
    "ax = axs[2]\n",
    "metrics = \"retain_acc\"\n",
    "for method, label in plot_labels.items():\n",
    "    style = plot_styles[label]\n",
    "    if label == \"SCRUB\":\n",
    "        style = plot_styles[\"CureNewton (ours)\"]\n",
    "    if method not in data: continue\n",
    "    plot(ax, data[method][\"mean\"][metrics], data[method][\"std\"][metrics], label=label, **style)\n",
    "\n",
    "axs[0].set_title('$D_{test}$ Acc. (%)', pad=15)\n",
    "axs[1].set_title('$D_e$  Acc. (%)', pad=15)\n",
    "axs[2].set_title('$D_r$  Acc. (%)', pad=15)\n",
    "\n",
    "# normal view\n",
    "axs[0].set_yticks([30, 60, 90])\n",
    "axs[1].set_yticks([30, 60, 90])\n",
    "axs[2].set_yticks([30, 60, 90])\n",
    "\n",
    "# zoom-in view\n",
    "# axs[0].set_ylim(77, 93)\n",
    "# axs[0].set_yticks([80, 85, 90])\n",
    "# axs[1].set_ylim(79, 95)\n",
    "# axs[1].set_yticks([82, 87, 92])\n",
    "# axs[2].set_ylim(79, 95)\n",
    "# axs[2].set_yticks([82, 87, 92])\n",
    "\n",
    "handles, labels = axs[1].get_legend_handles_labels()\n",
    "# fig.legend(handles, labels, loc='lower center', ncols=4, bbox_to_anchor=(0, 1, 1, 1))\n",
    "for i in range(3):\n",
    "    axs[i].set_xlabel('# Unlearning Rounds', labelpad=15)\n",
    "    axs[i].set_xticks([1, 2, 3, 4, 5])\n",
    "\n",
    "fig.tight_layout();\n",
    "plt.savefig(\"outputs/sequential_instance_resnet18.png\", dpi=300, bbox_inches='tight');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Time Efficiency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for method in data:\n",
    "#     print(\"Method:\", method)\n",
    "#     time = data[method][\"std\"][\"running_time\"]\n",
    "#     mean_time = time.sum() / len(time)\n",
    "#     print(\"Time:\", mean_time)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "unlearning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
