{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ROOT_DIR = ''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Llama Arxiv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_arxiv_sequential_unlearning():\n",
    "    outputs = {}\n",
    "    mia_outputs = {}\n",
    "    ps_f_outputs = {}\n",
    "    ps_r_outputs = {}\n",
    "\n",
    "    random_seeds = [41, 42, 43]\n",
    "\n",
    "    # class unlearning \n",
    "    method_list = ['retraining', 'original', 'finetune', 'grad_ascent', 'grad_diff',\n",
    "                   'scrub', 'idk', 'npo', 'scr_newton']\n",
    "    for method in method_list:\n",
    "        res_all_seeds = []\n",
    "        mia_res_all_seeds = []\n",
    "        ps_f_all_seeds = []\n",
    "        ps_r_all_seeds = []\n",
    "        for seed in random_seeds:\n",
    "            path = f\"{ROOT_DIR}/llama2_arxiv/seed_{seed}/eval/{method}/rouge.csv\"\n",
    "            mia_path = f\"{ROOT_DIR}/llama2_arxiv/seed_{seed}/eval/{method}/mia.json\"\n",
    "            ps_path = f\"{ROOT_DIR}/llama2_arxiv/seed_{seed}/eval/{method}/ps_metric.json\"\n",
    "\n",
    "            if not os.path.exists(path):\n",
    "                print(f\"Can't find path: {path}\")\n",
    "                continue\n",
    "            with open(path, \"r\") as f:\n",
    "                stats = pd.read_csv(f)\n",
    "            stats[\"index\"] = range(len(stats))\n",
    "            res_all_seeds.append(stats)\n",
    "\n",
    "            if not os.path.exists(mia_path):\n",
    "                print(f\"Can't find MIA path: {mia_path}\")\n",
    "                continue\n",
    "            mia_stats = json.load(open(mia_path, \"r\"))\n",
    "            if isinstance(mia_stats, dict):\n",
    "                mia_stats = mia_stats['holdout_forget_Min-40%'] \n",
    "            mia_res_all_seeds.append(mia_stats)\n",
    "            \n",
    "            if not os.path.exists(ps_path):\n",
    "                print(f\"Can't find MIA path: {ps_path}\")\n",
    "                continue\n",
    "            ps_stats = json.load(open(ps_path, \"r\"))\n",
    "            if isinstance(ps_stats, dict):\n",
    "                ps_forget = ps_stats['ps_forget_exact'] \n",
    "                ps_retain = ps_stats['ps_retain_exact']\n",
    "            ps_f_all_seeds.append(ps_forget)\n",
    "            ps_r_all_seeds.append(ps_retain)\n",
    "\n",
    "        outputs[method] = res_all_seeds\n",
    "        mia_outputs[method] = mia_res_all_seeds\n",
    "        ps_f_outputs[method] = ps_f_all_seeds\n",
    "        ps_r_outputs[method] = ps_r_all_seeds\n",
    "\n",
    "    return outputs, mia_outputs, ps_f_outputs, ps_r_outputs\n",
    "\n",
    "data, data_mia, data_psf, data_psr = load_arxiv_sequential_unlearning()\n",
    "method_list = list(data.keys())\n",
    "\n",
    "# keys for Arxiv\n",
    "forget_key = 'ROUGE Forget'\n",
    "retain_key = 'ROUGE Retain'\n",
    "test_key = 'ROUGE Test'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Qwen Arxiv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_arxiv_sequential_unlearning():\n",
    "    outputs = {}\n",
    "    mia_outputs = {}\n",
    "    ps_f_outputs = {}\n",
    "    ps_r_outputs = {}\n",
    "\n",
    "    random_seeds = [41, 42, 43]\n",
    "\n",
    "    # class unlearning \n",
    "    method_list = ['retraining', 'original', 'finetune', 'grad_ascent', 'grad_diff',\n",
    "                   'scrub', 'npo', 'scr_newton']\n",
    "    for method in method_list:\n",
    "        res_all_seeds = []\n",
    "        mia_res_all_seeds = []\n",
    "        ps_f_all_seeds = []\n",
    "        ps_r_all_seeds = []\n",
    "        for seed in random_seeds:\n",
    "            path = f\"{ROOT_DIR}/qwen_arxiv/seed_{seed}/eval/{method}/rouge.csv\"\n",
    "            mia_path = f\"{ROOT_DIR}/qwen_arxiv/seed_{seed}/eval/{method}/mia.json\"\n",
    "            ps_path = f\"{ROOT_DIR}/llama2_arxiv/seed_{seed}/eval/{method}/ps_metric.json\"\n",
    "\n",
    "            if not os.path.exists(path):\n",
    "                print(f\"Can't find path: {path}\")\n",
    "                continue\n",
    "            with open(path, \"r\") as f:\n",
    "                stats = pd.read_csv(f)\n",
    "            stats[\"index\"] = range(len(stats))\n",
    "            res_all_seeds.append(stats)\n",
    "\n",
    "            if not os.path.exists(mia_path):\n",
    "                print(f\"Can't find MIA path: {mia_path}\")\n",
    "                continue\n",
    "            mia_stats = json.load(open(mia_path, \"r\"))\n",
    "            if isinstance(mia_stats, dict):\n",
    "                mia_stats = mia_stats['holdout_forget_Min-40%'] \n",
    "            mia_res_all_seeds.append(mia_stats)\n",
    "            \n",
    "            if not os.path.exists(ps_path):\n",
    "                print(f\"Can't find MIA path: {ps_path}\")\n",
    "                continue\n",
    "            ps_stats = json.load(open(ps_path, \"r\"))\n",
    "            if isinstance(ps_stats, dict):\n",
    "                ps_forget = ps_stats['ps_forget_exact'] \n",
    "                ps_retain = ps_stats['ps_retain_exact']\n",
    "            ps_f_all_seeds.append(ps_forget)\n",
    "            ps_r_all_seeds.append(ps_retain)\n",
    "\n",
    "        outputs[method] = res_all_seeds\n",
    "        mia_outputs[method] = mia_res_all_seeds\n",
    "        ps_f_outputs[method] = ps_f_all_seeds\n",
    "        ps_r_outputs[method] = ps_r_all_seeds\n",
    "\n",
    "    return outputs, mia_outputs, ps_f_outputs, ps_r_outputs\n",
    "\n",
    "data, data_mia, data_psf, data_psr = load_arxiv_sequential_unlearning()\n",
    "method_list = list(data.keys())\n",
    "\n",
    "# keys for Arxiv\n",
    "forget_key = 'ROUGE Forget'\n",
    "retain_key = 'ROUGE Retain'\n",
    "test_key = 'ROUGE Test'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aggregate results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retrain_res = data['retraining']\n",
    "\n",
    "normalize = False\n",
    "\n",
    "for method in method_list:\n",
    "    print('Method:', method)\n",
    "    seed = 1\n",
    "    for seed_res, retrain_seed_res in zip(data[method], retrain_res):\n",
    "        de_acc = seed_res[forget_key]\n",
    "        dr_acc = seed_res[retain_key]\n",
    "        dtest_acc = seed_res[test_key]\n",
    "        if normalize:\n",
    "            max_value = 100.0\n",
    "        else:\n",
    "            max_value = 1.0\n",
    "        tow_score = 1 - (abs(de_acc - retrain_seed_res[forget_key]) / max_value)\n",
    "        tow_score *= (1 - (abs(dr_acc - retrain_seed_res[retain_key]) / max_value))\n",
    "        # tow_score *= (1 - (abs(dtest_acc - retrain_seed_res[test_key]) / max_value))\n",
    "            \n",
    "        seed_res['tow_score'] = tow_score\n",
    "        seed_res['mia'] = data_mia[method][seed - 1]\n",
    "        seed_res['psf'] = data_psf[method][seed - 1]\n",
    "        seed_res['psr'] = data_psr[method][seed - 1]\n",
    "        seed += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "forget_acc_list = {method: [] for method in method_list}\n",
    "retain_acc_list = {method: [] for method in method_list}\n",
    "test_acc_list = {method: [] for method in method_list}\n",
    "tow_list = {method: [] for method in method_list}\n",
    "mia_list = {method: [] for method in method_list}\n",
    "psf_list = {method: [] for method in method_list}\n",
    "psr_list = {method: [] for method in method_list}\n",
    "\n",
    "for method in method_list:\n",
    "    print(method)\n",
    "    for seed_res in data[method]:\n",
    "        forget_acc_list[method].append(seed_res.iloc[-1][forget_key])\n",
    "        retain_acc_list[method].append(seed_res.iloc[-1][retain_key])\n",
    "        # test_acc_list[method].append(seed_res.iloc[-1][test_key])\n",
    "        tow_list[method].append(seed_res.iloc[-1]['tow_score'])\n",
    "        mia_list[method].append(seed_res.iloc[-1]['mia'])\n",
    "        psf_list[method].append(seed_res.iloc[-1]['psf'])\n",
    "        psr_list[method].append(seed_res.iloc[-1]['psr'])\n",
    "    forget_acc_list[method] = {'mean': np.mean(forget_acc_list[method]), 'std': np.std(forget_acc_list[method])}\n",
    "    retain_acc_list[method] = {'mean': np.mean(retain_acc_list[method]), 'std': np.std(retain_acc_list[method])}\n",
    "    # test_acc_list[method] = {'mean': np.mean(test_acc_list[method]), 'std': np.std(test_acc_list[method])}\n",
    "    tow_list[method] = {'mean': np.mean(tow_list[method]), 'std': np.std(tow_list[method])}\n",
    "    mia_list[method] = {'mean': np.mean(mia_list[method]), 'std': np.std(mia_list[method])}\n",
    "    psf_list[method] = {'mean': np.mean(psf_list[method]), 'std': np.std(psf_list[method])}\n",
    "    psr_list[method] = {'mean': np.mean(psr_list[method]), 'std': np.std(psr_list[method])}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Method | Forget Acc | Retain Acc | ToW Score | Forget ES | Retain ES | MIA')\n",
    "print('--- | --- | --- | --- | ---')\n",
    "for method in method_list:\n",
    "    print('{} & {:.3f} ± {:.3f} & {:.3f} ± {:.3f} & {:.3f} ± {:.3f} & {:.3f} ± {:.3f} & {:.3f} ± {:.3f} & {:.4f} ± {:.4f} \\\\\\\\'.format(\n",
    "        method,\n",
    "        forget_acc_list[method]['mean'], forget_acc_list[method]['std'], \n",
    "        retain_acc_list[method]['mean'], retain_acc_list[method]['std'],\n",
    "        # test_acc_list[method]['mean'], test_acc_list[method]['std'],\n",
    "        tow_list[method]['mean'], tow_list[method]['std'], \n",
    "        # 0.0, 0.0,\n",
    "        # truth_ratio_list[method]['mean'], truth_ratio_list[method]['std'],\n",
    "        # 0.0, 0.0,\n",
    "        psf_list[method]['mean'], psf_list[method]['std'],\n",
    "        psr_list[method]['mean'], psr_list[method]['std'],\n",
    "        mia_list[method]['mean'], mia_list[method]['std'],\n",
    "    ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "waterfall",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
