{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ROOT_DIR = ''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Llama TOFU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_tofu_sequential_unlearning():\n",
    "    outputs = {}\n",
    "    mia_outputs = {}\n",
    "\n",
    "    random_seeds = [41, 42, 43]\n",
    "\n",
    "    # class unlearning \n",
    "    method_list = ['retraining', 'original', 'finetune', 'grad_ascent', 'grad_diff',\n",
    "                   'scrub', 'idk', 'npo', 'scr_newton']\n",
    "    for method in method_list:\n",
    "        res_all_seeds = []\n",
    "        mia_res_all_seeds = []\n",
    "        for seed in random_seeds:\n",
    "            path = f\"{ROOT_DIR}/llama2_tofu/seed_{seed}/eval/{method}/tofu.csv\"\n",
    "            mia_path = f\"{ROOT_DIR}/llama2_tofu/seed_{seed}/eval/{method}/mia.json\"\n",
    "\n",
    "            if not os.path.exists(path):\n",
    "                print(f\"Can't find path: {path}\")\n",
    "                continue\n",
    "            with open(path, \"r\") as f:\n",
    "                stats = pd.read_csv(f)\n",
    "            stats[\"index\"] = range(len(stats))\n",
    "            res_all_seeds.append(stats)\n",
    "\n",
    "            if not os.path.exists(mia_path):\n",
    "                print(f\"Can't find MIA path: {mia_path}\")\n",
    "                continue\n",
    "            mia_stats = json.load(open(mia_path, \"r\"))\n",
    "            if isinstance(mia_stats, dict):\n",
    "                mia_stats = mia_stats['holdout_forget_Min-40%'] \n",
    "            mia_res_all_seeds.append(mia_stats)\n",
    "\n",
    "        outputs[method] = res_all_seeds\n",
    "        mia_outputs[method] = mia_res_all_seeds\n",
    "\n",
    "    return outputs, mia_outputs\n",
    "\n",
    "data, data_mia = load_tofu_sequential_unlearning()\n",
    "method_list = list(data.keys())\n",
    "\n",
    "# keys for TOFU\n",
    "forget_key = 'ROUGE Forget'\n",
    "retain_key = 'ROUGE Retain'\n",
    "test_key = 'ROUGE Real Authors'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Qwen TOFU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_tofu_sequential_unlearning():\n",
    "    outputs = {}\n",
    "    mia_outputs = {}\n",
    "\n",
    "    random_seeds = [41, 42, 43]\n",
    "\n",
    "    # class unlearning \n",
    "    method_list = ['retraining', 'original', 'finetune', 'grad_ascent', 'grad_diff',\n",
    "                   'scrub', 'idk', 'npo', 'scr_newton']\n",
    "    for method in method_list:\n",
    "        res_all_seeds = []\n",
    "        mia_res_all_seeds = []\n",
    "        for seed in random_seeds:\n",
    "            path = f\"{ROOT_DIR}/qwen_tofu/seed_{seed}/eval/{method}/tofu.csv\"\n",
    "            mia_path = f\"{ROOT_DIR}/qwen_tofu/seed_{seed}/eval/{method}/mia.json\"\n",
    "\n",
    "            if not os.path.exists(path):\n",
    "                print(f\"Can't find path: {path}\")\n",
    "                continue\n",
    "            with open(path, \"r\") as f:\n",
    "                stats = pd.read_csv(f)\n",
    "            stats[\"index\"] = range(len(stats))\n",
    "            res_all_seeds.append(stats)\n",
    "\n",
    "            if not os.path.exists(mia_path):\n",
    "                print(f\"Can't find MIA path: {mia_path}\")\n",
    "                continue\n",
    "            mia_stats = json.load(open(mia_path, \"r\"))\n",
    "            if isinstance(mia_stats, dict):\n",
    "                mia_stats = mia_stats['holdout_forget_Min-40%'] \n",
    "            mia_res_all_seeds.append(mia_stats)\n",
    "\n",
    "        outputs[method] = res_all_seeds\n",
    "        mia_outputs[method] = mia_res_all_seeds\n",
    "\n",
    "    return outputs, mia_outputs\n",
    "\n",
    "data, data_mia = load_tofu_sequential_unlearning()\n",
    "method_list = list(data.keys())\n",
    "\n",
    "# keys for TOFU\n",
    "forget_key = 'ROUGE Forget'\n",
    "retain_key = 'ROUGE Retain'\n",
    "test_key = 'ROUGE Real Authors'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aggregate results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retrain_res = data['retraining']\n",
    "\n",
    "normalize = False\n",
    "\n",
    "for method in method_list:\n",
    "    print('Method:', method)\n",
    "    seed = 1\n",
    "    for seed_res, retrain_seed_res in zip(data[method], retrain_res):\n",
    "        de_acc = seed_res[forget_key]\n",
    "        dr_acc = seed_res[retain_key]\n",
    "        dtest_acc = seed_res[test_key]\n",
    "        if normalize:\n",
    "            max_value = 100.0\n",
    "        else:\n",
    "            max_value = 1.0\n",
    "        tow_score = 1 - (abs(de_acc - retrain_seed_res[forget_key]) / max_value)\n",
    "        tow_score *= (1 - (abs(dr_acc - retrain_seed_res[retain_key]) / max_value))\n",
    "        tow_score *= (1 - (abs(dtest_acc - retrain_seed_res[test_key]) / max_value))\n",
    "            \n",
    "        seed_res['tow_score'] = tow_score\n",
    "        seed_res['truth_ratio'] = seed_res['Truth Ratio Forget']\n",
    "        seed_res['mia'] = data_mia[method][seed - 1]\n",
    "        seed += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "forget_acc_list = {method: [] for method in method_list}\n",
    "retain_acc_list = {method: [] for method in method_list}\n",
    "test_acc_list = {method: [] for method in method_list}\n",
    "tow_list = {method: [] for method in method_list}\n",
    "truth_ratio_list = {method: [] for method in method_list}\n",
    "mia_list = {method: [] for method in method_list}\n",
    "\n",
    "for method in method_list:\n",
    "    print(method)\n",
    "    for seed_res in data[method]:\n",
    "        forget_acc_list[method].append(seed_res.iloc[-1][forget_key])\n",
    "        retain_acc_list[method].append(seed_res.iloc[-1][retain_key])\n",
    "        test_acc_list[method].append(seed_res.iloc[-1][test_key])\n",
    "        tow_list[method].append(seed_res.iloc[-1]['tow_score'])\n",
    "        truth_ratio_list[method].append(seed_res.iloc[-1]['truth_ratio'])\n",
    "        mia_list[method].append(seed_res.iloc[-1]['mia'])\n",
    "    forget_acc_list[method] = {'mean': np.mean(forget_acc_list[method]), 'std': np.std(forget_acc_list[method])}\n",
    "    retain_acc_list[method] = {'mean': np.mean(retain_acc_list[method]), 'std': np.std(retain_acc_list[method])}\n",
    "    test_acc_list[method] = {'mean': np.mean(test_acc_list[method]), 'std': np.std(test_acc_list[method])}\n",
    "    tow_list[method] = {'mean': np.mean(tow_list[method]), 'std': np.std(tow_list[method])}\n",
    "    truth_ratio_list[method] = {'mean': np.mean(truth_ratio_list[method]), 'std': np.std(truth_ratio_list[method])}\n",
    "    mia_list[method] = {'mean': np.mean(mia_list[method]), 'std': np.std(mia_list[method])}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Method | Forget Acc | Retain Acc | Test Acc | ToW Score | Truth Ratio | MIA')\n",
    "print('--- | --- | --- | --- | ---')\n",
    "for method in method_list:\n",
    "    print('{} & {:.3f} ± {:.3f} & {:.3f} ± {:.3f} & {:.3f} ± {:.3f} & {:.3f} ± {:.3f} & {:.3f} ± {:.3f} & {:.4f} ± {:.4f} \\\\\\\\'.format(\n",
    "        method,\n",
    "        forget_acc_list[method]['mean'], forget_acc_list[method]['std'], \n",
    "        retain_acc_list[method]['mean'], retain_acc_list[method]['std'],\n",
    "        test_acc_list[method]['mean'], test_acc_list[method]['std'],\n",
    "        tow_list[method]['mean'], tow_list[method]['std'], \n",
    "        truth_ratio_list[method]['mean'], truth_ratio_list[method]['std'],\n",
    "        mia_list[method]['mean'], mia_list[method]['std'],\n",
    "    ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "method_list = [\n",
    "    'retraining', 'original', 'finetune', 'grad_ascent', \n",
    "    'grad_diff', 'scrub', 'idk', 'npo', 'scr_newton'\n",
    "]\n",
    "seeds = [41, 42, 43]\n",
    "\n",
    "\n",
    "print(f\"{'Method':<15} | {'Avg Time (s)':<15} | {'Std Dev':<15}\")\n",
    "print(\"-\" * 50)\n",
    "\n",
    "for method in method_list:\n",
    "    times = []\n",
    "    \n",
    "    for seed in seeds:\n",
    "        # Construct path based on your format\n",
    "        file_path = f\"{ROOT_DIR}/qwen_arxiv/seed_{seed}/{method}/unlearning_time.txt\"\n",
    "        \n",
    "        try:\n",
    "            with open(file_path, \"r\") as f:\n",
    "                content = f.read().strip()\n",
    "                time_val = float(content.split()[0])\n",
    "                times.append(time_val)\n",
    "        except FileNotFoundError:\n",
    "            # Handle cases where a specific seed/method might be missing (e.g., 'original' might not have a time)\n",
    "            pass\n",
    "        except ValueError:\n",
    "            print(f\"Error parsing value in: {file_path}\")\n",
    "\n",
    "    # Compute Statistics if data exists\n",
    "    if times:\n",
    "        avg_time = np.mean(times)\n",
    "        std_time = np.std(times)\n",
    "        print(f\"{method:<15} | {avg_time:<15.4f} | {std_time:<15.4f}\")\n",
    "    else:\n",
    "        print(f\"{method:<15} | {'N/A':<15} | {'N/A':<15}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "method_list = [\n",
    "    'retraining', 'original', 'finetune', 'grad_ascent', \n",
    "    'grad_diff', 'scrub', 'idk', 'npo', 'scr_newton'\n",
    "]\n",
    "seeds = [41, 42, 43]\n",
    "\n",
    "\n",
    "print(f\"{'Method':<15} | {'Avg Time (s)':<15} | {'Std Dev':<15}\")\n",
    "print(\"-\" * 50)\n",
    "\n",
    "for method in method_list:\n",
    "    times = []\n",
    "    \n",
    "    for seed in seeds:\n",
    "        # Construct path based on your format\n",
    "        file_path = f\"{ROOT_DIR}/qwen_tofu/seed_{seed}/{method}/memory.txt\"\n",
    "        \n",
    "        try:\n",
    "            with open(file_path, \"r\") as f:\n",
    "                content = f.read().strip()\n",
    "                # Parse string like \"254.55 seconds\" -> get \"254.55\" -> convert to float\n",
    "                time_val = float(content.split()[0])\n",
    "                times.append(time_val)\n",
    "        except FileNotFoundError:\n",
    "            # Handle cases where a specific seed/method might be missing (e.g., 'original' might not have a time)\n",
    "            pass\n",
    "        except ValueError:\n",
    "            print(f\"Error parsing value in: {file_path}\")\n",
    "\n",
    "    # Compute Statistics if data exists\n",
    "    if times:\n",
    "        avg_time = np.mean(times)\n",
    "        std_time = np.std(times)\n",
    "        print(f\"{method:<15} | {avg_time:<15.4f} | {std_time:<15.4f}\")\n",
    "    else:\n",
    "        print(f\"{method:<15} | {'N/A':<15} | {'N/A':<15}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sequential Unlearning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "ROOT_DIR = ''\n",
    "random_seeds = [41, 42, 43]\n",
    "rounds = [1, 2, 3, 4, 5]\n",
    "method_list = ['retraining', 'finetune', 'grad_ascent', 'grad_diff', \n",
    "               'scrub', 'idk', 'npo', 'scr_newton']\n",
    "\n",
    "def load_sequential_tofu():\n",
    "    outputs = {m: {r: [] for r in rounds} for m in method_list}\n",
    "    mia_outputs = {m: {r: [] for r in rounds} for m in method_list}\n",
    "\n",
    "    for method in method_list:\n",
    "        for r in rounds:\n",
    "            for seed in random_seeds:\n",
    "                path = f\"{ROOT_DIR}/llama2_tofu/seed_{seed}/sequential/eval/{method}-{r}/tofu.csv\"\n",
    "                mia_path = f\"{ROOT_DIR}/llama2_tofu/seed_{seed}/sequential/eval/{method}-{r}/mia.json\"\n",
    "\n",
    "                if not os.path.exists(path):\n",
    "                    print(f\"Missing path: {path}\")\n",
    "                    continue\n",
    "                \n",
    "                try:\n",
    "                    stats = pd.read_csv(path)\n",
    "                    stats[\"index\"] = range(len(stats))\n",
    "                    outputs[method][r].append(stats)\n",
    "                except Exception as e:\n",
    "                    print(f\"Error loading CSV {path}: {e}\")\n",
    "                    continue\n",
    "\n",
    "                if not os.path.exists(mia_path):\n",
    "                    continue\n",
    "                \n",
    "                try:\n",
    "                    with open(mia_path, \"r\") as f:\n",
    "                        mia_stats = json.load(f)\n",
    "                        mia_stats = mia_stats.get('holdout_forget_Min-40%', 0) \n",
    "                    mia_outputs[method][r].append(mia_stats)\n",
    "                except Exception as e:\n",
    "                    print(f\"Error loading MIA {mia_path}: {e}\")\n",
    "                    continue\n",
    "\n",
    "    return outputs, mia_outputs\n",
    "\n",
    "data, data_mia = load_sequential_tofu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "forget_key = 'ROUGE Forget'\n",
    "retain_key = 'ROUGE Retain'\n",
    "test_key = 'ROUGE Real Authors'\n",
    "truth_key = 'Truth Ratio Forget'\n",
    "\n",
    "metrics_to_plot = ['Forget', 'Retain']\n",
    "results = {metric: {m: {'mean': [], 'std': []} for m in method_list} for metric in metrics_to_plot}\n",
    "\n",
    "normalize = False \n",
    "max_value = 100.0 if normalize else 1.0\n",
    "\n",
    "for r in rounds:\n",
    "    retrain_seeds_df = data.get('retraining', {}).get(r, [])\n",
    "\n",
    "    for method in method_list:\n",
    "        round_vals = {k: [] for k in metrics_to_plot}\n",
    "        \n",
    "        seeds_data = data[method][r]\n",
    "        mia_data = data_mia[method][r]\n",
    "\n",
    "        if not seeds_data:\n",
    "            for k in metrics_to_plot:\n",
    "                results[k][method]['mean'].append(np.nan)\n",
    "                results[k][method]['std'].append(np.nan)\n",
    "            continue\n",
    "\n",
    "        for i, seed_res in enumerate(seeds_data):\n",
    "            de_acc = seed_res.iloc[-1][forget_key]\n",
    "            dr_acc = seed_res.iloc[-1][retain_key]\n",
    "            dtest_acc = seed_res.iloc[-1][test_key]\n",
    "\n",
    "            if i < len(retrain_seeds_df):\n",
    "                retrain_res = retrain_seeds_df[i].iloc[-1]\n",
    "                \n",
    "                tow_score = 1 - (abs(de_acc - retrain_res[forget_key]) / max_value)\n",
    "                tow_score *= (1 - (abs(dr_acc - retrain_res[retain_key]) / max_value))\n",
    "                tow_score *= (1 - (abs(dtest_acc - retrain_res[test_key]) / max_value))\n",
    "            else:\n",
    "                tow_score = np.nan\n",
    "\n",
    "            round_vals['Forget'].append(de_acc)\n",
    "            round_vals['Retain'].append(dr_acc)\n",
    "\n",
    "        for metric in metrics_to_plot:\n",
    "            clean_vals = [v for v in round_vals[metric] if not np.isnan(v)]\n",
    "            if clean_vals:\n",
    "                results[metric][method]['mean'].append(np.mean(clean_vals))\n",
    "                results[metric][method]['std'].append(np.std(clean_vals))\n",
    "            else:\n",
    "                results[metric][method]['mean'].append(np.nan)\n",
    "                results[metric][method]['std'].append(np.nan)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
    "axes = axes.flatten()\n",
    "\n",
    "titles = {\n",
    "    'Forget': r'$\\mathcal{D}_{\\mathcal{F}}$ ROUGE',\n",
    "    'Retain': r'$\\mathcal{D}_{\\mathcal{R}}$ ROUGE',\n",
    "}\n",
    "\n",
    "label_map = {\n",
    "    'retraining': 'Retraining',\n",
    "    'finetune': 'Fine-tune',\n",
    "    'grad_ascent': 'GA',\n",
    "    'grad_diff': 'GDiff',\n",
    "    'scrub': 'Scrub',\n",
    "    'idk': 'IDK',\n",
    "    'npo': 'NPO',\n",
    "    'scr_newton': 'SCRN',\n",
    "    'dareU': 'DareU (ours)'\n",
    "}\n",
    "\n",
    "plt.rcParams.update({\n",
    "        \"font.size\": 20\n",
    "    })\n",
    "\n",
    "for i, metric in enumerate(metrics_to_plot):\n",
    "    ax = axes[i]\n",
    "    for method in method_list:\n",
    "        if np.all(np.isnan(results[metric][method]['mean'])):\n",
    "            continue\n",
    "            \n",
    "        means = np.array(results[metric][method]['mean'])\n",
    "        stds = np.array(results[metric][method]['std'])\n",
    "        \n",
    "        color_map = {\n",
    "            'retraining': 'tab:red',\n",
    "            'dareU': 'tab:blue',\n",
    "            'finetune': 'tab:orange',\n",
    "            'grad_ascent': 'tab:green',\n",
    "            'grad_diff': 'tab:purple',\n",
    "            'scrub': 'tab:brown',\n",
    "            'idk': 'tab:pink',\n",
    "            'npo': 'tab:olive',\n",
    "            'scr_newton': 'tab:cyan', \n",
    "            'original': 'black'    \n",
    "        }\n",
    "        \n",
    "        c = color_map.get(method, 'tab:gray')\n",
    "        \n",
    "        if method == 'retraining':\n",
    "            ls = '-'\n",
    "            lw = 2.5 \n",
    "            alpha = 1.0\n",
    "        elif method == 'dareU':\n",
    "            ls = '-'\n",
    "            lw = 2.5\n",
    "            alpha = 1.0\n",
    "        else:\n",
    "            ls = ':' \n",
    "            lw = 2.0 \n",
    "            alpha = 0.8\n",
    "            \n",
    "        display_label = label_map.get(method, method)\n",
    "        \n",
    "        line, = ax.plot(rounds, means, marker='o', linestyle=ls, linewidth=lw, \n",
    "                        alpha=alpha, label=display_label, color=c)\n",
    "        \n",
    "        ax.fill_between(rounds, means - stds, means + stds, alpha=0.1)\n",
    "\n",
    "    ax.set_title(titles[metric])\n",
    "    ax.set_xlabel('Unlearning Round')\n",
    "    ax.set_xticks(rounds)\n",
    "    \n",
    "    handles, labels = axes[0].get_legend_handles_labels()\n",
    "    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.17), ncol=5, frameon=False)\n",
    "    ax.grid(True, linestyle='--', alpha=0.6)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('tofu_sequential_unlearning.pdf', bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
