{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Base path to the directories for each seed\n",
    "base_folder_path = '~/nfs_public/watermark_arxiv/main_results/'\n",
    "# base_folder_path = '../../main_results/'\n",
    "\n",
    "# Seed numbers you wish to use\n",
    "seeds = [41, 42, 43]\n",
    "\n",
    "# Initialize dictionaries to store the 'KnowMem Forget' and 'mia' metrics for each seed\n",
    "wtm_results = {algo: [] for algo in [\"original\", \"retraining\", \"finetune\", \"ga\", \"gdiff\", \"KL\", \"tv\", \"scrub\"]}\n",
    "knowmem_results = {algo: [] for algo in [\"original\", \"retraining\", \"finetune\", \"ga\", \"gdiff\", \"KL\", \"tv\", \"scrub\"]}\n",
    "mia_results = {algo: [] for algo in [\"original\", \"retraining\", \"finetune\", \"ga\", \"gdiff\", \"KL\", \"tv\", \"scrub\"]}\n",
    "rouge_results = {algo: [] for algo in [\"original\", \"retraining\", \"finetune\", \"ga\", \"gdiff\", \"KL\", \"tv\", \"scrub\"]}\n",
    "\n",
    "matplotlib.rcParams.update({'font.size': 20})\n",
    "\n",
    "\n",
    "for seed in seeds:\n",
    "    folder_path = os.path.join(base_folder_path, f'seed_{seed}/results_remove-1class')\n",
    "    watermarked_folder_path = os.path.join(base_folder_path, f'seed_{seed}/watermarked_results_remove-1class')\n",
    "\n",
    "    for algo in knowmem_results.keys():\n",
    "        filepath_knowmem = os.path.join(folder_path, f'eval/knowmem/10/{algo}/aggregated.json')\n",
    "        filepath_mia = os.path.join(folder_path, f'eval/mia_{algo}.json')\n",
    "        filepath_rouge = os.path.join(folder_path, f'eval/rouge_{algo}.csv')\n",
    "        filepath_wtm = os.path.join(watermarked_folder_path, f'watermark_verify/{algo}_q.npy')\n",
    "        \n",
    "        # WaterDrum metrics\n",
    "        try:\n",
    "            num_last_elements = 1\n",
    "            data = np.load(filepath_wtm)\n",
    "            diagonal = np.diagonal(data).tolist()\n",
    "            wtm_results[algo].append(np.mean(diagonal[-num_last_elements:])) \n",
    "            # wtm_results[algo].append(np.mean(diagonal[:-num_last_elements])) \n",
    "        except FileNotFoundError:\n",
    "            print(f\"File not found: {filepath_wtm}\")\n",
    "\n",
    "        # KnowMem metrics\n",
    "        try:\n",
    "            with open(filepath_knowmem, 'r') as file:\n",
    "                data = json.load(file)\n",
    "                knowmem_results[algo].append(data['KnowMem Forget']['mean_rougeL_recall'])\n",
    "        except FileNotFoundError:\n",
    "            print(f\"File not found: {filepath_knowmem}\")\n",
    "\n",
    "        # MIA metrics\n",
    "        try:\n",
    "            with open(filepath_mia, 'r') as file:\n",
    "                data = json.load(file)\n",
    "                mia_results[algo].append(data['forget_holdout_Min-40%'])\n",
    "        except FileNotFoundError:\n",
    "            print(f\"File not found: {filepath_mia}\")\n",
    "\n",
    "        # ROUGE metrics\n",
    "        try:\n",
    "            data = pd.read_csv(filepath_rouge)\n",
    "            rouge_results[algo].append(data['ROUGE Forget'])\n",
    "        except FileNotFoundError:\n",
    "            print(f\"File not found: {filepath_rouge}\")\n",
    "\n",
    "# Calculate mean and std for each algo\n",
    "knowmem_stats = {algo: (np.mean(values), np.std(values)) for algo, values in knowmem_results.items()}\n",
    "mia_stats = {algo: (np.mean(values), np.std(values)) for algo, values in mia_results.items()}\n",
    "rouge_stats = {algo: (np.mean(values), np.std(values)) for algo, values in rouge_results.items()}\n",
    "wtm_stats = {algo: (np.mean(values), np.std(values)) for algo, values in wtm_results.items()}\n",
    "\n",
    "print(\"Algorithm | ROUGE Mean | ROUGE Std | KnowMem Mean | KnowMem Std | MIA Mean | MIA Std | WaterDrum Mean | WaterDrum Std\")\n",
    "print(\"-----------------------------------------------------------\")\n",
    "for algo in knowmem_results.keys():\n",
    "    rouge_mean, rouge_std = rouge_stats[algo]\n",
    "    knowmem_mean, knowmem_std = knowmem_stats[algo]\n",
    "    mia_mean, mia_std = mia_stats[algo]\n",
    "    wtm_mean, wtm_std = wtm_stats[algo]\n",
    "    print(f\"{algo:<10} | {rouge_mean:<12.4f} | {rouge_std:<9.4f} | {knowmem_mean:<12.4f} | {knowmem_std:<9.4f} | {mia_mean:<8.4f} | {mia_std:<6.4f} | {wtm_mean:<8.4f} | {wtm_std:<6.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.rcParams.update({\n",
    "    \"font.size\": 18,       # Base font size\n",
    "    \"axes.titlesize\": 18,  # Title size\n",
    "    \"axes.labelsize\": 16,  # Axis label size\n",
    "    \"xtick.labelsize\": 14, # X tick size\n",
    "    \"ytick.labelsize\": 14, # Y tick size\n",
    "    \"legend.fontsize\": 12  # Legend font size\n",
    "})\n",
    "\n",
    "seeds = [41, 42, 43]\n",
    "\n",
    "dataset_name = \"arxiv\"\n",
    "unlearning_methods = [\n",
    "    \"original\", \"retraining\", \"finetune\", \n",
    "    \"KL\", \"tv\", \"scrub\"]\n",
    "labels = [\n",
    "    \"No unlearning\", \"Retraining\", \"GD\",\n",
    "    \"KL\", \"TV\", \"SCRUB\"\n",
    "]\n",
    "\n",
    "res = np.array([\n",
    "    [\n",
    "        np.load(\n",
    "                os.path.join(\n",
    "                    f\"~nfs_public/watermark_{dataset_name}/main_results/seed_{seed}/watermarked_results_remove-1class/watermark_verify\", (f\"{unlearning_method}_q_all.npy\")\n",
    "                )\n",
    "            )\n",
    "        for unlearning_method in unlearning_methods\n",
    "    ]\n",
    "    for seed in seeds\n",
    "])\n",
    "res = res.reshape(*res.shape[:2], 20, -1, res.shape[-1]).diagonal(axis1=-1, axis2=-3).mean(axis=-2)\n",
    "res = np.stack([res[...,:-1].mean(axis=-1), res[...,-1]], axis=-1)\n",
    "print(res)\n",
    "res /= res[:,:1]    # Normalize to 1 based on original\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(2.6,2.6))\n",
    "ax.set_aspect('equal', adjustable='box')\n",
    "title = 'WaterDrum-Ax'\n",
    "save = True\n",
    "xtop = 1.15\n",
    "xbottom = -0.25\n",
    "ytop = 1.15\n",
    "ybottom = -0.25\n",
    "xticks = np.arange(-10,10,0.25)\n",
    "xticks = xticks[(xticks < xtop) & (xticks > xbottom)]\n",
    "yticks = np.arange(-10,10,0.25)\n",
    "yticks = yticks[(yticks < ytop) & (yticks > ybottom)]\n",
    "for j in range(len(unlearning_methods)):\n",
    "    curr = res[:,j]\n",
    "    xdata = curr[...,0]\n",
    "    ydata = curr[...,1]\n",
    "    xerr = ((xdata.mean() - xdata.min(),), (xdata.max() - xdata.mean(),))\n",
    "    yerr = ((ydata.mean() - ydata.min(),), (ydata.max() - ydata.mean(),))\n",
    "    ax.errorbar(xdata.mean(), ydata.mean(), xerr=xerr, yerr=yerr, capsize=5, elinewidth=1, fmt=\"o\", label=labels[j])\n",
    "ax.set_xlim(xbottom,xtop)\n",
    "ax.set_ylim(ybottom,ytop)\n",
    "ax.set_xticks([0, 0.5, 1.0])\n",
    "ax.set_yticks([0, 0.5, 1.0])\n",
    "ax.grid(True, which='both', linestyle='--', linewidth=0.5)\n",
    "\n",
    "ax.set_xlabel(\"$M'_{d\\\\in{\\\\mathcal{D}_\\\\mathcal{R}}}({\\\\varphi}(q_d), \\\\mathcal{R})$ →\", fontsize=14)\n",
    "ax.set_ylabel(\"← $M'_{d\\\\in{\\\\mathcal{D}_\\\\mathcal{F}}}({\\\\varphi}(q_d), \\\\mathcal{F})$\", \n",
    "              fontsize=14,)\n",
    "ax.yaxis.set_label_coords(-0.32, 0.4)\n",
    "plt.tight_layout()\n",
    "plt.subplots_adjust(left=0.3)\n",
    "if save:\n",
    "    plt.savefig(f\"{dataset_name}_remove_benchmark_no_legend_ret_single.pdf\")\n",
    "ax.legend(\n",
    "    labels, \n",
    "    loc='right',\n",
    "    bbox_to_anchor=(2.15, 0.5), \n",
    "    ncol=1,\n",
    "    handlelength=0.8,            # shorter handles\n",
    "    handletextpad=0.4,           # tighter text-to-handle gap\n",
    "    borderpad=0.5,              # tighter padding inside the box\n",
    "    labelspacing=0.45,            # tighter rows (optional)\n",
    "    frameon=True\n",
    ")   \n",
    "if save:\n",
    "    plt.savefig(f\"{dataset_name}_remove_benchmark_ret_single.pdf\", bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "unlearning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
