{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook we reproduce the experiment from Appendix to measure the influence of concatentation.\n",
    "\n",
    "We first keep only text of length $T=150$.\n",
    "Then we compute correlation for the not concatenated text.\n",
    "Then we shuffled all the scores randomly and we compute the correlation for the shuffled version.\n",
    "We compare the two distribution using a KS test."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "from watermark_stealing.config.meta_config import get_pydantic_models_from_path\n",
    "from watermark_stealing.server import Server\n",
    "from src.ngram_counter import load_ngram_counter_from_cfg\n",
    "from scipy import stats\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model_name(cfg):\n",
    "    model_name = cfg.server.model.name.replace(\"/\", \"_\")\n",
    "    model_name += f\"/{cfg.server.watermark.generation.seeding_scheme}/delta{cfg.server.watermark.generation.delta}/gamma{cfg.server.watermark.generation.gamma}\"\n",
    "    return model_name\n",
    "\n",
    "def load_server(cfg):\n",
    "    cfg.server.model.skip = True\n",
    "    server = Server(cfg.meta, cfg.server)\n",
    "    tokenizer = server.model.tokenizer\n",
    "    return server, tokenizer\n",
    "\n",
    "def load_analyzers_from_cfg(cfg):\n",
    "    model_name = get_model_name(cfg)\n",
    "    prefix = f\"data/reprompting/{model_name}/c4/\"\n",
    "    attacker_short_name = cfg.attacker.model.short_str().replace(\"/\", \"_\")\n",
    "    \n",
    "    rng_device = cfg.meta.rng_device\n",
    "    if rng_device == \"cpu\":\n",
    "        watermarked_path = prefix + \"watermarked/\"\n",
    "    else:\n",
    "        watermarked_path = prefix + \"cuda_watermarked/\"\n",
    "    \n",
    "    \n",
    "    spoofed_path = prefix + f\"spoofed_{attacker_short_name}/\"\n",
    "    \n",
    "    with open(watermarked_path + \"analyzers.pkl\", \"rb\") as f:\n",
    "        watermarked_analyzers = pickle.load(f)\n",
    "    with open(spoofed_path + \"analyzers.pkl\", \"rb\") as f:\n",
    "        spoofed_analyzers = pickle.load(f)\n",
    "        \n",
    "    return watermarked_analyzers, spoofed_analyzers\n",
    "\n",
    "\n",
    "def trim_beginning(analyzer, n_trim: int = 30):\n",
    "    \n",
    "    encoded_sentence = analyzer._get_tokenized_sentence()\n",
    "    new_encoded_sentence = encoded_sentence[n_trim:]\n",
    "            \n",
    "    analyzer.color_mask = analyzer.color_mask[n_trim:]\n",
    "    analyzer.sentence_tokens = new_encoded_sentence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg_path = \"\"\n",
    "cfg = get_pydantic_models_from_path(cfg_path)[0]\n",
    "gamma = cfg.server.watermark.generation.gamma\n",
    "watermarked_analyzers, spoofed_analyzers = load_analyzers_from_cfg(cfg)\n",
    "print(len(watermarked_analyzers), len(spoofed_analyzers))\n",
    "\n",
    "watermarked_analyzers = [analyzer for analyzer in watermarked_analyzers if analyzer[1].sentence_tokens is not None]\n",
    "watermarked_analyzers = [analyzer for analyzer in watermarked_analyzers if analyzer[0].sentence_tokens is not None]\n",
    "spoofed_analyzers = [analyzer for analyzer in spoofed_analyzers if analyzer[1].sentence_tokens is not None]\n",
    "spoofed_analyzers = [analyzer for analyzer in spoofed_analyzers if analyzer[0].sentence_tokens is not None]\n",
    "print(len(watermarked_analyzers), len(spoofed_analyzers))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "unigram = load_ngram_counter_from_cfg(cfg,1, ordered=False)\n",
    "server, tokenizer = load_server(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trim_analyzers(analyzers, token_target: int = 150):\n",
    "    \n",
    "    # Shuffle the analyzers\n",
    "    np.random.seed(0)\n",
    "    np.random.shuffle(analyzers)\n",
    "\n",
    "    kept_analyzers = []\n",
    "    current_analyzer1, _  = analyzers[0]\n",
    "    trim_beginning(current_analyzer1)\n",
    "    for analyer_tuple in analyzers[1:]:\n",
    "    \n",
    "        analyzer1, analyzer2 = analyer_tuple\n",
    "      \n",
    "        length1, length2 = analyzer1.get_length(), analyzer2.get_length()\n",
    "        min_length = min(length1, length2)\n",
    "        min_length = min(min_length, token_target)\n",
    "        if min_length == token_target:\n",
    "            analyzer1.shallow_clean_sentence(min_length)\n",
    "            analyzer2.shallow_clean_sentence(min_length)\n",
    "            kept_analyzers.append((analyzer1, analyzer2))\n",
    "    return kept_analyzers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "watermarked_analyzers = trim_analyzers(watermarked_analyzers)   \n",
    "spoofed_analyzers = trim_analyzers(spoofed_analyzers)\n",
    "print(len(watermarked_analyzers), len(spoofed_analyzers))\n",
    "\n",
    "watermarked_analyzers = watermarked_analyzers[:1000]\n",
    "spoofed_analyzers = spoofed_analyzers[:1000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "token_scores_watermarked1 = [analyzer[0].get_token_occurence_context_score(unigram, 0) for analyzer in watermarked_analyzers]\n",
    "token_scores_watermarked2 = [analyzer[1].get_token_occurence_context_score(unigram, 0) for analyzer in watermarked_analyzers]\n",
    "token_scores_spoofed1 = [analyzer[0].get_token_occurence_context_score(unigram, 0) for analyzer in spoofed_analyzers]\n",
    "token_scores_spoofed2 = [analyzer[1].get_token_occurence_context_score(unigram, 0) for analyzer in spoofed_analyzers]\n",
    "\n",
    "color_watermarked1 = [np.array(analyzer[0].color_mask) for analyzer in watermarked_analyzers]\n",
    "color_watermarked2 = [np.array(analyzer[1].color_mask) for analyzer in watermarked_analyzers]\n",
    "color_spoofed1 = [np.array(analyzer[0].color_mask) for analyzer in spoofed_analyzers]\n",
    "color_spoofed2 = [np.array(analyzer[1].color_mask) for analyzer in spoofed_analyzers]\n",
    "\n",
    "mask_scores_watermarked1 = [(np.array(color) != -1) for color in color_watermarked1]\n",
    "mask_scores_watermarked2 = [(np.array(color) != -1) for color in color_watermarked2]\n",
    "mask_scores_spoofed1 = [(np.array(color) != -1) for color in color_spoofed1]\n",
    "mask_scores_spoofed2 = [(np.array(color) != -1) for color in color_spoofed2]\n",
    "\n",
    "token_scores_watermarked1 = [token_score[mask] for token_score, mask in zip(token_scores_watermarked1, mask_scores_watermarked1)]\n",
    "token_scores_watermarked2 = [token_score[mask] for token_score, mask in zip(token_scores_watermarked2, mask_scores_watermarked2)]\n",
    "token_scores_spoofed1 = [token_score[mask] for token_score, mask in zip(token_scores_spoofed1, mask_scores_spoofed1)]\n",
    "token_scores_spoofed2 = [token_score[mask] for token_score, mask in zip(token_scores_spoofed2, mask_scores_spoofed2)]\n",
    "color_watermarked1 = [color[mask] for color, mask in zip(color_watermarked1, mask_scores_watermarked1)]\n",
    "color_watermarked2 = [color[mask] for color, mask in zip(color_watermarked2, mask_scores_watermarked2)]\n",
    "color_spoofed1 = [color[mask] for color, mask in zip(color_spoofed1, mask_scores_spoofed1)]\n",
    "color_spoofed2 = [color[mask] for color, mask in zip(color_spoofed2, mask_scores_spoofed2)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_correlation(score_list, color_list):\n",
    "    return [np.arctanh(stats.spearmanr(score, color).correlation) for score, color in zip(score_list, color_list)]\n",
    "\n",
    "watermarked_correlation1 = compute_correlation(token_scores_watermarked1, color_watermarked1)   \n",
    "watermarked_correlation2 = compute_correlation(token_scores_watermarked2, color_watermarked2)\n",
    "\n",
    "spoofed_correlation1 = compute_correlation(token_scores_spoofed1, color_spoofed1)\n",
    "spoofed_correlation2 = compute_correlation(token_scores_spoofed2, color_spoofed2)\n",
    "\n",
    "watermarked_zscores = [(correlation1 - correlation2)/np.sqrt(1/300) for correlation1, correlation2 in zip(watermarked_correlation1, watermarked_correlation2)]\n",
    "spoofed_zscores = [(correlation1 - correlation2)/np.sqrt(1/300) for correlation1, correlation2 in zip(spoofed_correlation1, spoofed_correlation2)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "concat_token_scores_watermarked1 = np.concatenate(token_scores_watermarked1)\n",
    "concat_token_scores_watermarked2 = np.concatenate(token_scores_watermarked2)\n",
    "concat_token_scores_spoofed1 = np.concatenate(token_scores_spoofed1)\n",
    "concat_token_scores_spoofed2 = np.concatenate(token_scores_spoofed2)\n",
    "\n",
    "concat_color_watermarked1 = np.concatenate(color_watermarked1)\n",
    "concat_color_watermarked2 = np.concatenate(color_watermarked2)\n",
    "concat_color_spoofed1 = np.concatenate(color_spoofed1)\n",
    "concat_color_spoofed2 = np.concatenate(color_spoofed2)\n",
    "\n",
    "# Shuffle everything in the same way\n",
    "np.random.seed(0)\n",
    "np.random.shuffle(concat_token_scores_watermarked1)\n",
    "np.random.seed(0)\n",
    "np.random.shuffle(concat_token_scores_watermarked2)\n",
    "np.random.seed(0)\n",
    "np.random.shuffle(concat_token_scores_spoofed1)\n",
    "np.random.seed(0)\n",
    "np.random.shuffle(concat_token_scores_spoofed2)\n",
    "np.random.seed(0)\n",
    "np.random.shuffle(concat_color_watermarked1)\n",
    "np.random.seed(0)\n",
    "np.random.shuffle(concat_color_watermarked2)\n",
    "np.random.seed(0)\n",
    "np.random.shuffle(concat_color_spoofed1)\n",
    "np.random.seed(0)\n",
    "np.random.shuffle(concat_color_spoofed2)\n",
    "\n",
    "# Reconstruction sentences\n",
    "shuffled_token_scores_watermarked1 = np.array_split(concat_token_scores_watermarked1, 1000)\n",
    "shuffled_token_scores_watermarked2 = np.array_split(concat_token_scores_watermarked2, 1000)\n",
    "shuffled_token_scores_spoofed1 = np.array_split(concat_token_scores_spoofed1, 1000)\n",
    "shuffled_token_scores_spoofed2 = np.array_split(concat_token_scores_spoofed2, 1000)\n",
    "shuffled_color_watermarked1 = np.array_split(concat_color_watermarked1, 1000)\n",
    "shuffled_color_watermarked2 = np.array_split(concat_color_watermarked2, 1000)\n",
    "shuffled_color_spoofed1 = np.array_split(concat_color_spoofed1, 1000)\n",
    "shuffled_color_spoofed2 = np.array_split(concat_color_spoofed2, 1000)\n",
    "\n",
    "shuffled_watermarked_correlation1 = compute_correlation(shuffled_token_scores_watermarked1, shuffled_color_watermarked1)\n",
    "shuffled_watermarked_correlation2 = compute_correlation(shuffled_token_scores_watermarked2, shuffled_color_watermarked2)\n",
    "shuffled_spoofed_correlation1 = compute_correlation(shuffled_token_scores_spoofed1, shuffled_color_spoofed1)\n",
    "shuffled_spoofed_correlation2 = compute_correlation(shuffled_token_scores_spoofed2, shuffled_color_spoofed2)\n",
    "\n",
    "shuffled_watermarked_zscores = [(correlation1 - correlation2)/np.sqrt(1/300) for correlation1, correlation2 in zip(shuffled_watermarked_correlation1, shuffled_watermarked_correlation2)]\n",
    "shuffled_spoofed_zscores = [(correlation1 - correlation2)/np.sqrt(1/300) for correlation1, correlation2 in zip(shuffled_spoofed_correlation1, shuffled_spoofed_correlation2)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set Seaborn style\n",
    "sns.set_context(\"talk\")\n",
    "sns.set_style(\"white\")\n",
    "\n",
    "# Set fontsize, ticksize and legend size\n",
    "# Set font sizes globally\n",
    "plt.rcParams['axes.labelsize'] = 25  # Axes label font size\n",
    "plt.rcParams['xtick.labelsize'] = 23  # X-axis tick font size\n",
    "plt.rcParams['ytick.labelsize'] = 23  # Y-axis tick font size\n",
    "plt.rcParams['legend.fontsize'] = 20  # Legend font size\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)  # Adjust figsize as needed\n",
    "\n",
    "\n",
    "\n",
    "ax = axes[0]\n",
    "sns.histplot(watermarked_zscores, bins=50, color=\"blue\", label=\"Watermarked\", ax=ax, edgecolor=\"black\", stat=\"density\")\n",
    "sns.histplot(shuffled_watermarked_zscores, bins=50, color=\"lightblue\", label=\"Shuffled Watermarked\", ax=ax, edgecolor=\"black\", stat=\"density\")\n",
    "sns.despine(ax=ax)\n",
    "\n",
    "\n",
    "\n",
    "ax = axes[1]\n",
    "sns.histplot(spoofed_zscores, bins=50, color=\"red\", label=\"Spoofed\", ax=ax, edgecolor=\"black\", stat=\"density\")\n",
    "sns.histplot(shuffled_spoofed_zscores, bins=50, color=\"pink\", label=\"Shuffled Spoofed\", ax=ax, edgecolor=\"black\", stat=\"density\")\n",
    "sns.despine(ax=ax)\n",
    "\n",
    "# Custom legend using rectangles (boxes) instead of lines\n",
    "handles = [plt.Rectangle((0, 0), 1, 1, color=\"blue\", alpha=0.7),\n",
    "           plt.Rectangle((0, 0), 1, 1, color=\"lightblue\", alpha=0.7),\n",
    "           plt.Rectangle((0, 0), 1, 1, color=\"red\", alpha=0.7),\n",
    "           plt.Rectangle((0, 0), 1, 1, color=\"pink\", alpha=0.7)]\n",
    "\n",
    "labels = [\"$\\\\xi$-watermarked\", \"$\\\\xi$-watermarked\\n(Shuffled)\", \"Spoofed\", \"Spoofed\\n(Shuffled)\"]\n",
    "\n",
    "# Add the legend manually on the right side of the figure\n",
    "fig.legend(handles=handles, labels=labels, loc='center left', bbox_to_anchor=(0.9, 0.5), frameon=False)\n",
    "\n",
    "plt.savefig(\"figures/additional/concatenation_validation.pdf\", bbox_inches=\"tight\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mann-Whitney U rank test\n",
    "print(stats.mannwhitneyu(watermarked_zscores, shuffled_watermarked_zscores))\n",
    "print(stats.mannwhitneyu(spoofed_zscores, shuffled_spoofed_zscores))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ws2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
