{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "OLD_ROOT_DIR = '~/unlearning/others/curenewton_old/main_results'\n",
    "ROOT_DIR = '~/unlearning/cure_newton'\n",
    "\n",
    "OLD_ROOT_DIR = os.path.expanduser(OLD_ROOT_DIR)\n",
    "ROOT_DIR = os.path.expanduser(ROOT_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ROOT_DIR = '~/unlearning/cure_newton/'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### FashionMNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_fashion_mnist_batch_unlearning():\n",
    "    outputs = {}\n",
    "    unlearning_type = 'by-instance'    # by-instance\n",
    "\n",
    "    # class unlearning \n",
    "    random_seeds = [1, 2, 3]\n",
    "    method_list = ['retraining', 'original', 'random_labels', 'gd', 'ga', 'ntk', 'pinv_newton', 'damped_newton', 'scrub', 'cr_newton', 'scr_newton', 'npo', 'gdiff', 'delete']\n",
    "\n",
    "    for method in method_list:\n",
    "        res_all_seeds = []\n",
    "        for seed in random_seeds:\n",
    "            path = os.path.join(ROOT_DIR, f\"batch_unlearning/fashion_mnist/seed-{seed}/{unlearning_type}/{method}/stats.json\")\n",
    "            if not os.path.exists(path):\n",
    "                print(f\"Can't find path: {path}\")\n",
    "                continue\n",
    "            with open(path, \"r\") as f:\n",
    "                stats = json.load(f)\n",
    "            stats = pd.DataFrame(stats)\n",
    "            stats[\"index\"] = range(len(stats))\n",
    "            res_all_seeds.append(stats)\n",
    "\n",
    "        outputs[method] = res_all_seeds\n",
    "\n",
    "    return outputs\n",
    "\n",
    "data = load_fashion_mnist_batch_unlearning()\n",
    "method_list = list(data.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Llama"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_llama_sequential_unlearning():\n",
    "    outputs = {}\n",
    "\n",
    "    # unlearning_type = 'by-class'\n",
    "    unlearning_type = 'by-instance'\n",
    "    random_seeds = [1, 2, 3]\n",
    "\n",
    "    # class unlearning \n",
    "    method_list = ['retraining', 'original', 'random_labels', 'gd', 'ga', 'scrub', 'scr_newton']\n",
    "\n",
    "    for method in method_list:\n",
    "        res_all_seeds = []\n",
    "        for seed in random_seeds:\n",
    "            path = f\"{OLD_ROOT_DIR}/sequential_unlearning/llama-instance/ag_news/seed-{seed}/{unlearning_type}/{method}/stats.json\"\n",
    "            # path = f\"{ROOT_DIR}/sequential_unlearning/llama/ag_news/seed-{seed}/{unlearning_type}/{method}/stats.json\"\n",
    "            if not os.path.exists(path):\n",
    "                print(f\"Can't find path: {path}\")\n",
    "                continue\n",
    "            with open(path, \"r\") as f:\n",
    "                stats = json.load(f)\n",
    "            stats = pd.DataFrame(stats)\n",
    "            stats[\"index\"] = range(len(stats))\n",
    "            res_all_seeds.append(stats)\n",
    "\n",
    "        outputs[method] = res_all_seeds\n",
    "\n",
    "    return outputs\n",
    "\n",
    "data = load_llama_sequential_unlearning()\n",
    "method_list = list(data.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### TOFU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_tofu_sequential_unlearning():\n",
    "    outputs = {}\n",
    "    mia_outputs = {}\n",
    "\n",
    "    # unlearning_type = 'by-class'\n",
    "    unlearning_type = 'by-instance'\n",
    "    random_seeds = [1, 2, 3]\n",
    "\n",
    "    # class unlearning \n",
    "    method_list = ['retraining', 'original', 'gd', 'ga', 'scrub', 'scr_newton', 'KL', 'dpo', 'gdiff', 'npo']\n",
    "    for method in method_list:\n",
    "        res_all_seeds = []\n",
    "        mia_res_all_seeds = []\n",
    "        for seed in random_seeds:\n",
    "            path = f\"{ROOT_DIR}/seed-{seed}/eval/{method}/tofu.csv\"\n",
    "            mia_path = f\"{ROOT_DIR}/seed-{seed}/eval/{method}/mia.json\"\n",
    "            # path = f\"{ROOT_DIR}/sequential_unlearning/llama/ag_news/seed-{seed}/{unlearning_type}/{method}/stats.json\"\n",
    "            if not os.path.exists(path):\n",
    "                print(f\"Can't find path: {path}\")\n",
    "                continue\n",
    "            with open(path, \"r\") as f:\n",
    "                stats = pd.read_csv(f)\n",
    "            stats[\"index\"] = range(len(stats))\n",
    "            res_all_seeds.append(stats)\n",
    "\n",
    "            if not os.path.exists(mia_path):\n",
    "                print(f\"Can't find MIA path: {mia_path}\")\n",
    "                continue\n",
    "            mia_stats = json.load(open(mia_path, \"r\"))\n",
    "            if isinstance(mia_stats, dict):\n",
    "                mia_stats = mia_stats['forget_holdout_Min-40%'] \n",
    "            mia_res_all_seeds.append(mia_stats)\n",
    "\n",
    "        outputs[method] = res_all_seeds\n",
    "        mia_outputs[method] = mia_res_all_seeds\n",
    "\n",
    "    return outputs, mia_outputs\n",
    "\n",
    "data, data_mia = load_tofu_sequential_unlearning()\n",
    "method_list = list(data.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### ResNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_resnet_sequential_unlearning():\n",
    "    outputs = {}\n",
    "    mia_outputs = {}\n",
    "    # unlearning_type = 'by-class'    # by-instance\n",
    "    unlearning_type = 'by-class'\n",
    "\n",
    "    # class unlearning \n",
    "    random_seeds = [1, 2, 3]\n",
    "    # method_list = ['retraining', 'original', 'random_labels', 'gd', 'ga', 'scrub', 'scr_newton']\n",
    "    # method_list = ['retraining', 'original', 'random_labels', 'sgd', 'ga', 'scrub', 'scr_newton']\n",
    "    method_list = ['retraining', 'original', 'ga', 'random_labels', 'gdiff', 'npo', 'scr_newton']\n",
    "\n",
    "    for method in method_list:\n",
    "        res_all_seeds = []\n",
    "        mia_all_seeds = []\n",
    "        for seed in random_seeds:\n",
    "            path = f\"{ROOT_DIR}/sequential_unlearning/cifar10-class/seed-{seed}/{unlearning_type}/{method}/stats.json\"\n",
    "            mia_path = f\"{ROOT_DIR}/sequential_unlearning/cifar10-class/seed-{seed}/{unlearning_type}/{method}/attack_result.json\"\n",
    "            if not os.path.exists(path):\n",
    "                print(f\"Can't find path: {path}\")\n",
    "                continue\n",
    "            with open(path, \"r\") as f:\n",
    "                stats = json.load(f)\n",
    "            stats = pd.DataFrame(stats)\n",
    "            stats[\"index\"] = range(len(stats))\n",
    "            res_all_seeds.append(stats)\n",
    "\n",
    "            if not os.path.exists(mia_path):\n",
    "                print(f\"Can't find MIA path: {mia_path}\")\n",
    "                continue\n",
    "            with open(mia_path, \"r\") as f:\n",
    "                raw = f.read()\n",
    "            raw = \"[\" + raw.replace(\"}{\", \"},{\") + \"]\"\n",
    "            mia_stats = json.loads(raw)\n",
    "            mia_all_seeds.append(mia_stats[0]['Train AUC'])\n",
    "\n",
    "        outputs[method] = res_all_seeds\n",
    "        mia_outputs[method] = mia_all_seeds\n",
    "\n",
    "    return outputs, mia_outputs\n",
    "\n",
    "data, data_mia = load_resnet_sequential_unlearning()\n",
    "method_list = list(data.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aggregate results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retrain_res = data['retraining']\n",
    "\n",
    "# keys for TOFU\n",
    "# forget_key = 'ROUGE Forget'\n",
    "# retain_key = 'ROUGE Retain'\n",
    "# test_key = 'ROUGE Real Authors'\n",
    "# normalize = False\n",
    "\n",
    "# keys for llama by-class\n",
    "# forget_key = 'accu_forget_eval_accuracy'\n",
    "# retain_key = 'retain_eval_accuracy'\n",
    "# test_key = 'test_eval_accuracy'\n",
    "\n",
    "# keys for CIFAR-10\n",
    "forget_key = 'accu_forget_acc'\n",
    "retain_key = 'retain_acc'\n",
    "test_key = 'test_acc'\n",
    "normalize = True\n",
    "\n",
    "for method in method_list:\n",
    "    print('Method:', method)\n",
    "    seed = 1\n",
    "    for seed_res, retrain_seed_res in zip(data[method], retrain_res):\n",
    "        print('seed number:', seed)\n",
    "        de_acc = seed_res[forget_key]\n",
    "        dr_acc = seed_res[retain_key]\n",
    "        dtest_acc = seed_res[test_key]\n",
    "        if normalize:\n",
    "            max_value = 100.0\n",
    "        else:\n",
    "            max_value = 1.0\n",
    "        tow_score = 1 - (abs(de_acc - retrain_seed_res[forget_key]) / max_value)\n",
    "        tow_score *= (1 - (abs(dr_acc - retrain_seed_res[retain_key]) / max_value))\n",
    "        tow_score *= (1 - (abs(dtest_acc - retrain_seed_res[test_key]) / max_value))\n",
    "            \n",
    "        seed_res['tow_score'] = tow_score\n",
    "        print(seed_res['tow_score'])\n",
    "        # seed_res['truth_ratio'] = seed_res['Truth Ratio Forget']\n",
    "        # seed_res['mia'] = data_mia[method][seed - 1]\n",
    "        seed += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "forget_acc_list = {method: [] for method in method_list}\n",
    "retain_acc_list = {method: [] for method in method_list}\n",
    "test_acc_list = {method: [] for method in method_list}\n",
    "tow_list = {method: [] for method in method_list}\n",
    "# truth_ratio_list = {method: [] for method in method_list}\n",
    "# mia_list = {method: [] for method in method_list}\n",
    "\n",
    "for method in method_list:\n",
    "    print(method)\n",
    "    for seed_res in data[method]:\n",
    "        forget_acc_list[method].append(seed_res.iloc[-1][forget_key])\n",
    "        retain_acc_list[method].append(seed_res.iloc[-1][retain_key])\n",
    "        test_acc_list[method].append(seed_res.iloc[-1][test_key])\n",
    "        tow_list[method].append(seed_res.iloc[-1]['tow_score'])\n",
    "        # truth_ratio_list[method].append(seed_res.iloc[-1]['truth_ratio'])\n",
    "        # mia_list[method].append(seed_res.iloc[-1]['mia'])\n",
    "\n",
    "    forget_acc_list[method] = {'mean': np.mean(forget_acc_list[method]), 'std': np.std(forget_acc_list[method])}\n",
    "    retain_acc_list[method] = {'mean': np.mean(retain_acc_list[method]), 'std': np.std(retain_acc_list[method])}\n",
    "    test_acc_list[method] = {'mean': np.mean(test_acc_list[method]), 'std': np.std(test_acc_list[method])}\n",
    "    tow_list[method] = {'mean': np.mean(tow_list[method]), 'std': np.std(tow_list[method])}\n",
    "    # truth_ratio_list[method] = {'mean': np.mean(truth_ratio_list[method]), 'std': np.std(truth_ratio_list[method])}\n",
    "    # mia_list[method] = {'mean': np.mean(mia_list[method]), 'std': np.std(mia_list[method])}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print('Method | Forget Acc | Retain Acc | Test Acc | ToW Score | MIA AUC')\n",
    "# print('--- | --- | --- | --- | ---')\n",
    "# for method in method_list:\n",
    "#     print('{} & {:.2f} ± {:.2f} & {:.2f} ± {:.2f} & {:.2f} ± {:.2f} & {:.2f} ± {:.2f} & {:.2f} ± {:.2f} \\\\\\\\'.format(\n",
    "#         method,\n",
    "#         forget_acc_list[method]['mean'], forget_acc_list[method]['std'], \n",
    "#         retain_acc_list[method]['mean'], retain_acc_list[method]['std'],\n",
    "#         test_acc_list[method]['mean'], test_acc_list[method]['std'],\n",
    "#         tow_list[method]['mean'], tow_list[method]['std'], \n",
    "#         mia_list[method]['mean'], mia_list[method]['std']\n",
    "#     ))\n",
    "\n",
    "print('Method | Forget Acc | Retain Acc | Test Acc | Truth Ratio | ToW Score | MIA')\n",
    "print('--- | --- | --- | --- | ---')\n",
    "for method in method_list:\n",
    "    print('{} & {:.2f} ± {:.2f} & {:.2f} ± {:.2f} & {:.2f} ± {:.2f} & {:.2f} ± {:.2f} & {:.2f} ± {:.2f} & {:.4f} ± {:.4f} \\\\\\\\'.format(\n",
    "        method,\n",
    "        forget_acc_list[method]['mean'], forget_acc_list[method]['std'], \n",
    "        retain_acc_list[method]['mean'], retain_acc_list[method]['std'],\n",
    "        test_acc_list[method]['mean'], test_acc_list[method]['std'],\n",
    "        0.0, 0.0,\n",
    "        # truth_ratio_list[method]['mean'], truth_ratio_list[method]['std'],\n",
    "        tow_list[method]['mean'], tow_list[method]['std'], \n",
    "        0.0, 0.0,\n",
    "        # mia_list[method]['mean'], mia_list[method]['std'],\n",
    "    ))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "unlearning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
