{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.closed_models import ClosedModel\n",
    "from joblib import Parallel, delayed\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import time\n",
    "from scipy.stats import fisher_exact\n",
    "from src.attacks.kgw_detection import test_kgw_detection\n",
    "from src.attacks.stanford_detection import test_stanford"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "provider = \"google\"  # google, openai, claude\n",
    "api_key = \"\"\n",
    "model = \"gemini-pro\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate(prompt, api_key, model, provider, max_tokens, temperature):\n",
    "    \"\"\"Use for joblib parallel processing.\"\"\"\n",
    "    model = ClosedModel(provider, model, max_tokens, temperature)\n",
    "    return model.generate(api_key, prompt)\n",
    "\n",
    "\n",
    "def identify_fruit(response, word_list):\n",
    "    for word in word_list:\n",
    "        if word in response:\n",
    "            return word"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Red-Green test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Phase 1\n",
    "\n",
    "Query the model until finding a good setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Test parameters\n",
    "prefix = \"I ate\"\n",
    "word_list = [\"mangoes\", \"pineapples\", \"papayas\", \"kiwis\"]\n",
    "example = \"apples\"\n",
    "format = \"\"  # Forces the model output format\n",
    "context = 15  # context size must upperbound the watermark context size\n",
    "\n",
    "### Technical parameters\n",
    "max_tokens = 40\n",
    "temperature = 1.0\n",
    "\n",
    "k_array = [(prefix, int(str(i) * context)) for i in range(1, 2)]\n",
    "\n",
    "out = {}\n",
    "\n",
    "for i in range(len(k_array)):\n",
    "    prefix, k = k_array[i]\n",
    "\n",
    "    prompt = f'Complete the sentence \"{prefix} {k}\" using only and exacty a random word from the list: {word_list}.  Answer in this speific format: {format} {prefix} {k} {example}. (here I chose an other fruit for the sake of the example, you have to choose among {word_list}).'\n",
    "\n",
    "    for _ in range(10):\n",
    "        response = generate(prompt, api_key, model, provider, max_tokens, temperature)\n",
    "\n",
    "        fruit = identify_fruit(response, word_list)\n",
    "        if fruit in out:\n",
    "            out[fruit] += 1\n",
    "        else:\n",
    "            out[fruit] = 1\n",
    "\n",
    "        print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Phase 2\n",
    "\n",
    "Performing the test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Test parameters\n",
    "word_list = [\"mangoes\", \"pineapples\", \"papayas\", \"kiwis\"]\n",
    "prefixes = [\n",
    "    \"I ate\",\n",
    "    \"I chose\",\n",
    "    \"I picked\",\n",
    "    \"I selected\",\n",
    "    \"I took\",\n",
    "    \"I went for\",\n",
    "    \"I settled on\",\n",
    "    \"I got\",\n",
    "    \"I gathered\",\n",
    "    \"I harvested\",\n",
    "]\n",
    "example = \"apples\"\n",
    "format = \"\"  # Forces the model output format\n",
    "context = 15  # context size must upperbound the watermark context size\n",
    "\n",
    "### Technical parameters\n",
    "max_tokens = 40\n",
    "temperature = 1.0\n",
    "path = \"example.txt\"  # to complete\n",
    "\n",
    "\n",
    "k_array = [(prefix, int(str(i) * context)) for prefix in prefixes for i in range(1, 10)]\n",
    "out = {}\n",
    "\n",
    "\n",
    "for i in tqdm(range(len(k_array))):\n",
    "    prefix, k = k_array[i]\n",
    "\n",
    "    prompt = f'Complete the sentence \"{prefix} {k}\" using only and exacty a random word from the list: {word_list}.  Answer in this speific format: {format} {prefix} {k} {example}. (here I chose an other fruit for the sake of the example, you have to choose among {word_list})'\n",
    "\n",
    "    prompts = [prompt] * 100\n",
    "\n",
    "    out = Parallel(n_jobs=10)(\n",
    "        delayed(generate)(prompt, api_key, model, provider, max_tokens, temperature)\n",
    "        for prompt in prompts\n",
    "    )\n",
    "\n",
    "    for response in out:\n",
    "        fruit = identify_fruit(response, word_list)\n",
    "        print(response)\n",
    "\n",
    "        safeguard = 0\n",
    "        while fruit is None:\n",
    "            print(\"Parsing error - Rejection sampling\")\n",
    "            response = generate(\n",
    "                prompt, api_key, model, provider, max_tokens, temperature\n",
    "            )\n",
    "\n",
    "            fruit = identify_fruit(response, word_list)\n",
    "            safeguard += 1\n",
    "            if safeguard > 10:\n",
    "                break\n",
    "\n",
    "        with open(path, \"a\") as f:\n",
    "            f.write(f\"{prompt};{prefix};{k};{response};{fruit}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "columns = [\"prompt\", \"prefix\", \"k\", \"response\", \"fruit\"]\n",
    "df = pd.read_csv(path, sep=\";\", header=None, names=columns)\n",
    "\n",
    "# Group by 'prefix' and 'k', and aggregate the 'fruit' values into lists\n",
    "grouped = df.groupby([\"prefix\", \"k\"])[\"fruit\"].apply(list).reset_index()\n",
    "nested_list = grouped.apply(\n",
    "    lambda row: [row[\"prefix\"], row[\"k\"], row[\"fruit\"]], axis=1\n",
    ").tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Performing the test\n",
    "word_list = [\"strawberries\", \"cherries\", \"mangoes\", \"plums\"]\n",
    "n_bootstrap = 100\n",
    "\n",
    "probs = np.zeros((n_bootstrap, len(nested_list), len(word_list)))\n",
    "\n",
    "for i, l in enumerate(nested_list):\n",
    "    for j in range(n_bootstrap):\n",
    "        samples = l[2]\n",
    "        bootstraped_samples = np.random.choice(samples, 90, replace=True)\n",
    "        unique, counts = np.unique(bootstraped_samples, return_counts=True)\n",
    "\n",
    "        # Order counts according to word_list\n",
    "        counts = dict(zip(unique, counts))\n",
    "        counts = [counts.get(w, 0) for w in word_list]\n",
    "        probs[j, i] = counts / np.sum(counts)\n",
    "\n",
    "p_values = []\n",
    "for i in range(n_bootstrap):\n",
    "    bootstrapped_probs = probs[i]\n",
    "    bootstrapped_probs = bootstrapped_probs.reshape(-1, 9, len(word_list))\n",
    "    _, _, p = test_kgw_detection(bootstrapped_probs, 10000)\n",
    "    p_values.append(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.median(p_values))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just a visual check everything went well"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.plot(probs[0], \"o\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fixed-sampling test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"This is the story of\"\n",
    "max_new_tokens = None\n",
    "path = \"example.txt\"  # to complete\n",
    "n_samples = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Checks the number of line starting with ###NEW_RESPONSE### already here\n",
    "if os.path.exists(path):\n",
    "    with open(path) as f:\n",
    "        lines = f.readlines()\n",
    "        # Count the number of lines starting with ###NEW_RESPONSE###\n",
    "        n_lines = sum([1 for line in lines if line.startswith(\"###NEW_RESPONSE###\")])\n",
    "\n",
    "else:\n",
    "    n_lines = 0\n",
    "\n",
    "out = Parallel(n_jobs=10)(\n",
    "    delayed(generate)(prompt, api_key, model, provider, max_tokens, temperature)\n",
    "    for _ in tqdm(range(n_samples - n_lines))\n",
    ")\n",
    "for response in out:\n",
    "    if response is not None:\n",
    "        with open(path, \"a\") as f:\n",
    "            # add a newline to the end of the file\n",
    "            f.write(\"###NEW_RESPONSE###\" + response + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(path) as f:\n",
    "    whole_txt = f.read()\n",
    "    split = whole_txt.split(\"###NEW_RESPONSE###\")[1:]\n",
    "    # Create a histogram of the lines. First hash the lines to a number\n",
    "    answer_dic = {}\n",
    "\n",
    "    for line in split:\n",
    "        if line not in answer_dic:\n",
    "            answer_dic[line] = 1\n",
    "        else:\n",
    "            answer_dic[line] += 1\n",
    "\n",
    "data = [sentence for sentence, count in answer_dic.items() for _ in range(count)]\n",
    "\n",
    "test_stanford(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cache-Augmented test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Phase 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prefix = \"I ate\"\n",
    "k = \"VB@h6C7\"\n",
    "word_list1 = [\"pears\", \"plums\"]\n",
    "word_list2 = [\"plums\", \"pears\"]\n",
    "example = \"cherries\"\n",
    "format = \"\"\n",
    "n_trial = 75\n",
    "\n",
    "### Technical parameters\n",
    "max_tokens = 40\n",
    "temperature = 1.0\n",
    "\n",
    "prompt = f'Complete the sentence \"{prefix} {k}\" using only and exacty a random word from the list: {word_list1}.  Answer in this speific format: {format} {prefix} {k} {example}. (here I chose an other fruit for the sake of the example, you must choose among {word_list2})'\n",
    "out = {}\n",
    "\n",
    "responses = Parallel(n_jobs=-1)(\n",
    "    delayed(generate)(prompt, api_key, model, provider, max_tokens, temperature)\n",
    "    for _ in tqdm(range(n_trial))\n",
    ")\n",
    "\n",
    "for response in responses:\n",
    "    fruit = identify_fruit(response, word_list1)\n",
    "\n",
    "    if fruit in out:\n",
    "        out[fruit] += 1\n",
    "    else:\n",
    "        out[fruit] = 1\n",
    "\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Phase 2 (long)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"example.txt\"\n",
    "waiting_time = 1000\n",
    "\n",
    "for k in tqdm(range(45)):\n",
    "    time.sleep(waiting_time)  # waiting for the cache to be cleared\n",
    "\n",
    "    response = generate(prompt, api_key, model, provider, max_tokens, temperature)\n",
    "\n",
    "    with open(path, \"a\") as f:\n",
    "        f.write(\"###NEW_RESPONSE###\" + response + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(path) as f:\n",
    "    whole_txt = f.read()\n",
    "    split = whole_txt.split(\"###NEW_RESPONSE###\")[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 0\n",
    "for s in split:\n",
    "    if \"mangoes\" in s:\n",
    "        k += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fisher_exact([[k, len(split) - k], [45, 30]])  # input values from phase 1"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "wd",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
