{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You first need to run the corresponding test using the launch scripts for this notebook to be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.test_formatter import apply_1test\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pickle as pkl\n",
    "from src.attacks.kgw_detection import estimate_context_size\n",
    "from src.utils import sample_vectorized\n",
    "import pickle \n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from src.delta_estimation import estimate_delta_grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Helper function to perform tests and append results\n",
    "debug = False\n",
    "n_queries = None\n",
    "disable_watermark_every=0\n",
    "\n",
    "def perform_test(df,wm_scheme, model_name, temperature, n_samples, delta=None, gamma=None, key_size=None, seeding_scheme=None, alpha=None, max_new_tokens=50, disable_watermark_every=0):\n",
    "    p_values = [\n",
    "        apply_1test(\n",
    "            test=test,\n",
    "            wm_scheme=wm_scheme,\n",
    "            model_name=model_name,\n",
    "            temperature=temperature,\n",
    "            n_samples=n_samples,\n",
    "            delta=delta,\n",
    "            gamma=gamma,\n",
    "            key_size=key_size,\n",
    "            seeding_scheme=seeding_scheme,\n",
    "            custom_name=custom_name,\n",
    "            alpha=alpha,\n",
    "            max_new_tokens=max_new_tokens,\n",
    "            debug=debug,\n",
    "            disable_watermark_every=disable_watermark_every,\n",
    "        ) for _ in range(num_runs)\n",
    "    ]\n",
    "    for p in p_values:\n",
    "        # Append results to dataframe\n",
    "        df.loc[len(df.index)] = {\n",
    "            \"wm_scheme\": wm_scheme,\n",
    "            \"model_name\": model_name,\n",
    "            \"temperature\": temperature,\n",
    "            \"n_samples\": n_samples,\n",
    "            \"delta\": delta,\n",
    "            \"gamma\": gamma,\n",
    "            \"key_size\": key_size,\n",
    "            \"seeding_scheme\": seeding_scheme,\n",
    "            \"p_value\": p,\n",
    "            \"alpha\": alpha,\n",
    "        }\n",
    "    \n",
    "# Main testing loop\n",
    "def run_test(df):\n",
    "    for n_sample in n_samples:\n",
    "        for model_name in model_names:\n",
    "            model_name = model_name.replace(\"/\", \"_\")\n",
    "            for watermark in watermarks:\n",
    "                if watermark == \"no_watermark\":\n",
    "                    for temperature in temperatures:\n",
    "                        perform_test(df,watermark, model_name, temperature, n_sample, disable_watermark_every=disable_watermark_every)\n",
    "                        \n",
    "                elif watermark == \"KGW\":\n",
    "                    temperature = 1.0\n",
    "                    for delta in deltas:\n",
    "                        for gamma in gammas:\n",
    "                            for seeding in seeding_scheme:\n",
    "                                perform_test(df,watermark, model_name, temperature, n_sample, delta=delta, gamma=gamma, seeding_scheme=seeding, disable_watermark_every=disable_watermark_every)\n",
    "                            \n",
    "                elif watermark == \"stanford\":\n",
    "                    temperature = 1.0\n",
    "                    for key_size in [256, 2048]:\n",
    "                        perform_test(df,watermark, model_name, temperature, n_sample, key_size=key_size, disable_watermark_every=disable_watermark_every)\n",
    "                        \n",
    "                elif watermark == \"dipmark\":\n",
    "                    temperature = 1.0\n",
    "                    for alpha in alphas:\n",
    "                        perform_test(df,watermark, model_name, temperature, n_sample, alpha=alpha, disable_watermark_every=disable_watermark_every)\n",
    "                        \n",
    "                elif watermark == \"cache_dipmark\":\n",
    "                    temperature = 1.0\n",
    "                    for alpha in alphas:\n",
    "                        perform_test(df,watermark, model_name, temperature, n_sample, alpha=alpha, disable_watermark_every=disable_watermark_every)\n",
    "                        \n",
    "                elif watermark == \"cache_delta_reweight\":\n",
    "                    temperature = 1.0\n",
    "                    perform_test(df,watermark, model_name, temperature, n_sample, disable_watermark_every=disable_watermark_every)\n",
    "                        \n",
    "                elif watermark == \"DeltaReweight\":\n",
    "                    \n",
    "                    temperature = 1.0\n",
    "                    perform_test(df,watermark, model_name, temperature, n_sample,  disable_watermark_every=disable_watermark_every)\n",
    "                        \n",
    "                        \n",
    "def get_test_results():\n",
    "    # Create dataframe to store results\n",
    "    df = pd.DataFrame(columns=[\"wm_scheme\", \"model_name\", \"temperature\", \"n_samples\", \"delta\", \"gamma\", \"key_size\", \"seeding_scheme\", \"p_value\", \"alpha\"])\n",
    "    run_test(df)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Red-Green test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This allows to load the results from the Red-Green test (KGW)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = \"KGW\"\n",
    "\n",
    "watermarks = [\"no_watermark\", \"KGW\", \"stanford\", \"dipmark\"]\n",
    "deltas = [2.0,4.0]\n",
    "gammas = [0.25, 0.5]\n",
    "temperatures = [0.7, 1.0]\n",
    "model_names = [ \"meta-llama/Llama-2-13b-chat-hf\", \"meta-llama/Llama-2-70b-chat-hf\", \"mistralai/Mistral-7B-Instruct-v0.1\", \"meta-llama/Llama-2-7b-chat-hf\", \"meta-llama_Meta-Llama-3-8B-Instruct\"]\n",
    "model_names=  [\"meta-llama_Meta-Llama-3-8B-Instruct\"]\n",
    "n_samples = [0, 100]\n",
    "alphas = [0.5,0.3]\n",
    "custom_name=\"\"\n",
    "seeding_scheme = [\"lefthash\", \"selfhash\"]\n",
    "num_permutations = 100\n",
    "num_runs = 1  # Number of times to run each test to get median\n",
    "debug = False\n",
    "disable_watermark_every=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataframe to hold results\n",
    "df = get_test_results()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df =df.fillna(-1)\n",
    "d = df.groupby([\"wm_scheme\", \"model_name\", \"temperature\", \"n_samples\", \"delta\",\"gamma\",\"key_size\",\"seeding_scheme\",\"alpha\"]).agg({\"p_value\": [\"median\", \"std\"]})\n",
    "d.reset_index(inplace=True)\n",
    "for model_name in d[\"model_name\"].unique():\n",
    "    display((d[(d[\"model_name\"]== model_name) & (d[\"n_samples\"]==100)]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fixed-Sampling test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This allows to load the results from the Fixed-Sampling test (stanford)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = \"stanford\"\n",
    "max_new_tokens = 50\n",
    "watermarks = [\"no_watermark\", \"KGW\", \"stanford\",\"dipmark\"]\n",
    "alphas = [0.3, 0.5]\n",
    "deltas = [2.0,4.0]\n",
    "gammas = [0.25, 0.5]\n",
    "temperatures = [0.7, 1.0]\n",
    "model_names = [\"meta-llama/Llama-2-70b-chat-hf\", \"meta-llama/Llama-2-13b-chat-hf\", \"mistralai/Mistral-7B-Instruct-v0.1\", \"meta-llama/Llama-2-7b-chat-hf\", \"meta-llama_Meta-Llama-3-8B-Instruct\"]\n",
    "n_samples = [0]\n",
    "custom_name=\"\"\n",
    "seeding_scheme = [\"lefthash\", \"selfhash\"]\n",
    "num_runs = 1  # Number of times to run each test to get mean and std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = get_test_results()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df =df.fillna(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = df.groupby([\"wm_scheme\", \"model_name\", \"temperature\", \"n_samples\", \"delta\",\"gamma\",\"key_size\",\"seeding_scheme\",\"alpha\"]).agg({\"p_value\": [\"median\", \"std\"]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d.reset_index(inplace=True)\n",
    "\n",
    "for model_name in d[\"model_name\"].unique():\n",
    "    display((d[d[\"model_name\"]== model_name]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cache-Augmented test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This allows to load the results from the Cache-Augmented test (cache)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = \"cache\"\n",
    "watermarks = [\"cache_dipmark\",\"cache_delta_reweight\", \"no_watermark\",\"KGW\",\"stanford\"]\n",
    "#watermarks = [\"gamma_reweight\"]\n",
    "deltas = [2.0,4.0]\n",
    "gammas = [0.25, 0.5]\n",
    "temperatures = [0.7, 1.0]\n",
    "model_names = [\"meta-llama/Llama-2-13b-chat-hf\", \"mistralai/Mistral-7B-Instruct-v0.1\", \"meta-llama/Llama-2-70b-chat-hf\", \"meta-llama/Llama-2-7b-chat-hf\", \"meta-llama_Meta-Llama-3-8B-Instruct\"]\n",
    "n_samples = [75]\n",
    "custom_name=\"\"\n",
    "alphas = [0.5,0.3]\n",
    "num_runs = 50  # Number of times to run each test to get mean and std\n",
    "seeding_scheme = [\"lefthash\", \"selfhash\"]\n",
    "disable_watermark_every=10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = get_test_results()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df =df.fillna(-1)\n",
    "d = df.groupby([\"wm_scheme\", \"model_name\", \"temperature\", \"n_samples\", \"delta\",\"gamma\",\"key_size\",\"seeding_scheme\",\"alpha\"]).agg({\"p_value\": [\"median\", \"std\"]})\n",
    "d.reset_index(inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name in d[\"model_name\"].unique():\n",
    "    display((d[(d[\"model_name\"]== model_name) & (d[\"n_samples\"]==75)]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Additional results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Red-Green"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Context size estimation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To estimate the context size, you need to run the specific test (see runs/parameter_estimation/context_estimation.sh)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### LeftHash"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = \"KGW\"\n",
    "\n",
    "delta = 2.0\n",
    "gamma = 0.25\n",
    "model_names = [\"meta-llama/Llama-2-13b-chat-hf\", \"meta-llama/Llama-2-70b-chat-hf\", \"mistralai/Mistral-7B-Instruct-v0.1\"]\n",
    "\n",
    "n_samples = [0,10,100]\n",
    "temperature = 1.0\n",
    "custom_name= \"\"\n",
    "seeding_scheme = \"lefthash\"\n",
    "keys = [1]\n",
    "\n",
    "num_runs = 10  # Number of times to run each test to get median"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = {\n",
    "    \"wm_scheme\": [],\n",
    "    \"model_name\": [],\n",
    "    \"temperature\": [],\n",
    "    \"n_samples\": [],\n",
    "    \"delta\": [],\n",
    "    \"gamma\": [],\n",
    "    \"key_size\": [],\n",
    "    \"seeding_scheme\": [],\n",
    "    \"context_size\": [],\n",
    "}\n",
    "\n",
    "for model_name in model_names:\n",
    "    model_name = model_name.replace(\"/\", \"_\")\n",
    "    for n_sample in n_samples:\n",
    "    \n",
    "        with open(f\"pkl_results/context/{model_name}/KGW/{seeding_scheme}/{custom_name}_context5_gamma{gamma}_{temperature}_{keys}.pkl\", \"rb\") as f:\n",
    "            data = pickle.load(f)\n",
    "            \n",
    "        data = data[delta]\n",
    "        if n_sample != 0:\n",
    "            data = sample_vectorized(data, n_sample)\n",
    "        context_size = estimate_context_size(data, significance = .1)\n",
    "        \n",
    "        df[\"wm_scheme\"].append(\"KGW\")\n",
    "        df[\"model_name\"].append(model_name)\n",
    "        df[\"temperature\"].append(temperature)\n",
    "        df[\"n_samples\"].append(n_sample)\n",
    "        df[\"delta\"].append(delta)\n",
    "        df[\"seeding_scheme\"].append(seeding_scheme)\n",
    "        df[\"gamma\"].append(gamma)\n",
    "        df[\"key_size\"].append(None)\n",
    "        df[\"context_size\"].append(context_size)\n",
    "df = pd.DataFrame(df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Selfhash"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = \"KGW\"\n",
    "\n",
    "delta = 2.0\n",
    "gamma = 0.25\n",
    "model_names = [\"meta-llama/Llama-2-13b-chat-hf\", \"meta-llama/Llama-2-70b-chat-hf\", \"mistralai/Mistral-7B-Instruct-v0.1\"]\n",
    "\n",
    "n_samples = [0,10,100]\n",
    "temperature = 1.0\n",
    "custom_name= \"\"\n",
    "seeding_scheme = \"selfhash\"\n",
    "keys = [1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "num_runs = 10  # Number of times to run each test to get median\n",
    "df = {\n",
    "    \"wm_scheme\": [],\n",
    "    \"model_name\": [],\n",
    "    \"temperature\": [],\n",
    "    \"n_samples\": [],\n",
    "    \"delta\": [],\n",
    "    \"gamma\": [],\n",
    "    \"key_size\": [],\n",
    "    \"seeding_scheme\": [],\n",
    "    \"context_size\": [],\n",
    "}\n",
    "\n",
    "for model_name in model_names:\n",
    "    model_name = model_name.replace(\"/\", \"_\")\n",
    "    for n_sample in n_samples:\n",
    "    \n",
    "        with open(f\"pkl_results/context/{model_name}/KGW/{seeding_scheme}/{custom_name}_context5_gamma{gamma}_{temperature}_{keys}.pkl\", \"rb\") as f:\n",
    "            data = pickle.load(f)\n",
    "            \n",
    "        data = data[delta]\n",
    "        if n_sample != 0:\n",
    "            data = sample_vectorized(data, n_sample)\n",
    "        context_size = estimate_context_size(data, significance = .1)\n",
    "        \n",
    "        df[\"wm_scheme\"].append(\"KGW\")\n",
    "        df[\"model_name\"].append(model_name)\n",
    "        df[\"temperature\"].append(temperature)\n",
    "        df[\"n_samples\"].append(n_sample)\n",
    "        df[\"delta\"].append(delta)\n",
    "        df[\"seeding_scheme\"].append(seeding_scheme)\n",
    "        df[\"gamma\"].append(gamma)\n",
    "        df[\"key_size\"].append(None)\n",
    "        df[\"context_size\"].append(np.mean(context_size))\n",
    "df = pd.DataFrame(df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### $\\delta$-estimation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run the estimation of $\\delta$ you need to have run the Red-Green test first"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_names = [\"meta-llama_Llama-2-7b-chat-hf\", \"meta-llama_Llama-2-13b-chat-hf\",\"mistralai_Mistral-7B-Instruct-v0.1\",  \"meta-llama_Meta-Llama-3-8B-Instruct\", \"meta-llama_Llama-2-70b-chat-hf\"]\n",
    "custom_name = \"\"\n",
    "seeding_scheme = \"lefthash\"\n",
    "keys = [1]\n",
    "gamma = 0.25\n",
    "context = 5\n",
    "temperature= 1.0\n",
    "n_samples = [100]\n",
    "n_bootstrap = 100\n",
    "\n",
    "df = {\n",
    "    \"model\": [],\n",
    "    \"delta\": [],\n",
    "    \"n_samples\": [],\n",
    "    \"delta_hat\": [],\n",
    "}\n",
    "\n",
    "\n",
    "\n",
    "for model_name in model_names:\n",
    "    with open(f\"pkl_results/KGW/{model_name}/KGW/{seeding_scheme}/20perturbations_context{context}_gamma{gamma}_{temperature}_{keys}.pkl\", \"rb\") as f:\n",
    "        out = pkl.load(f)\n",
    "\n",
    "    delta_list = []\n",
    "    for delta_value in tqdm(range(5)):\n",
    "        \n",
    "        if delta_value in out:\n",
    "            \n",
    "            for n_sample in n_samples:\n",
    "                \n",
    "                for _ in range(n_bootstrap):\n",
    "            \n",
    "                    data = np.array(out[delta_value])\n",
    "                    \n",
    "                    if n_sample != 0:\n",
    "                        data = sample_vectorized(data, n_sample, bayesian=True)\n",
    "                    \n",
    "                    estimated_delta = estimate_delta_grad(data)\n",
    "                    \n",
    "                    df[\"model\"].append(model_name)\n",
    "                    df[\"delta\"].append(delta_value)\n",
    "                    df[\"n_samples\"].append(n_sample)\n",
    "                    df[\"delta_hat\"].append(estimated_delta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotdf = pd.DataFrame(df)\n",
    "\n",
    "\n",
    "better_model_name = {\n",
    "    \"meta-llama_Llama-2-7b-chat-hf\": \"Llama2-7B\",\n",
    "    \"meta-llama_Llama-2-13b-chat-hf\": \"Llama2-13B\",\n",
    "    \"meta-llama_Llama-2-70b-chat-hf\": \"Llama2-70B\",\n",
    "    \"google_gemma-7b-it\": \"Gemma-7b\",\n",
    "    \"meta-llama_Meta-Llama-3-8B-Instruct\": \"Llama3-8B\",\n",
    "    \"mistralai_Mistral-7B-Instruct-v0.1\": \"Mistral-7B\"\n",
    "}\n",
    "\n",
    "plotdf[\"model\"] = plotdf[\"model\"].map(better_model_name)\n",
    "\n",
    "\n",
    "labelsize=25\n",
    "ticksize=16\n",
    "legendsize=18\n",
    "\n",
    "\n",
    "sns.set_context(\"paper\", font_scale=1.8)  \n",
    "sns.set(style=\"ticks\", font_scale=1.6)  \n",
    "\n",
    "sns.set_palette(\"colorblind\")\n",
    "\n",
    "n_col = len(plotdf[\"n_samples\"].unique())\n",
    "\n",
    "g = sns.FacetGrid(plotdf, col=\"n_samples\", col_wrap=n_col, height=6, aspect=1.2)\n",
    "\n",
    "g.map_dataframe(sns.lineplot, x=\"delta\", y=\"delta_hat\", hue=\"model\", marker=\"o\", linewidth=2.5, errorbar=\"ci\")\n",
    "\n",
    "def plot_identity(x, **kwargs):\n",
    "    plt.plot(x, x, linestyle='--', color='grey', alpha=0.7)\n",
    "\n",
    "g.map(plot_identity, \"delta\", linestyle='--', color='grey', alpha=0.7, linewidth=2.5)\n",
    "\n",
    "g.add_legend(fontsize=legendsize, title_fontsize=20)\n",
    "g._legend.set_bbox_to_anchor((0.91, 0.34)) \n",
    "g._legend.set_title(\"Model\")  \n",
    "g._legend.get_title().set_fontsize(20) \n",
    "\n",
    "\n",
    "g.set_titles(col_template=\"$n$ = {col_name}\", size=labelsize)\n",
    "\n",
    "g.set_axis_labels(x_var=r\"$\\delta$\", y_var=r\"$\\hat{\\delta}$\")\n",
    "g._legend.texts[0].set_fontsize(legendsize)\n",
    "g._legend.texts[1].set_fontsize(legendsize)\n",
    "g._legend.texts[2].set_fontsize(legendsize)\n",
    "g._legend.texts[3].set_fontsize(legendsize)\n",
    "g._legend.texts[4].set_fontsize(legendsize)\n",
    "\n",
    "\n",
    "axes = g.axes.flatten()\n",
    "for ax in axes:\n",
    "    ax.set_xlabel(ax.get_xlabel(), fontsize=labelsize)\n",
    "    ax.set_ylabel(ax.get_ylabel(), fontsize=labelsize)\n",
    "\n",
    "axes[0].set_title(\"With logprobs\", fontsize=labelsize)\n",
    "\n",
    "plt.xticks(fontsize=ticksize)\n",
    "plt.yticks(fontsize=ticksize)\n",
    "\n",
    "plt.savefig(\"delta_estimation.pdf\", bbox_inches='tight', format='pdf')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cache"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### $\\alpha$ estimation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run the estimation of $\\alpha$ you need to have run the Cache-Augmented test first"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_names = [\"meta-llama/Llama-2-13b-chat-hf\", \"meta-llama/Llama-2-70b-chat-hf\", \"mistralai/Mistral-7B-Instruct-v0.1\"]\n",
    "n_samples = [50]\n",
    "custom_name= \"\"\n",
    "watermarks = [\"cache_dipmarkt\", \"cache_delta_reweight\"]\n",
    "alphas = [0.2,0.3,0.4,0.5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils import dip_reweight\n",
    "\n",
    "\n",
    "\n",
    "def estimate_alpha(model_name, custom_name, n_samples, wm_scheme, alpha):\n",
    "    path = f\"pkl_results/cache/{model_name}/no_watermark/{custom_name}_1.0.pkl\"\n",
    "    delta = 0\n",
    "\n",
    "        \n",
    "    with open(path, \"rb\") as f:\n",
    "        out = pickle.load(f)\n",
    "\n",
    "    data = out[delta]\n",
    "\n",
    "    data = data/np.sum(data, axis = 1)[:, None] # rejection sampling\n",
    "    \n",
    "    is_delta = True\n",
    "\n",
    "    bound = 0\n",
    "    \n",
    "    for n_trial in range(data.shape[0]):\n",
    "        og_probs = np.array([data[n_trial]]) # We only need one t2        \n",
    "\n",
    "        if wm_scheme == \"cache_dipmarkt\":\n",
    "            probs = dip_reweight(og_probs, alpha)\n",
    "        elif wm_scheme == \"cache_delta_reweight\":\n",
    "            choice = np.random.choice([0, 1], p = og_probs[0])\n",
    "            probs = np.array([[choice, 1-choice]])\n",
    "                        \n",
    "        probs = sample_vectorized(probs, n_samples, return_samples=False, bayesian=False)\n",
    "\n",
    "        og_probs, probs = og_probs[0], probs[0]\n",
    "        chosen = np.argmax(og_probs)\n",
    "        \n",
    "        p,q = og_probs[chosen], probs[chosen]\n",
    "        \n",
    "        \n",
    "                \n",
    "        if q == 1:\n",
    "            pass\n",
    "        elif q == 0:\n",
    "            return True, None\n",
    "        elif q > p:\n",
    "            return False, q-p\n",
    "        else:\n",
    "            if np.abs(q-(2*p - 1) ) < 1/np.sqrt(n_samples):\n",
    "                is_delta = False\n",
    "                bound = max(1 - p, bound)\n",
    "            else:\n",
    "                is_delta = False\n",
    "                bound = max(p-q, bound)\n",
    "                \n",
    "    return is_delta, bound\n",
    "\n",
    "\n",
    "df = {\n",
    "    \"model_name\": [],\n",
    "    \"n_samples\": [],\n",
    "    \"watermark\": [],\n",
    "    \"alpha\": [],\n",
    "    \"is_delta\": [],\n",
    "    \"alpha_hat\": [],\n",
    "\n",
    "}\n",
    "\n",
    "for model_name in model_names:\n",
    "    model_name = model_name.replace(\"/\", \"_\")\n",
    "    for n_sample in n_samples:\n",
    "        for watermark in watermarks:\n",
    "            for alpha in alphas:\n",
    "                result, diff = estimate_alpha(model_name, custom_name, n_sample, watermark, alpha)\n",
    "                \n",
    "                df[\"model_name\"].append(model_name)\n",
    "                df[\"n_samples\"].append(n_sample)\n",
    "                df[\"watermark\"].append(watermark)\n",
    "                df[\"alpha\"].append(alpha)\n",
    "                df[\"is_delta\"].append(result)\n",
    "                df[\"alpha_hat\"].append(diff)\n",
    "df = pd.DataFrame(df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[(df[\"n_samples\"] == 50) & (df[\"is_delta\"] == False)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fixed-Sampling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run the estimation of $n_{key}$ you need to have run the Cache-Augmented test first"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Key size estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = \"stanford\"\n",
    "max_new_tokens = 50\n",
    "model_names = [\"meta-llama_Llama-2-70b-chat-hf\", \"meta-llama_Llama-2-13b-chat-hf\", \"mistralai_Mistral-7B-Instruct-v0.1\", \"meta-llama_Llama-2-7b-chat-hf\", \"meta-llama_Meta-Llama-3-8B-Instruct\"]\n",
    "custom_name=\"finalResults\"\n",
    "key_sizes = [256,2048]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.optimize import curve_fit\n",
    "from scipy import stats\n",
    "\n",
    "def rarefaction(x, n_key):\n",
    "    return n_key * (1 - (1 - 1/n_key)**x)\n",
    "\n",
    "def rarefaction_curve(data, num_samples=1000, trials=500):\n",
    "\n",
    "    max_samples = min(num_samples, len(data))\n",
    "    unique_counts = np.zeros(max_samples)\n",
    "    \n",
    "    # Perform multiple trials to average the curve\n",
    "    for _ in range(trials):\n",
    "        np.random.shuffle(data)\n",
    "        seen_sentences = set()\n",
    "        cumulative_uniques = []\n",
    "\n",
    "        for i in range(1, max_samples + 1):\n",
    "            seen_sentences.add(data[i - 1])\n",
    "            cumulative_uniques.append(len(seen_sentences))\n",
    "        \n",
    "        unique_counts += np.array(cumulative_uniques)\n",
    "\n",
    "    unique_counts /= trials\n",
    "    return unique_counts\n",
    "\n",
    "def fit_model(data, model, x):\n",
    "    resampled_curve = rarefaction_curve(data)\n",
    "\n",
    "    try:\n",
    "        # Fit model on resampled data\n",
    "        popt, pcov = curve_fit(model, x, resampled_curve)\n",
    "\n",
    "        # Calculate the standard deviations of the parameters from the diagonal of the covariance matrix\n",
    "        perr = np.sqrt(np.diag(pcov))\n",
    "\n",
    "        # Calculate confidence intervals assuming a normal distribution of the parameter estimates\n",
    "        alpha = 0.05  # 95% confidence interval -> 100*(1-alpha)%\n",
    "        n = len(resampled_curve)  # number of data points\n",
    "        p = len(popt)  # number of parameters\n",
    "        dof = max(0, n - p)  # degrees of freedom\n",
    "        # Student's t-distribution quantile for the dof and confidence level\n",
    "        tval = stats.t.ppf(1.0-alpha/2., dof)\n",
    "        ci = tval * perr\n",
    "\n",
    "        return popt, ci\n",
    "    except Exception as e:\n",
    "        print(\"Error:\", e)\n",
    "        return None  # Return None for failed fits\n",
    "\n",
    "def fit_model(data, model, x, n_bootstraps=100):\n",
    "    bootstrap_params = []\n",
    "    for _ in range(n_bootstraps):\n",
    "        # Resample the data\n",
    "        resampled_data = rarefaction_curve(data)\n",
    "        try:\n",
    "            # Fit model on resampled data\n",
    "            popt, _ = curve_fit(model, x, resampled_data)\n",
    "            bootstrap_params.append(popt)\n",
    "        except RuntimeError:\n",
    "            # Handle cases where the model fitting fails\n",
    "            continue\n",
    "\n",
    "    bootstrap_params = np.array(bootstrap_params)\n",
    "    # Estimate the parameter confidence intervals\n",
    "    ci_lower = np.percentile(bootstrap_params, 2.5, axis=0)\n",
    "    ci_upper = np.percentile(bootstrap_params, 97.5, axis=0)\n",
    "\n",
    "    return ci_lower, ci_upper, np.mean(bootstrap_params, axis=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name in model_names:\n",
    "    for key_size in key_sizes:\n",
    "        path = f\"pkl_results/stanford/{model_name}/Stanford/finalResults_maxtokens50_1.0_[1]_{key_size}.pkl\"\n",
    "        \n",
    "        with open(path) as f:\n",
    "            whole_txt = f.read()\n",
    "            split = whole_txt.split(\"###NEW_RESPONSE###\")[1:]\n",
    "            # Create a histogram of the lines. First hash the lines to a number\n",
    "            answer_dic = {}\n",
    "                \n",
    "            for line in split:\n",
    "            \n",
    "                if line not in answer_dic:\n",
    "                    answer_dic[line] = 1\n",
    "                else:\n",
    "                    answer_dic[line] += 1\n",
    "            \n",
    "        data = [sentence for sentence, count in answer_dic.items() for _ in range(count)]\n",
    "        \n",
    "        cil, ciu, n = fit_model(data, rarefaction, np.arange(1, 1001))\n",
    "        print(n, ciu-cil, model_name, key_size)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "wd",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
