{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def compute_fixed_point_r2(x, y):\n",
    "    \"\"\"\n",
    "    Compute R² for best linear fit y = -mx + b where the line must pass through (100, 0).\n",
    "    This means b = 100m, so the equation becomes y = -mx + 100m = m(100-x)\n",
    "    \n",
    "    Parameters:\n",
    "    x (array-like): x values (percentage points)\n",
    "    y (array-like): y values (metric values)\n",
    "    \n",
    "    Returns:\n",
    "    tuple: (R² value, slope)\n",
    "    \"\"\"\n",
    "    x = np.array(x)\n",
    "    y = np.array(y)\n",
    "    \n",
    "    # Since we want y = m(100-x), we can find m using least squares\n",
    "    # m = Σ(y*(100-x)) / Σ((100-x)²)\n",
    "    shifted_x = 100 - x\n",
    "    m = np.sum(y * shifted_x) / np.sum(shifted_x ** 2)\n",
    "    \n",
    "    # Calculate predicted values\n",
    "    y_pred = m * shifted_x\n",
    "    \n",
    "    # Calculate R²\n",
    "    \n",
    "    \n",
    "    rss = np.sum((y - y_pred) ** 2)\n",
    "    tss = np.sum((y - np.mean(y)) ** 2)\n",
    "    r2 = 1 - (rss / tss)\n",
    "    \n",
    "    return r2, m\n",
    "\n",
    "def plot_best_fit_line(ax, x_values, y_values, color, alpha=0.5):\n",
    "    \"\"\"\n",
    "    Plot the best fit line for a given set of points.\n",
    "    \"\"\"\n",
    "    r2, slope = compute_fixed_point_r2(x_values, y_values)\n",
    "    \n",
    "    # Generate points for the line\n",
    "    x_fit = np.array([0, 100])\n",
    "    y_fit = slope * (100 - x_fit)\n",
    "    \n",
    "    # Plot the line\n",
    "    ax.plot(x_fit, y_fit, '--', color=color, alpha=alpha)\n",
    "    \n",
    "    return r2, slope"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "\n",
    "# Define the path to the main folder containing results for all seeds\n",
    "main_folder_path = '~/nfs_public/watermark_arxiv/main_results'\n",
    "seeds = [41, 42, 43]  # List of seeds you want to use\n",
    "num_pct_values = 11  # As you have percentages from 0 to 10\n",
    "\n",
    "def scale_by_first_element(array):\n",
    "    first_element = array[0]\n",
    "    return array / first_element if first_element != 0 else array\n",
    "\n",
    "def process_and_plot_metrics(ax, metric_name, y_label, seeds, main_folder_path, num_pct_values):\n",
    "    # Specific percentage points for calibration_duplicate_10times\n",
    "    specific_pcts_indices = [0, 2, 4, 6, 8, 10]\n",
    "\n",
    "    # Initialize dictionaries to store arrays for the current metric\n",
    "    metrics_data = {\n",
    "        'calibration': np.zeros((len(seeds), num_pct_values)),\n",
    "        'calibration_duplicate': np.zeros((len(seeds), num_pct_values)),\n",
    "        'calibration_semduplicate': np.zeros((len(seeds), num_pct_values)),\n",
    "    }\n",
    "\n",
    "    # Process calibration data\n",
    "    for seed_index, seed in enumerate(seeds):\n",
    "        for eval_type in metrics_data.keys():\n",
    "            folder_path = os.path.join(main_folder_path, f'seed_{seed}/{eval_type}/eval')\n",
    "\n",
    "            if eval_type == 'calibration_duplicate_10times':\n",
    "                if metric_name in ['1- MIA', 'KnowMeM', 'WaterDrum', 'ROUGE']:\n",
    "                    for idx, pct in enumerate(specific_pcts_indices):\n",
    "                        if metric_name == '1- MIA':\n",
    "                            filename = f'mia_retraining_{pct}pct.json'\n",
    "                            filepath = os.path.join(folder_path, filename)\n",
    "                            with open(filepath, 'r') as file:\n",
    "                                data = json.load(file)\n",
    "                                metrics_data[eval_type][seed_index, idx] = 1 - data['forget_holdout_Min-40%']\n",
    "                        elif metric_name == 'KnowMeM':\n",
    "                            filepath = os.path.join(folder_path, f'knowmem/{pct}/retraining/aggregated.json')\n",
    "                            with open(filepath, 'r') as file:\n",
    "                                data = json.load(file)\n",
    "                                metrics_data[eval_type][seed_index, idx] = data['KnowMem Forget']['mean_rougeL_recall']\n",
    "\n",
    "                        elif metric_name == 'WaterDrum':\n",
    "                            wtm_path = os.path.join(main_folder_path, f'seed_{seed}/watermarked_{eval_type}/watermark_verify/')\n",
    "                            wtm_data = np.load(os.path.join(wtm_path, 'calibration.npy'))\n",
    "                            metrics_data[eval_type][seed_index, idx] = wtm_data[pct].mean()\n",
    "\n",
    "                        else:  # rougeL_recall\n",
    "                            filename = f'rouge_retraining_{pct}pct.csv'\n",
    "                            filepath = os.path.join(folder_path, filename)\n",
    "                            data = pd.read_csv(filepath)\n",
    "                            metrics_data[eval_type][seed_index, idx] = data['ROUGE Forget'].mean()\n",
    "                            \n",
    "            else:\n",
    "                # Existing processing for other metric types\n",
    "                if metric_name == '1- MIA':\n",
    "                    for pct in range(num_pct_values):\n",
    "                        filename = f'mia_retraining_{pct}pct.json'\n",
    "                        filepath = os.path.join(folder_path, filename)\n",
    "                        with open(filepath, 'r') as file:\n",
    "                            data = json.load(file)\n",
    "                            metrics_data[eval_type][seed_index, pct] = 1 - data['forget_holdout_Min-40%']\n",
    "\n",
    "                elif metric_name == 'KnowMem':\n",
    "                    for pct in range(num_pct_values):\n",
    "                        filepath = os.path.join(folder_path, f'knowmem/{pct}/retraining/aggregated.json')\n",
    "                        with open(filepath, 'r') as file:\n",
    "                            data = json.load(file)\n",
    "                            metrics_data[eval_type][seed_index, pct] = data['KnowMem Forget']['mean_rougeL_recall']\n",
    "\n",
    "                elif metric_name == 'WaterDrum':\n",
    "                    if eval_type == \"calibration_duplicate_interclass\":\n",
    "                        wtm_path = os.path.join(main_folder_path, f'seed_{seed}/watermarked_calibration_interduplicate/watermark_verify/')\n",
    "                    else:\n",
    "                        wtm_path = os.path.join(main_folder_path, f'seed_{seed}/watermarked_{eval_type}/watermark_verify/')\n",
    "\n",
    "                    wtm_data = np.load(os.path.join(wtm_path, 'calibration.npy'))\n",
    "                    for pct in range(num_pct_values):\n",
    "                        metrics_data[eval_type][seed_index, pct] = wtm_data[pct].mean()\n",
    "\n",
    "                else:  # rougeL_recall\n",
    "                    for pct in range(num_pct_values):\n",
    "                        filename = f'rouge_retraining_{pct}pct.csv'\n",
    "                        filepath = os.path.join(folder_path, filename)\n",
    "                        data = pd.read_csv(filepath)\n",
    "                        metrics_data[eval_type][seed_index, pct] = data['ROUGE Forget'].mean()\n",
    "\n",
    "    # Compute averages and percentiles for all metrics\n",
    "    cal_avg = np.mean(metrics_data['calibration'], axis=0)\n",
    "    cal_5th = np.min(metrics_data['calibration'], axis=0)\n",
    "    cal_95th = np.max(metrics_data['calibration'], axis=0)\n",
    "\n",
    "\n",
    "    cal_dupl_avg = np.mean(metrics_data['calibration_duplicate'], axis=0)\n",
    "    cal_dupl_5th = np.min(metrics_data['calibration_duplicate'], axis=0)\n",
    "    cal_dupl_95th = np.max(metrics_data['calibration_duplicate'], axis=0)\n",
    "\n",
    "    cal_sem_dupl_avg = np.mean(metrics_data['calibration_semduplicate'], axis=0)\n",
    "    cal_sem_dupl_5th = np.min(metrics_data['calibration_semduplicate'], axis=0)\n",
    "    cal_sem_dupl_95th = np.max(metrics_data['calibration_semduplicate'], axis=0)\n",
    "\n",
    "    # if metric_name == 'WaterDrum':\n",
    "    cal_avg = scale_by_first_element(cal_avg)\n",
    "    cal_5th = scale_by_first_element(cal_5th)\n",
    "    cal_95th = scale_by_first_element(cal_95th)\n",
    "    \n",
    "    cal_dupl_avg = scale_by_first_element(cal_dupl_avg)\n",
    "    cal_dupl_5th = scale_by_first_element(cal_dupl_5th)\n",
    "    cal_dupl_95th = scale_by_first_element(cal_dupl_95th)\n",
    "    \n",
    "    cal_sem_dupl_avg = scale_by_first_element(cal_sem_dupl_avg)\n",
    "    cal_sem_dupl_5th = scale_by_first_element(cal_sem_dupl_5th)\n",
    "    cal_sem_dupl_95th = scale_by_first_element(cal_sem_dupl_95th)\n",
    "        \n",
    "    # Set up Seaborn style\n",
    "    sns.set(style=\"whitegrid\")\n",
    "\n",
    "    # Define distinct colors using a subtle palette\n",
    "    colors = sns.color_palette(\"muted\", n_colors=6)\n",
    "\n",
    "    # Reverse `x_values` to go from 100 to 0 for plotting, but we will label them from 0 to 100\n",
    "    x_values = np.arange(100, -1, -10)\n",
    "\n",
    "    # Plot calibration_duplicate and its best fit line\n",
    "    ax.plot(x_values, cal_dupl_avg[::-1], marker='s', label=f'Calibration Exact Duplicate', \n",
    "            color=colors[1], linestyle='-', markersize=8)\n",
    "    ax.fill_between(x_values, cal_dupl_5th[::-1], cal_dupl_95th[::-1], color=colors[1], alpha=0.2)\n",
    "    r2_dupl, slope_dupl = plot_best_fit_line(ax, x_values, cal_dupl_avg[::-1], colors[1])\n",
    "\n",
    "    # Plot calibration_semduplicate and its best fit line\n",
    "    ax.plot(x_values, cal_sem_dupl_avg[::-1], marker='^', label=f'Calibration Semantic Duplicate', \n",
    "            color=colors[2], linestyle='-', markersize=8)\n",
    "    ax.fill_between(x_values, cal_sem_dupl_5th[::-1], cal_sem_dupl_95th[::-1], color=colors[2], alpha=0.2)\n",
    "    r2_sem, slope_sem = plot_best_fit_line(ax, x_values, cal_sem_dupl_avg[::-1], colors[2])\n",
    "\n",
    "    # Plot original calibration and its best fit line\n",
    "    ax.plot(x_values, cal_avg[::-1], marker='o', label=f'Calibration No Duplicate', \n",
    "            color=colors[0], linestyle='-', markersize=8)\n",
    "    ax.fill_between(x_values, cal_5th[::-1], cal_95th[::-1], color=colors[0], alpha=0.2)\n",
    "    r2_cal, slope_cal = plot_best_fit_line(ax, x_values, cal_avg[::-1], colors[0])\n",
    "\n",
    "    x_values = np.arange(100, -1, -20)\n",
    "    # Customize the plot\n",
    "    ax.set_xticks(x_values)  # Keep ticks same to align with reversed data\n",
    "    ax.set_xticklabels(np.arange(0, 101, 20))  # Change tick labels to go from 0 to 100\n",
    "    ax.set_xlim([100, 0])  # Explicit x-axis limits ensure plotting direction is reversed\n",
    "    # ax.set_xlabel('Percentage of Remaining Forgotten Data', fontsize=27)\n",
    "    ax.tick_params(axis='both', labelsize=28)\n",
    "    ax.set_title(f'{metric_name}', fontsize=33, weight='bold')\n",
    "    ax.grid(True, which='both', linestyle='--', linewidth=0.5)\n",
    "    ax.set_ylim(0, 1.15)\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_color('black')\n",
    "\n",
    "    # Print R² values and slopes\n",
    "    print(f\"\\nMetric: {metric_name}\")\n",
    "    print(f\"No Duplicate: R²={r2_cal:.3f}, slope={slope_cal:.3f}\")\n",
    "    print(f\"Exact Duplicate: R²={r2_dupl:.3f}, slope={slope_dupl:.3f}\")\n",
    "    print(f\"Semantic Duplicate: R²={r2_sem:.3f}, slope={slope_sem:.3f}\")\n",
    "    # print(f\"Inter-class Duplicate: R²={r2_noclass:.3f}, slope={slope_noclass:.3f}\")\n",
    "    # print(f\"Intra-class Duplicate: R²={r2_intclass:.3f}, slope={slope_intclass:.3f}\")\n",
    "    \n",
    "    return metrics_data\n",
    "\n",
    "# Create a single figure with four subplots in a 2x2 grid\n",
    "fig, axes = plt.subplots(1, 4, figsize=(28, 6))  # Adjusted to a 2x2 grid for improved layout\n",
    "\n",
    "# Flatten the axes array for easy indexing\n",
    "axes = axes.flatten()\n",
    "\n",
    "# Plot for each metric separately\n",
    "metrics_data = process_and_plot_metrics(axes[0], 'ROUGE', 'ROUGE Forget', seeds, main_folder_path, num_pct_values)\n",
    "process_and_plot_metrics(axes[1], 'KnowMem', 'KnowMem', seeds, main_folder_path, num_pct_values)\n",
    "process_and_plot_metrics(axes[2], '1- MIA', '1- MIA', seeds, main_folder_path, num_pct_values)\n",
    "process_and_plot_metrics(axes[3], 'WaterDrum', 'WaterDrum Forget', seeds, main_folder_path, num_pct_values)\n",
    "\n",
    "fig.supxlabel('Percentage of Remaining Forgotten Data', fontsize=32, y=0.03)\n",
    "\n",
    "# Adjust layout to make room for a single legend\n",
    "plt.tight_layout(rect=[0, 0, 1, 0.92])\n",
    "plt.subplots_adjust(wspace=0.2, hspace=0.3)\n",
    "\n",
    "# Create a single legend outside the subplots, above them\n",
    "handles, labels = axes[0].get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=3, fontsize=28)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "\n",
    "# Define the path to the main folder containing results for all seeds\n",
    "main_folder_path = '~/nfs_public/watermark_arxiv/main_results'\n",
    "seeds = [41, 42, 43]  # List of seeds you want to use\n",
    "num_pct_values = 11  # As you have percentages from 0 to 10\n",
    "\n",
    "def scale_by_first_element(array):\n",
    "    first_element = array[0]\n",
    "    last_element = array[-1]\n",
    "    if first_element == last_element:\n",
    "        return np.ones_like(array)  # Avoids division by zero\n",
    "    return (array - last_element) / (first_element - last_element)\n",
    "\n",
    "def process_and_plot_metrics(ax, metric_name, y_label, seeds, main_folder_path, num_pct_values):\n",
    "    # Specific percentage points for calibration_duplicate_10times\n",
    "    specific_pcts_indices = [0, 2, 4, 6, 8, 10]\n",
    "\n",
    "    # Initialize dictionaries to store arrays for the current metric\n",
    "    metrics_data = {\n",
    "        'calibration': np.zeros((len(seeds), num_pct_values)),\n",
    "        'calibration_duplicate': np.zeros((len(seeds), num_pct_values)),\n",
    "        'calibration_semduplicate': np.zeros((len(seeds), num_pct_values)),\n",
    "    }\n",
    "\n",
    "    # Process calibration data\n",
    "    for seed_index, seed in enumerate(seeds):\n",
    "        for eval_type in metrics_data.keys():\n",
    "            folder_path = os.path.join(main_folder_path, f'seed_{seed}/{eval_type}/eval')\n",
    "\n",
    "            if eval_type == 'calibration_duplicate_10times':\n",
    "                if metric_name in ['1- MIA', 'KnowMeM', 'WaterDrum', 'ROUGE']:\n",
    "                    for idx, pct in enumerate(specific_pcts_indices):\n",
    "                        if metric_name == '1- MIA':\n",
    "                            filename = f'mia_retraining_{pct}pct.json'\n",
    "                            filepath = os.path.join(folder_path, filename)\n",
    "                            with open(filepath, 'r') as file:\n",
    "                                data = json.load(file)\n",
    "                                metrics_data[eval_type][seed_index, idx] = 1 - data['forget_holdout_Min-40%']\n",
    "                        elif metric_name == 'KnowMeM':\n",
    "                            filepath = os.path.join(folder_path, f'knowmem/{pct}/retraining/aggregated.json')\n",
    "                            with open(filepath, 'r') as file:\n",
    "                                data = json.load(file)\n",
    "                                metrics_data[eval_type][seed_index, idx] = data['KnowMem Forget']['mean_rougeL_recall']\n",
    "\n",
    "                        elif metric_name == 'WaterDrum':\n",
    "                            wtm_path = os.path.join(main_folder_path, f'seed_{seed}/watermarked_{eval_type}/watermark_verify/')\n",
    "                            wtm_data = np.load(os.path.join(wtm_path, 'calibration.npy'))\n",
    "                            metrics_data[eval_type][seed_index, idx] = wtm_data[pct].mean()\n",
    "\n",
    "                        else:  # rougeL_recall\n",
    "                            filename = f'rouge_retraining_{pct}pct.csv'\n",
    "                            filepath = os.path.join(folder_path, filename)\n",
    "                            data = pd.read_csv(filepath)\n",
    "                            metrics_data[eval_type][seed_index, idx] = data['ROUGE Forget'].mean()\n",
    "\n",
    "            else:\n",
    "                # Existing processing for other metric types\n",
    "                if metric_name == '1- MIA':\n",
    "                    for pct in range(num_pct_values):\n",
    "                        filename = f'mia_retraining_{pct}pct.json'\n",
    "                        filepath = os.path.join(folder_path, filename)\n",
    "                        with open(filepath, 'r') as file:\n",
    "                            data = json.load(file)\n",
    "                            metrics_data[eval_type][seed_index, pct] = 1 - data['forget_holdout_Min-40%']\n",
    "\n",
    "                elif metric_name == 'KnowMem':\n",
    "                    for pct in range(num_pct_values):\n",
    "                        filepath = os.path.join(folder_path, f'knowmem/{pct}/retraining/aggregated.json')\n",
    "                        with open(filepath, 'r') as file:\n",
    "                            data = json.load(file)\n",
    "                            metrics_data[eval_type][seed_index, pct] = data['KnowMem Forget']['mean_rougeL_recall']\n",
    "\n",
    "                elif metric_name == 'WaterDrum':\n",
    "                    if eval_type == \"calibration_duplicate_interclass\":\n",
    "                        wtm_path = os.path.join(main_folder_path, f'seed_{seed}/watermarked_calibration_interduplicate/watermark_verify/')\n",
    "                    else:\n",
    "                        wtm_path = os.path.join(main_folder_path, f'seed_{seed}/watermarked_{eval_type}/watermark_verify/')\n",
    "\n",
    "                    wtm_data = np.load(os.path.join(wtm_path, 'calibration.npy'))\n",
    "                    for pct in range(num_pct_values):\n",
    "                        metrics_data[eval_type][seed_index, pct] = wtm_data[pct].mean()\n",
    "\n",
    "                else:  # rougeL_recall\n",
    "                    for pct in range(num_pct_values):\n",
    "                        filename = f'rouge_retraining_{pct}pct.csv'\n",
    "                        filepath = os.path.join(folder_path, filename)\n",
    "                        data = pd.read_csv(filepath)\n",
    "                        metrics_data[eval_type][seed_index, pct] = data['ROUGE Forget'].mean()\n",
    "\n",
    "    # Compute averages and percentiles for all metrics\n",
    "    cal_avg = np.mean(metrics_data['calibration'], axis=0)\n",
    "    cal_5th = np.min(metrics_data['calibration'], axis=0)\n",
    "    cal_95th = np.max(metrics_data['calibration'], axis=0)\n",
    "\n",
    "    cal_dupl_avg = np.mean(metrics_data['calibration_duplicate'], axis=0)\n",
    "    cal_dupl_5th = np.min(metrics_data['calibration_duplicate'], axis=0)\n",
    "    cal_dupl_95th = np.max(metrics_data['calibration_duplicate'], axis=0)\n",
    "\n",
    "    cal_sem_dupl_avg = np.mean(metrics_data['calibration_semduplicate'], axis=0)\n",
    "    cal_sem_dupl_5th = np.min(metrics_data['calibration_semduplicate'], axis=0)\n",
    "    cal_sem_dupl_95th = np.max(metrics_data['calibration_semduplicate'], axis=0)\n",
    "\n",
    "    # if metric_name == 'WaterDrum':\n",
    "    cal_avg = scale_by_first_element(cal_avg)\n",
    "    cal_5th = scale_by_first_element(cal_5th)\n",
    "    cal_95th = scale_by_first_element(cal_95th)\n",
    "\n",
    "    cal_dupl_avg = scale_by_first_element(cal_dupl_avg)\n",
    "    cal_dupl_5th = scale_by_first_element(cal_dupl_5th)\n",
    "    cal_dupl_95th = scale_by_first_element(cal_dupl_95th)\n",
    "\n",
    "    cal_sem_dupl_avg = scale_by_first_element(cal_sem_dupl_avg)\n",
    "    cal_sem_dupl_5th = scale_by_first_element(cal_sem_dupl_5th)\n",
    "    cal_sem_dupl_95th = scale_by_first_element(cal_sem_dupl_95th)\n",
    "\n",
    "    # Set up Seaborn style\n",
    "    sns.set(style=\"whitegrid\")\n",
    "\n",
    "    # Define distinct colors using a subtle palette\n",
    "    colors = sns.color_palette(\"muted\", n_colors=6)\n",
    "\n",
    "    # Reverse `x_values` for plotting\n",
    "    x_values = np.arange(100, -1, -10)\n",
    "\n",
    "    # Plot calibration_duplicate and its best fit line\n",
    "    ax.plot(x_values, cal_dupl_avg[::-1], marker='s', label=f'Calibration Exact Duplicate', \n",
    "            color=colors[1], linestyle='-', markersize=8)\n",
    "    ax.fill_between(x_values, cal_dupl_5th[::-1], cal_dupl_95th[::-1], color=colors[1], alpha=0.2)\n",
    "    r2_dupl, slope_dupl = plot_best_fit_line(ax, x_values, cal_dupl_avg[::-1], colors[1])\n",
    "\n",
    "    # Plot calibration_semduplicate and its best fit line\n",
    "    ax.plot(x_values, cal_sem_dupl_avg[::-1], marker='^', label=f'Calibration Semantic Duplicate', \n",
    "            color=colors[2], linestyle='-', markersize=8)\n",
    "    ax.fill_between(x_values, cal_sem_dupl_5th[::-1], cal_sem_dupl_95th[::-1], color=colors[2], alpha=0.2)\n",
    "    r2_sem, slope_sem = plot_best_fit_line(ax, x_values, cal_sem_dupl_avg[::-1], colors[2])\n",
    "\n",
    "    # Plot original calibration and its best fit line\n",
    "    ax.plot(x_values, cal_avg[::-1], marker='o', label=f'Calibration No Duplicate', \n",
    "            color=colors[0], linestyle='-', markersize=8)\n",
    "    ax.fill_between(x_values, cal_5th[::-1], cal_95th[::-1], color=colors[0], alpha=0.2)\n",
    "    r2_cal, slope_cal = plot_best_fit_line(ax, x_values, cal_avg[::-1], colors[0])\n",
    "\n",
    "    # Customize the plot\n",
    "    ax.set_xticks(x_values)\n",
    "    ax.set_xticklabels(np.arange(0, 101, 10))  # Change tick labels from 0 to 100\n",
    "    ax.set_xlim([100, 0])  # Ensure x-axis is plotted in reverse direction\n",
    "    ax.set_xlabel('Percentage of Remaining Forgotten Data', fontsize=22)\n",
    "    ax.tick_params(axis='both', labelsize=18)\n",
    "    ax.set_title(f'{metric_name}', fontsize=24, weight='bold')\n",
    "    ax.grid(True, which='both', linestyle='--', linewidth=0.5)\n",
    "\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_color('black')\n",
    "\n",
    "    # Print R² values and slopes\n",
    "    print(f\"\\nMetric: {metric_name}\")\n",
    "    print(f\"No Duplicate: R²={r2_cal:.3f}, slope={slope_cal:.3f}\")\n",
    "    print(f\"Exact Duplicate: R²={r2_dupl:.3f}, slope={slope_dupl:.3f}\")\n",
    "    print(f\"Semantic Duplicate: R²={r2_sem:.3f}, slope={slope_sem:.3f}\")\n",
    "\n",
    "# Create a single figure with four subplots in a 2x2 grid\n",
    "fig, axes = plt.subplots(2, 2, figsize=(16, 12))  # Adjusted to a 2x2 grid for improved layout\n",
    "\n",
    "# Flatten the axes array for easy indexing\n",
    "axes = axes.flatten()\n",
    "\n",
    "# Plot for each metric separately\n",
    "process_and_plot_metrics(axes[0], 'ROUGE', 'ROUGE Forget', seeds, main_folder_path, num_pct_values)\n",
    "process_and_plot_metrics(axes[1], 'KnowMem', 'KnowMem', seeds, main_folder_path, num_pct_values)\n",
    "process_and_plot_metrics(axes[2], '1- MIA', '1- MIA', seeds, main_folder_path, num_pct_values)\n",
    "process_and_plot_metrics(axes[3], 'WaterDrum', 'WaterDrum Forget', seeds, main_folder_path, num_pct_values)\n",
    "\n",
    "# Adjust layout to make room for a single legend\n",
    "plt.tight_layout(rect=[0, 0, 1, 0.92])\n",
    "\n",
    "# Create a single legend outside the subplots, above them\n",
    "handles, labels = axes[0].get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=2, fontsize=22)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "\n",
    "# Define the path to the main folder containing results for all seeds\n",
    "main_folder_path = '~/nfs_public/watermark_arxiv/main_results'\n",
    "seeds = [41, 42, 43]  # List of seeds you want to use\n",
    "num_pct_values = 6  # As you have percentages from 0 to 10\n",
    "\n",
    "def scale_by_first_element(array):\n",
    "    first_element = array[0]\n",
    "    return array / first_element if first_element != 0 else array\n",
    "\n",
    "def process_and_plot_metrics(ax, metric_name, y_label, seeds, main_folder_path, num_pct_values):\n",
    "    # Initialize dictionaries to store arrays for the current metric\n",
    "    metrics_data = {\n",
    "        'phi_calibration': np.zeros((len(seeds), num_pct_values)),\n",
    "        'phi_calibration_duplicate': np.zeros((len(seeds), num_pct_values)),\n",
    "    }\n",
    "\n",
    "    # Process calibration data\n",
    "    for seed_index, seed in enumerate(seeds):\n",
    "        for eval_type in metrics_data.keys():\n",
    "            folder_path = os.path.join(main_folder_path, f'seed_{seed}/{eval_type}/eval')\n",
    "\n",
    "            if metric_name == 'WaterDrum':\n",
    "                wtm_path = os.path.join(main_folder_path, f'seed_{seed}/watermarked_{eval_type}/watermark_verify/')\n",
    "                wtm_data = np.load(os.path.join(wtm_path, 'calibration.npy'))\n",
    "                for pct in range(num_pct_values):\n",
    "                    metrics_data[eval_type][seed_index, pct] = wtm_data[pct].mean()\n",
    "\n",
    "    # Compute averages and percentiles for all metrics\n",
    "    cal_avg = np.mean(metrics_data['phi_calibration'], axis=0)\n",
    "    cal_5th = np.min(metrics_data['phi_calibration'], axis=0)\n",
    "    cal_95th = np.max(metrics_data['phi_calibration'], axis=0)\n",
    "\n",
    "    cal_dupl_avg = np.mean(metrics_data['phi_calibration_duplicate'], axis=0)\n",
    "    cal_dupl_5th = np.min(metrics_data['phi_calibration_duplicate'], axis=0)\n",
    "    cal_dupl_95th = np.max(metrics_data['phi_calibration_duplicate'], axis=0)\n",
    "\n",
    "    cal_avg = scale_by_first_element(cal_avg)\n",
    "    cal_5th = scale_by_first_element(cal_5th)\n",
    "    cal_95th = scale_by_first_element(cal_95th)\n",
    "\n",
    "    cal_dupl_avg = scale_by_first_element(cal_dupl_avg)\n",
    "    cal_dupl_5th = scale_by_first_element(cal_dupl_5th)\n",
    "    cal_dupl_95th = scale_by_first_element(cal_dupl_95th)\n",
    "\n",
    "    # Set up Seaborn style\n",
    "    sns.set(style=\"whitegrid\")\n",
    "\n",
    "    # Define distinct colors using a subtle palette\n",
    "    colors = sns.color_palette(\"muted\", n_colors=6)\n",
    "\n",
    "    # Reverse `x_values` for plotting\n",
    "    x_values = np.arange(100, -1, -20)  # Since num_pct_values=6, adjust to your actual scale\n",
    "\n",
    "    # Plot original calibration\n",
    "    r2_cal, slope_cal = plot_best_fit_line(ax, x_values, cal_avg[::-1], colors[0])\n",
    "    ax.plot(x_values, cal_avg[::-1], marker='o', label=f'Calibration No Duplicate (R²={r2_cal:.2f})', \n",
    "            color=colors[0], linestyle='-', markersize=8)\n",
    "    ax.fill_between(x_values, cal_5th[::-1], cal_95th[::-1], color=colors[0], alpha=0.2)\n",
    "\n",
    "\n",
    "    # Plot calibration_duplicate\n",
    "    r2_dupl, slope_dupl = plot_best_fit_line(ax, x_values, cal_dupl_avg[::-1], colors[1])\n",
    "    ax.plot(x_values, cal_dupl_avg[::-1], marker='s', label=f'Calibration Duplicate (R²={r2_dupl:.2f})', \n",
    "            color=colors[1], linestyle='-', markersize=8)\n",
    "    ax.fill_between(x_values, cal_dupl_5th[::-1], cal_dupl_95th[::-1], color=colors[1], alpha=0.2)\n",
    "\n",
    "\n",
    "    # Print R² values and slopes\n",
    "    print(f\"No Duplicate: R²={r2_cal:.3f}, slope={slope_cal:.3f}\")\n",
    "    print(f\"Exact Duplicate: R²={r2_dupl:.3f}, slope={slope_dupl:.3f}\")\n",
    "\n",
    "    # Customize the plot\n",
    "    ax.set_xticks(x_values)\n",
    "    ax.set_xticklabels(np.arange(0, 101, 20))  # Change tick labels from 0 to 100\n",
    "    ax.set_xlim([100, 0])  # Ensure x-axis is plotted in reverse direction\n",
    "    ax.set_xlabel('Percentage of Remaining Forgotten Data', fontsize=29)\n",
    "    ax.set_ylabel('WaterDrum', fontsize=29)\n",
    "\n",
    "    ax.set_title('Calibration for Phi-1.5', fontsize=32, weight='bold')\n",
    "\n",
    "    # Set the font size for the tick labels\n",
    "    ax.tick_params(axis='both', which='major', labelsize=30)\n",
    "    ax.grid(True, which='both', linestyle='--', linewidth=0.5)\n",
    "\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_color('black')\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(10, 8))\n",
    "\n",
    "# Plot for each metric separately\n",
    "process_and_plot_metrics(ax, 'WaterDrum', 'WaterDrum Forget', seeds, main_folder_path, num_pct_values)\n",
    "\n",
    "# Adjust layout to make room for a single legend\n",
    "plt.tight_layout()\n",
    "\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.08), ncol=2, fontsize=22)\n",
    "ax.grid(True, which='both', linestyle='--', linewidth=0.5)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "unlearning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
