{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import pandas as pd\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import os\n",
    "\n",
    "import matplotlib.font_manager as fm\n",
    "import matplotlib as mpl\n",
    "\n",
    "mpl.rcParams[\"figure.dpi\"] = 300\n",
    "\n",
    "mpl.rcParams[\"pdf.fonttype\"] = 42\n",
    "mpl.rcParams[\"ps.fonttype\"] = 42\n",
    "\n",
    "sns.set(style=\"whitegrid\")\n",
    "sns.set_context(\n",
    "    \"paper\",\n",
    "    font_scale=1.5,\n",
    "    rc={\"lines.markersize\": 7, \"lines.linewidth\": 4, \"axes.linewidth\": 3},\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Indicate the path to the parent folder\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_samples = 1\n",
    "parent_folder = \"opt_results\"\n",
    "algorithm_name_list = [\n",
    "    \"KGW\",\n",
    "    # \"Unigram\",\n",
    "    # \"SWEET\",\n",
    "    # \"EWD\",\n",
    "    # \"SIR\", \n",
    "    # \"XSIR\", \"UPV\", \"DIP\"\n",
    "]\n",
    "\n",
    "setting_list = [\n",
    "    \"unwatermarked\",\n",
    "    \"watermarked\",\n",
    "    \"watermarked_with_word-d\",\n",
    "    \"watermarked_with_word-s\",\n",
    "    # \"watermarked_with_gpt-3.5\"\n",
    "    \n",
    "]\n",
    "\n",
    "setting_name_list = [\n",
    "    \"Baseline\",\n",
    "    \"Watermarked\",\n",
    "    \"Watermarked\",\n",
    "    \"Watermarked\",\n",
    "    # \"Watermarked-GPT-3.5\",\n",
    "]\n",
    "target_model = \"facebook/opt-1.3b\"\n",
    "oracle_model = \"facebook/opt-2.7b\"\n",
    "ref_model = \"facebook/opt-125m\"\n",
    "ref_model_name = ref_model.replace(\"/\", \"_\")\n",
    "target_model_save_name = target_model.replace(\"/\", \"_\")\n",
    "oracle_model_save_name = oracle_model.replace(\"/\", \"_\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = []\n",
    "for algorithm_name in algorithm_name_list:\n",
    "    print(algorithm_name)\n",
    "    attack_type = \"None\"\n",
    "    attack_strength = 0.0\n",
    "    for s_idx, setting in enumerate(setting_list):\n",
    "        if s_idx == 0 or s_idx == 1:\n",
    "            attack_type_name = \"None\"\n",
    "        elif s_idx == 2:\n",
    "            attack_type_name = \"Word-D\"\n",
    "        else:\n",
    "            attack_type_name = \"Word-S\"\n",
    "        attack_strength = 0.0\n",
    "        with open(\n",
    "            f\"{parent_folder}/{algorithm_name}_{target_model_save_name}/res_{setting}_{attack_type}_{attack_strength}_ppl_diversity_{oracle_model_save_name}.pkl\",\n",
    "            \"rb\",\n",
    "        ) as f:\n",
    "            results = pickle.load(f)\n",
    "        with open(\n",
    "            f\"{parent_folder}/{algorithm_name}_{target_model_save_name}/res_{setting}_{attack_type}_{attack_strength}.pkl\",\n",
    "            \"rb\",\n",
    "        ) as f:\n",
    "            text_results = pickle.load(f)\n",
    "        print(setting, len(results[\"ppl\"]))\n",
    "        for i in range(num_samples):\n",
    "            data.append(\n",
    "                {\n",
    "                    \"algorithm_name\": algorithm_name,\n",
    "                    \"ppl\": results[\"ppl\"][i],\n",
    "                    \"setting\": setting_name_list[s_idx],\n",
    "                    \"attack_type\": attack_type_name,\n",
    "                    \"attack_strength\": attack_strength,\n",
    "                    \"diversity\": results[\"diversity\"][i],\n",
    "                    \"score\": results[\"score\"][i],\n",
    "                    \"is_watermarked\": text_results[i].detect_result[\"is_watermarked\"],\n",
    "                }\n",
    "            )\n",
    "    # Load the results from unwatermarked model\n",
    "    setting = \"Unwatermarked\"\n",
    "    attack_type = \"None\"\n",
    "    attack_strength = 0.0\n",
    "    with open(\n",
    "        f\"{parent_folder}/{algorithm_name}_{target_model_save_name}/res_baseline_{attack_type}_{attack_strength}_ppl_diversity_{oracle_model_save_name}.pkl\",\n",
    "        \"rb\",\n",
    "    ) as f:\n",
    "        results = pickle.load(f)\n",
    "    with open(\n",
    "        f\"{parent_folder}/{algorithm_name}_{target_model_save_name}/res_baseline_{attack_type}_{attack_strength}.pkl\",\n",
    "        \"rb\",\n",
    "    ) as f:\n",
    "        text_results = pickle.load(f)\n",
    "    print(setting, len(results[\"ppl\"]))\n",
    "    for i in range(num_samples):\n",
    "        data.append(\n",
    "            {\n",
    "                \"algorithm_name\": algorithm_name,\n",
    "                \"ppl\": results[\"ppl\"][i],\n",
    "                \"setting\": \"Unwatermarked\",\n",
    "                \"attack_type\": \"None\",\n",
    "                \"attack_strength\": attack_strength,\n",
    "                \"diversity\": results[\"diversity\"][i],\n",
    "                \"score\": results[\"score\"][i],\n",
    "                \"is_watermarked\": text_results[i].detect_result[\"is_watermarked\"],\n",
    "            }\n",
    "        )\n",
    "  \n",
    "    # Load the results from watermarked model\n",
    "    setting = \"watermarked_with_our_attack\"\n",
    "    attack_type = \"our_attack\"\n",
    "    for attack_strength in [0.5]:\n",
    "        with open(\n",
    "            f\"{parent_folder}/{algorithm_name}_{target_model_save_name}/res_{setting}_{attack_type}_{attack_strength}_ppl_diversity_{oracle_model_save_name}.pkl\",\n",
    "            \"rb\",\n",
    "        ) as f:\n",
    "            results = pickle.load(f)\n",
    "        with open(\n",
    "            f\"{parent_folder}/{algorithm_name}_{target_model_save_name}/res_{setting}_{attack_type}_{attack_strength}.pkl\",\n",
    "            \"rb\",\n",
    "        ) as f:\n",
    "            text_results = pickle.load(f)\n",
    "        print(setting, algorithm_name, len(results[\"ppl\"]))\n",
    "        for i in range(num_samples):\n",
    "            data.append(\n",
    "                {\n",
    "                    \"algorithm_name\": algorithm_name,\n",
    "                    \"ppl\": results[\"ppl\"][i],\n",
    "                    \"setting\": \"Watermarked\",\n",
    "                    \"attack_type\": f\"Ours {attack_strength}\",\n",
    "                    \"attack_strength\": attack_strength,\n",
    "                    \"diversity\": results[\"diversity\"][i],\n",
    "                    \"score\": results[\"score\"][i],\n",
    "                    \"is_watermarked\": text_results[i].detect_result[\"is_watermarked\"],\n",
    "                }\n",
    "            )\n",
    "\n",
    "\n",
    "df = pd.DataFrame(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get figure 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "from plot_setting import *\n",
    "import seaborn as sns\n",
    "test_alg = \"KGW\"\n",
    "\n",
    "\n",
    "show_df = df[\n",
    "    ((df[\"algorithm_name\"] == test_alg) & (df[\"setting\"] == \"Watermarked\"))\n",
    "    & (\n",
    "        (df[\"attack_type\"] == \"Ours 0.5\")\n",
    "        | (df[\"attack_type\"] == \"None\")\n",
    "        | (df[\"attack_type\"] == \"Word-D\")\n",
    "        | (df[\"attack_type\"] == \"Word-S\")\n",
    "    )\n",
    "]\n",
    "show_df[\"attack_type\"] = show_df[\"attack_type\"].apply(\n",
    "    lambda x: \"No attack\" if x == \"None\" else x\n",
    ")\n",
    "# Add unwatermarked setting to show_df\n",
    "show_df_unwatermarked = df[\n",
    "    (df[\"algorithm_name\"] == test_alg) & (df[\"setting\"] == \"Unwatermarked\")\n",
    "]\n",
    "show_df_unwatermarked[\"attack_type\"] = \"Unwatermarked\"\n",
    "show_df = pd.concat([show_df, show_df_unwatermarked])\n",
    "\n",
    "# Clip ppl values and update attack_type labels\n",
    "show_df[\"ppl\"] = np.clip(show_df[\"ppl\"], 0, 100)\n",
    "show_df[\"attack_type\"] = show_df[\"attack_type\"].apply(\n",
    "    lambda x: \"No attack\" if x == \"None\" else x\n",
    ")\n",
    "\n",
    "\n",
    "# Define the order of attack types\n",
    "attack_order = [\"Unwatermarked\", \"No attack\", \"Word-D\", \"Word-S\", \"Ours 0.5\"]\n",
    "\n",
    "# Create a custom palette\n",
    "palette = sns.color_palette(None, n_colors=len(attack_order))\n",
    "palette = [\"#1f77b4\", \"#ff7f0e\", \"#2ca02c\", \"#9467bd\", \"#d62728\"]\n",
    "# Create the jointplot\n",
    "g = sns.jointplot(\n",
    "    data=show_df,\n",
    "    x=\"score\",\n",
    "    y=\"ppl\",\n",
    "    hue=\"attack_type\",\n",
    "    hue_order=attack_order,\n",
    "    palette=palette,\n",
    "    alpha=0.7,\n",
    "    ylim=(0, 105),\n",
    "    xlim=(-3, 13),\n",
    "    height=5,\n",
    "    ratio=5,\n",
    ")\n",
    "\n",
    "g.ax_joint.set_xlabel(\"Z-Score (KGW)\")\n",
    "g.ax_joint.set_ylabel(\"Perplexity (PPL)\")\n",
    "\n",
    "# Customize the legend\n",
    "handles, labels = g.ax_joint.get_legend_handles_labels()\n",
    "g.ax_joint.legend(\n",
    "    handles,\n",
    "    [\"Un-W\", \"W\", \"W (Word-D)\", \"W (Word-S)\", \"W (Ours)\"],\n",
    "    loc=\"upper right\",\n",
    "    bbox_to_anchor=(1, 1),\n",
    "    ncol=1,\n",
    "    fontsize=8,\n",
    ")\n",
    "g.fig.set_size_inches((5, 3.5))\n",
    "# Remove grid\n",
    "g.ax_joint.grid(False)\n",
    "g.ax_marg_x.grid(False)\n",
    "g.ax_marg_y.grid(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get the table 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"loss\"] = np.log(df[\"ppl\"])\n",
    "df[\"combined_setting\"] = df[\"setting\"] + df[\"attack_type\"]\n",
    "\n",
    "df[\"combined_setting\"] = df[\"combined_setting\"].map({\n",
    "    \"BaselineNone\": \"Baseline\",\n",
    "    \"ReferenceNone\": \"Reference\",\n",
    "    \"UnwatermarkedNone\": \"Unwatermarked\",\n",
    "    \"WatermarkedNone\": \"Watermarked\",\n",
    "    \"WatermarkedWord-D\": \"Watermarked (Word-D)\",\n",
    "    \"WatermarkedWord-S\": \"Watermarked (Word-S)\",\n",
    "    \"Watermarked-GPT-3.5Word-S\": \"Watermarked (P-GPT3.5)\",\n",
    "    \"WatermarkedOurs 0.5\": \"Watermarked (Smoothing)\"\n",
    "})\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_df = df.groupby([\"algorithm_name\", \"combined_setting\"]).agg({\n",
    "    \"score\": \"mean\",\n",
    "    \"is_watermarked\": \"mean\",\n",
    "    \"ppl\": \"median\", # this is to avoid the impact of the outliers\n",
    "})\n",
    "\n",
    "data_df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
