{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from sklearn.metrics import roc_curve, auc\n",
    "import numpy as np\n",
    "import datasets\n",
    "\n",
    "# Base directory template for the data files\n",
    "base_dir_template = '~/nfs_public/watermark_arxiv/main_results/seed_{}/calibration/'\n",
    "wtm_base_dir_template = '~/nfs_public/watermark_arxiv/main_results/seed_{}/watermarked_calibration/'\n",
    "\n",
    "# Seeds to process\n",
    "seeds = [41, 42, 43]\n",
    "\n",
    "# Function to compute AUC for specified paths\n",
    "def compute_auc(forget_file_path, retain_file_path, key):\n",
    "    with open(forget_file_path, 'r') as file:\n",
    "        forget_data = json.load(file)\n",
    "    forget_values = [item[key] for item in forget_data] if isinstance(forget_data, list) else list(forget_data[key].values())\n",
    "\n",
    "    with open(retain_file_path, 'r') as file:\n",
    "        retain_data = json.load(file)\n",
    "    retain_values = [item[key] for item in retain_data] if isinstance(retain_data, list) else list(retain_data[key].values())\n",
    "\n",
    "    labels = [0] * len(forget_values) + [1] * len(retain_values)\n",
    "    scores = forget_values + retain_values\n",
    "    fpr, tpr, _ = roc_curve(labels, scores)\n",
    "    return auc(fpr, tpr)\n",
    "\n",
    "\n",
    "rouge_aucs = []\n",
    "knowmem_aucs = []\n",
    "wtm_aucs = []\n",
    "\n",
    "for seed in seeds:\n",
    "    base_dir = base_dir_template.format(seed)\n",
    "    wtm_base_dir = wtm_base_dir_template.format(seed)   \n",
    "    \n",
    "    # For ROUGE\n",
    "    rouge_roc_auc = compute_auc(\n",
    "        os.path.join(base_dir, 'eval/retraining/10pct_eval_rouge_forget.json'),\n",
    "        os.path.join(base_dir, 'eval/retraining/10pct_eval_rouge.json'),\n",
    "        'rougeL_recall'\n",
    "    )\n",
    "    rouge_aucs.append(rouge_roc_auc)\n",
    "\n",
    "    # For KnowMem\n",
    "    knowmem_roc_auc = compute_auc(\n",
    "        os.path.join(base_dir, f'eval/knowmem/10/retraining/eval_knowmem_forget.json'),\n",
    "        os.path.join(base_dir, f'eval/knowmem/10/retraining/eval_knowmem_retain.json'),\n",
    "        \"rougeL_recall\"\n",
    "    )\n",
    "    knowmem_aucs.append(knowmem_roc_auc)\n",
    "    \n",
    "    # For WTM\n",
    "    wtm_path = os.path.join(wtm_base_dir, f'watermark_verify/retraining_10pct_verify_0_1_2_3_4_5_6_7_8_9_10_11_12_13_14_15_16_17_18_19_1_0')\n",
    "    q = np.array(datasets.load_from_disk(wtm_path)[\"q\"]) \n",
    "    retain_values = list(q[:8000,:,:].reshape(-1,400,10,20).diagonal(axis1=0,axis2=3)[...,:-1].flatten())\n",
    "    forget_values = list(q[:8000,:,:].reshape(-1,400,10,20).diagonal(axis1=0,axis2=3)[...,-1].flatten())   \n",
    "    labels = [0] * len(forget_values) + [1] * len(retain_values)\n",
    "    scores = forget_values + retain_values\n",
    "    fpr, tpr, _ = roc_curve(labels, scores)\n",
    "    wtm_aucs.append(auc(fpr, tpr))\n",
    "\n",
    "# Calculate mean and standard deviation\n",
    "rouge_mean = np.mean(rouge_aucs)\n",
    "rouge_std = np.std(rouge_aucs)\n",
    "knowmem_mean = np.mean(knowmem_aucs)\n",
    "knowmem_std = np.std(knowmem_aucs)\n",
    "wtm_mean = np.mean(wtm_aucs)\n",
    "wtm_std = np.std(wtm_aucs)\n",
    "\n",
    "# Print results\n",
    "print(f\"\\nResults for calibration:\")\n",
    "print(\"Metric\\t\\tMean AUC\\tStd AUC\")\n",
    "print(f\"ROUGE\\t\\t{rouge_mean:.4f}\\t\\t{rouge_std:.4f}\")\n",
    "print(f\"KnowMem\\t\\t{knowmem_mean:.4f}\\t\\t{knowmem_std:.4f}\")\n",
    "print(f\"WTM\\t\\t{wtm_mean:.4f}\\t\\t{wtm_std:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from sklearn.metrics import roc_curve, auc\n",
    "import numpy as np\n",
    "import datasets\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Base directory templates for the data files\n",
    "wtm_base_dir_template_dup = '~/nfs_public/watermark_arxiv/main_results/seed_{}/watermarked_calibration_duplicate/'\n",
    "wtm_base_dir_template_ori = '~/nfs_public/watermark_arxiv/main_results/seed_{}/watermarked_calibration/'\n",
    "\n",
    "# Seeds to process\n",
    "seeds = [41, 42, 43]\n",
    "\n",
    "# Function to compute AUC and return points for curve\n",
    "def compute_fpr_tpr_auc(forget_file_path, retain_file_path, key):\n",
    "    with open(forget_file_path, 'r') as file:\n",
    "        forget_data = json.load(file)\n",
    "    forget_values = [item[key] for item in forget_data] if isinstance(forget_data, list) else list(forget_data[key].values())\n",
    "\n",
    "    with open(retain_file_path, 'r') as file:\n",
    "        retain_data = json.load(file)\n",
    "    retain_values = [item[key] for item in retain_data] if isinstance(retain_data, list) else list(retain_data[key].values())\n",
    "\n",
    "    labels = [0] * len(forget_values) + [1] * len(retain_values)\n",
    "    scores = forget_values + retain_values\n",
    "    fpr, tpr, _ = roc_curve(labels, scores)\n",
    "    roc_auc = auc(fpr, tpr)\n",
    "    return fpr, tpr, roc_auc\n",
    "\n",
    "# Arrays to store AUCs and FPRs/TPRs for both datasets\n",
    "wtm_aucs_dup = []\n",
    "wtm_aucs_ori = []\n",
    "wtm_fpr_tpr_dup = []\n",
    "wtm_fpr_tpr_ori = []\n",
    "\n",
    "# Process each seed for both datasets\n",
    "for seed in seeds:\n",
    "    for base_dir_template, aucs, fpr_tpr in [\n",
    "        (wtm_base_dir_template_ori, wtm_aucs_ori, wtm_fpr_tpr_ori),\n",
    "        (wtm_base_dir_template_dup, wtm_aucs_dup, wtm_fpr_tpr_dup)\n",
    "    ]:\n",
    "        wtm_base_dir = base_dir_template.format(seed)\n",
    "\n",
    "        # For WaterDrum\n",
    "        wtm_path = os.path.join(wtm_base_dir, 'watermark_verify/retraining_10pct_verify_0_1_2_3_4_5_6_7_8_9_10_11_12_13_14_15_16_17_18_19_1_0')\n",
    "        q = np.array(datasets.load_from_disk(wtm_path)[\"q\"]) \n",
    "        retain_values = list(q[:8000,:,:].reshape(-1,400,10,20).diagonal(axis1=0,axis2=3)[...,:-1].flatten())\n",
    "        forget_values = list(q[:8000,:,:].reshape(-1,400,10,20).diagonal(axis1=0,axis2=3)[...,-1].flatten())   \n",
    "        labels = [0] * len(forget_values) + [1] * len(retain_values)\n",
    "        scores = forget_values + retain_values\n",
    "        fpr, tpr, _ = roc_curve(labels, scores)\n",
    "        roc_auc = auc(fpr, tpr)\n",
    "        aucs.append(roc_auc)\n",
    "        fpr_tpr.append((fpr, tpr))\n",
    "\n",
    "plt.rcParams.update({'font.size': 20})\n",
    "# Adjusted Plotting for Two Columns\n",
    "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(14, 6))\n",
    "metric_names = [\"WaterDrum No Duplicate\", \"WaterDrum Exact Duplicate\"]\n",
    "fpr_tpr_datas = [wtm_fpr_tpr_ori, wtm_fpr_tpr_dup]\n",
    "aucs_list = [wtm_aucs_ori, wtm_aucs_dup]\n",
    "\n",
    "for j, (ax, metric_name, auroc_fpr_tpr, aucs) in enumerate(zip(axes, metric_names, fpr_tpr_datas, aucs_list)):\n",
    "    for i, (fpr, tpr) in enumerate(auroc_fpr_tpr):\n",
    "        ax.plot(fpr, tpr, lw=2, alpha=0.5, color=colors[i])\n",
    "\n",
    "    # Calculate and plot mean ROC\n",
    "    mean_fpr = np.linspace(0, 1, 100)\n",
    "    mean_tpr = np.mean([np.interp(mean_fpr, fpr, tpr) for fpr, tpr in auroc_fpr_tpr], axis=0)\n",
    "    mean_auc = auc(mean_fpr, mean_tpr)\n",
    "    ax.plot(mean_fpr, mean_tpr, color='black', linestyle='--', label=f'Mean AUC = {mean_auc:.2f}')\n",
    "\n",
    "    ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')\n",
    "    ax.set_xlim([0.0, 1.0])\n",
    "    ax.set_ylim([0.0, 1.05])\n",
    "    ax.set_xlabel('False Positive Rate', fontsize=26)\n",
    "    ax.set_ylabel('True Positive Rate', fontsize=26)\n",
    "    ax.set_title(f'{metric_name}', fontsize=28)\n",
    "    ax.legend(loc=\"lower right\", fontsize=24)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "unlearning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
