{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DetectGPT Exploration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_DIR = \"/scratch/<username>/test/watermarking-root/input\"\n",
    "OUTPUT_DIR = \"/scratch/<username>/test/watermarking-root/output\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Basic imports\n",
    "import os\n",
    "\n",
    "from tqdm import tqdm\n",
    "from statistics import mean\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from matplotlib import rc\n",
    "rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n",
    "rc('text', usetex=True)\n",
    "\n",
    "import cmasher as cmr"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the processed dataset/frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reading JSON lines from /scratch/<username>/test/watermarking-root/input/new_runs/test_len_200_opt1_3b/gen_table.jsonl: 100%|██████████| 781/781 [00:00<00:00, 19661.31lines/s]\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.insert(0, \"..\")\n",
    "from datasets import Dataset\n",
    "from utils.io import read_jsonlines, load_jsonlines\n",
    "\n",
    "data_path = '/scratch/<username>/test/watermarking-root/input/new_runs/test_len_200_opt1_3b/gen_table.jsonl'\n",
    "\n",
    "list_of_dict = load_jsonlines(data_path)\n",
    "raw_data = Dataset.from_list(list_of_dict)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### convert to pandas df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = raw_data.to_pandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Orig number of rows: 781\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>idx</th>\n",
       "      <th>truncated_input</th>\n",
       "      <th>baseline_completion</th>\n",
       "      <th>orig_sample_length</th>\n",
       "      <th>prompt_length</th>\n",
       "      <th>baseline_completion_length</th>\n",
       "      <th>no_wm_output</th>\n",
       "      <th>w_wm_output</th>\n",
       "      <th>no_wm_num_tokens_generated</th>\n",
       "      <th>w_wm_num_tokens_generated</th>\n",
       "      <th>spike_entropies</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>776</th>\n",
       "      <td>1106</td>\n",
       "      <td>With the arrival of better weather (recent rai...</td>\n",
       "      <td>\\nH.H. Haugh was described as \"Dealers in new ...</td>\n",
       "      <td>968</td>\n",
       "      <td>768</td>\n",
       "      <td>200</td>\n",
       "      <td>\\nI had a better question, though: \"Was H.H. H...</td>\n",
       "      <td>\\nI had to check on this. You see, I don't kno...</td>\n",
       "      <td>200</td>\n",
       "      <td>200</td>\n",
       "      <td>[0.417236328125, 0.94580078125, 0.94873046875,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>777</th>\n",
       "      <td>1107</td>\n",
       "      <td>You probably don’t put much thought into craft...</td>\n",
       "      <td>it ripe for misuse.\\nBecause email is a quick...</td>\n",
       "      <td>288</td>\n",
       "      <td>88</td>\n",
       "      <td>200</td>\n",
       "      <td>it easy to make mistakes.\\nHere are some of t...</td>\n",
       "      <td>it incredibly easy to send a bad email.\\nSo, ...</td>\n",
       "      <td>200</td>\n",
       "      <td>200</td>\n",
       "      <td>[0.6787109375, 0.88427734375, 0.72802734375, 0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>778</th>\n",
       "      <td>1108</td>\n",
       "      <td>Welcome to Metro’s wine guide. Every week wine...</td>\n",
       "      <td>when in fact we should, but differently. Ten ...</td>\n",
       "      <td>637</td>\n",
       "      <td>437</td>\n",
       "      <td>200</td>\n",
       "      <td>because of the calories, and we do - but we a...</td>\n",
       "      <td>.\\nThe other weekend a friend from the Sunshin...</td>\n",
       "      <td>200</td>\n",
       "      <td>200</td>\n",
       "      <td>[0.79541015625, 0.87353515625, 0.95703125, 0.9...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>779</th>\n",
       "      <td>1109</td>\n",
       "      <td>By Tracey Maclin. Tracey Maclin is a professor...</td>\n",
       "      <td>warrantless search is permissible in an emerg...</td>\n",
       "      <td>962</td>\n",
       "      <td>762</td>\n",
       "      <td>200</td>\n",
       "      <td>sweeping search of public housing units, for ...</td>\n",
       "      <td>more flexible approach would allow potential ...</td>\n",
       "      <td>200</td>\n",
       "      <td>200</td>\n",
       "      <td>[0.96875, 0.9814453125, 0.8076171875, 0.755371...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>780</th>\n",
       "      <td>1112</td>\n",
       "      <td>Financial firms are not just responsible for e...</td>\n",
       "      <td>the platform first launched. \"The management ...</td>\n",
       "      <td>498</td>\n",
       "      <td>298</td>\n",
       "      <td>200</td>\n",
       "      <td>asked about why lenders are looking for new s...</td>\n",
       "      <td>talking about his firm's launch.\\n\"The system...</td>\n",
       "      <td>200</td>\n",
       "      <td>200</td>\n",
       "      <td>[0.85888671875, 0.48583984375, 0.69140625, 0.7...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      idx                                    truncated_input  \\\n",
       "776  1106  With the arrival of better weather (recent rai...   \n",
       "777  1107  You probably don’t put much thought into craft...   \n",
       "778  1108  Welcome to Metro’s wine guide. Every week wine...   \n",
       "779  1109  By Tracey Maclin. Tracey Maclin is a professor...   \n",
       "780  1112  Financial firms are not just responsible for e...   \n",
       "\n",
       "                                   baseline_completion  orig_sample_length  \\\n",
       "776  \\nH.H. Haugh was described as \"Dealers in new ...                 968   \n",
       "777   it ripe for misuse.\\nBecause email is a quick...                 288   \n",
       "778   when in fact we should, but differently. Ten ...                 637   \n",
       "779   warrantless search is permissible in an emerg...                 962   \n",
       "780   the platform first launched. \"The management ...                 498   \n",
       "\n",
       "     prompt_length  baseline_completion_length  \\\n",
       "776            768                         200   \n",
       "777             88                         200   \n",
       "778            437                         200   \n",
       "779            762                         200   \n",
       "780            298                         200   \n",
       "\n",
       "                                          no_wm_output  \\\n",
       "776  \\nI had a better question, though: \"Was H.H. H...   \n",
       "777   it easy to make mistakes.\\nHere are some of t...   \n",
       "778   because of the calories, and we do - but we a...   \n",
       "779   sweeping search of public housing units, for ...   \n",
       "780   asked about why lenders are looking for new s...   \n",
       "\n",
       "                                           w_wm_output  \\\n",
       "776  \\nI had to check on this. You see, I don't kno...   \n",
       "777   it incredibly easy to send a bad email.\\nSo, ...   \n",
       "778  .\\nThe other weekend a friend from the Sunshin...   \n",
       "779   more flexible approach would allow potential ...   \n",
       "780   talking about his firm's launch.\\n\"The system...   \n",
       "\n",
       "     no_wm_num_tokens_generated  w_wm_num_tokens_generated  \\\n",
       "776                         200                        200   \n",
       "777                         200                        200   \n",
       "778                         200                        200   \n",
       "779                         200                        200   \n",
       "780                         200                        200   \n",
       "\n",
       "                                       spike_entropies  \n",
       "776  [0.417236328125, 0.94580078125, 0.94873046875,...  \n",
       "777  [0.6787109375, 0.88427734375, 0.72802734375, 0...  \n",
       "778  [0.79541015625, 0.87353515625, 0.95703125, 0.9...  \n",
       "779  [0.96875, 0.9814453125, 0.8076171875, 0.755371...  \n",
       "780  [0.85888671875, 0.48583984375, 0.69140625, 0.7...  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(f\"Orig number of rows: {len(df)}\")\n",
    "df.tail()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['idx', 'truncated_input', 'baseline_completion', 'orig_sample_length',\n",
       "       'prompt_length', 'baseline_completion_length', 'no_wm_output',\n",
       "       'w_wm_output', 'no_wm_num_tokens_generated',\n",
       "       'w_wm_num_tokens_generated', 'spike_entropies'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DetectGPT Model Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading mask filling model t5-3b...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading (…)ve/main/spiece.model: 100%|██████████| 792k/792k [00:00<00:00, 60.4MB/s]\n",
      "Downloading (…)/main/tokenizer.json: 100%|██████████| 1.39M/1.39M [00:00<00:00, 87.1MB/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "T5ForConditionalGeneration(\n",
       "  (shared): Embedding(32128, 1024)\n",
       "  (encoder): T5Stack(\n",
       "    (embed_tokens): Embedding(32128, 1024)\n",
       "    (block): ModuleList(\n",
       "      (0): T5Block(\n",
       "        (layer): ModuleList(\n",
       "          (0): T5LayerSelfAttention(\n",
       "            (SelfAttention): T5Attention(\n",
       "              (q): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (k): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (v): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (o): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "              (relative_attention_bias): Embedding(32, 32)\n",
       "            )\n",
       "            (layer_norm): T5LayerNorm()\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (1): T5LayerFF(\n",
       "            (DenseReluDense): T5DenseActDense(\n",
       "              (wi): Linear(in_features=1024, out_features=16384, bias=False)\n",
       "              (wo): Linear(in_features=16384, out_features=1024, bias=False)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "              (act): ReLU()\n",
       "            )\n",
       "            (layer_norm): T5LayerNorm()\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (1-23): 23 x T5Block(\n",
       "        (layer): ModuleList(\n",
       "          (0): T5LayerSelfAttention(\n",
       "            (SelfAttention): T5Attention(\n",
       "              (q): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (k): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (v): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (o): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "            )\n",
       "            (layer_norm): T5LayerNorm()\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (1): T5LayerFF(\n",
       "            (DenseReluDense): T5DenseActDense(\n",
       "              (wi): Linear(in_features=1024, out_features=16384, bias=False)\n",
       "              (wo): Linear(in_features=16384, out_features=1024, bias=False)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "              (act): ReLU()\n",
       "            )\n",
       "            (layer_norm): T5LayerNorm()\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (final_layer_norm): T5LayerNorm()\n",
       "    (dropout): Dropout(p=0.1, inplace=False)\n",
       "  )\n",
       "  (decoder): T5Stack(\n",
       "    (embed_tokens): Embedding(32128, 1024)\n",
       "    (block): ModuleList(\n",
       "      (0): T5Block(\n",
       "        (layer): ModuleList(\n",
       "          (0): T5LayerSelfAttention(\n",
       "            (SelfAttention): T5Attention(\n",
       "              (q): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (k): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (v): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (o): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "              (relative_attention_bias): Embedding(32, 32)\n",
       "            )\n",
       "            (layer_norm): T5LayerNorm()\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (1): T5LayerCrossAttention(\n",
       "            (EncDecAttention): T5Attention(\n",
       "              (q): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (k): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (v): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (o): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "            )\n",
       "            (layer_norm): T5LayerNorm()\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (2): T5LayerFF(\n",
       "            (DenseReluDense): T5DenseActDense(\n",
       "              (wi): Linear(in_features=1024, out_features=16384, bias=False)\n",
       "              (wo): Linear(in_features=16384, out_features=1024, bias=False)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "              (act): ReLU()\n",
       "            )\n",
       "            (layer_norm): T5LayerNorm()\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (1-23): 23 x T5Block(\n",
       "        (layer): ModuleList(\n",
       "          (0): T5LayerSelfAttention(\n",
       "            (SelfAttention): T5Attention(\n",
       "              (q): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (k): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (v): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (o): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "            )\n",
       "            (layer_norm): T5LayerNorm()\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (1): T5LayerCrossAttention(\n",
       "            (EncDecAttention): T5Attention(\n",
       "              (q): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (k): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (v): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "              (o): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "            )\n",
       "            (layer_norm): T5LayerNorm()\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (2): T5LayerFF(\n",
       "            (DenseReluDense): T5DenseActDense(\n",
       "              (wi): Linear(in_features=1024, out_features=16384, bias=False)\n",
       "              (wo): Linear(in_features=16384, out_features=1024, bias=False)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "              (act): ReLU()\n",
       "            )\n",
       "            (layer_norm): T5LayerNorm()\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (final_layer_norm): T5LayerNorm()\n",
       "    (dropout): Dropout(p=0.1, inplace=False)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=1024, out_features=32128, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import transformers\n",
    "\n",
    "# mask filling t5 model\n",
    "mask_filling_model_name='t5-3b'\n",
    "\n",
    "int8_kwargs = {}\n",
    "half_kwargs = {}\n",
    "half_kwargs = dict(torch_dtype=torch.bfloat16)\n",
    "print(f'Loading mask filling model {mask_filling_model_name}...')\n",
    "mask_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(mask_filling_model_name, **int8_kwargs, **half_kwargs)\n",
    "## TODO: load the base model (for generation) and base tokenizer\n",
    "try:\n",
    "    n_positions = mask_model.config.n_positions\n",
    "except AttributeError:\n",
    "    n_positions = 512\n",
    "preproc_tokenizer = transformers.AutoTokenizer.from_pretrained('t5-small', model_max_length=512)\n",
    "mask_tokenizer = transformers.AutoTokenizer.from_pretrained(mask_filling_model_name, model_max_length=n_positions)\n",
    "\n",
    "mask_model.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## perturbing text ops\n",
    "import re\n",
    "\n",
    "buffer_size = 1\n",
    "# define regex to match all <extra_id_*> tokens, where * is an integer\n",
    "pattern = re.compile(r\"<extra_id_\\d+>\")\n",
    "mask_top_p=1.0\n",
    "\n",
    "def tokenize_and_mask(text, span_length, pct, ceil_pct=False):\n",
    "    tokens = text.split(' ')\n",
    "    mask_string = '<<<mask>>>'\n",
    "\n",
    "    n_spans = pct * len(tokens) / (span_length + buffer_size * 2)\n",
    "    if ceil_pct:\n",
    "        n_spans = np.ceil(n_spans)\n",
    "    n_spans = int(n_spans)\n",
    "\n",
    "    n_masks = 0\n",
    "    while n_masks < n_spans:\n",
    "        start = np.random.randint(0, len(tokens) - span_length)\n",
    "        end = start + span_length\n",
    "        search_start = max(0, start - buffer_size)\n",
    "        search_end = min(len(tokens), end + buffer_size)\n",
    "        if mask_string not in tokens[search_start:search_end]:\n",
    "            tokens[start:end] = [mask_string]\n",
    "            n_masks += 1\n",
    "    \n",
    "    # replace each occurrence of mask_string with <extra_id_NUM>, where NUM increments\n",
    "    num_filled = 0\n",
    "    for idx, token in enumerate(tokens):\n",
    "        if token == mask_string:\n",
    "            tokens[idx] = f'<extra_id_{num_filled}>'\n",
    "            num_filled += 1\n",
    "    assert num_filled == n_masks, f\"num_filled {num_filled} != n_masks {n_masks}\"\n",
    "    text = ' '.join(tokens)\n",
    "    return text\n",
    "\n",
    "def count_masks(texts):\n",
    "    return [len([x for x in text.split() if x.startswith(\"<extra_id_\")]) for text in texts]\n",
    "\n",
    "# replace each masked span with a sample from T5 mask_model\n",
    "def replace_masks(texts):\n",
    "    n_expected = count_masks(texts)\n",
    "    stop_id = mask_tokenizer.encode(f\"<extra_id_{max(n_expected)}>\")[0]\n",
    "    tokens = mask_tokenizer(texts, return_tensors=\"pt\", padding=True).cuda()\n",
    "    outputs = mask_model.generate(**tokens, max_length=150, do_sample=True, top_p=mask_top_p, num_return_sequences=1, eos_token_id=stop_id)\n",
    "    return mask_tokenizer.batch_decode(outputs, skip_special_tokens=False)\n",
    "\n",
    "\n",
    "def extract_fills(texts):\n",
    "    # remove <pad> from beginning of each text\n",
    "    texts = [x.replace(\"<pad>\", \"\").replace(\"</s>\", \"\").strip() for x in texts]\n",
    "\n",
    "    # return the text in between each matched mask token\n",
    "    extracted_fills = [pattern.split(x)[1:-1] for x in texts]\n",
    "\n",
    "    # remove whitespace around each fill\n",
    "    extracted_fills = [[y.strip() for y in x] for x in extracted_fills]\n",
    "\n",
    "    return extracted_fills\n",
    "\n",
    "\n",
    "def apply_extracted_fills(masked_texts, extracted_fills):\n",
    "    # split masked text into tokens, only splitting on spaces (not newlines)\n",
    "    tokens = [x.split(' ') for x in masked_texts]\n",
    "\n",
    "    n_expected = count_masks(masked_texts)\n",
    "\n",
    "    # replace each mask token with the corresponding fill\n",
    "    for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)):\n",
    "        if len(fills) < n:\n",
    "            tokens[idx] = []\n",
    "        else:\n",
    "            for fill_idx in range(n):\n",
    "                text[text.index(f\"<extra_id_{fill_idx}>\")] = fills[fill_idx]\n",
    "\n",
    "    # join tokens back into text\n",
    "    texts = [\" \".join(x) for x in tokens]\n",
    "    return texts\n",
    "\n",
    "def perturb_texts_(texts, span_length, pct, ceil_pct=False):\n",
    "    masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts]\n",
    "    raw_fills = replace_masks(masked_texts)\n",
    "    extracted_fills = extract_fills(raw_fills)\n",
    "    perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)\n",
    "\n",
    "    # Handle the fact that sometimes the model doesn't generate the right number of fills and we have to try again\n",
    "    attempts = 1\n",
    "    while '' in perturbed_texts:\n",
    "        idxs = [idx for idx, x in enumerate(perturbed_texts) if x == '']\n",
    "        print(f'WARNING: {len(idxs)} texts have no fills. Trying again [attempt {attempts}].')\n",
    "        masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for idx, x in enumerate(texts) if idx in idxs]\n",
    "        raw_fills = replace_masks(masked_texts)\n",
    "        extracted_fills = extract_fills(raw_fills)\n",
    "        new_perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)\n",
    "        for idx, x in zip(idxs, new_perturbed_texts):\n",
    "            perturbed_texts[idx] = x\n",
    "        attempts += 1\n",
    "\n",
    "    return perturbed_texts\n",
    "\n",
    "\n",
    "def perturb_texts(texts, span_length, pct, ceil_pct=False):\n",
    "    chunk_size = chunk_size\n",
    "    if '11b' in mask_filling_model_name:\n",
    "        chunk_size //= 2\n",
    "\n",
    "    outputs = []\n",
    "    for i in tqdm.tqdm(range(0, len(texts), chunk_size), desc=\"Applying perturbations\"):\n",
    "        outputs.extend(perturb_texts_(texts[i:i + chunk_size], span_length, pct, ceil_pct=ceil_pct))\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Get perturbation results\n",
    "import functools\n",
    "\n",
    "chunk_size = 20\n",
    "pct_words_masked = 0.3\n",
    "n_perturbation_rounds = 1\n",
    "\n",
    "# Get the log likelihood of each text under the base_model\n",
    "def get_ll(text):\n",
    "    with torch.no_grad():\n",
    "        tokenized = base_tokenizer(text, return_tensors=\"pt\").cuda()\n",
    "        labels = tokenized.input_ids\n",
    "        return -base_model(**tokenized, labels=labels).loss.item()\n",
    "\n",
    "\n",
    "def get_lls(texts):\n",
    "    return [get_ll(text) for text in texts]\n",
    "\n",
    "\n",
    "def get_perturbation_results(span_length=10, n_perturbations=1, n_samples=500):\n",
    "    mask_model.cuda()\n",
    "\n",
    "    torch.manual_seed(0)\n",
    "    np.random.seed(0)\n",
    "\n",
    "    results = []\n",
    "    original_text = df[\"baseline_completion\"]\n",
    "    sampled_text = df[\"w_wm_output\"]\n",
    "\n",
    "    perturb_fn = functools.partial(perturb_texts, span_length=span_length, pct=pct_words_masked)\n",
    "\n",
    "    p_sampled_text = perturb_fn([x for x in sampled_text for _ in range(n_perturbations)])\n",
    "    p_original_text = perturb_fn([x for x in original_text for _ in range(n_perturbations)])\n",
    "    for _ in range(n_perturbation_rounds - 1):\n",
    "        try:\n",
    "            p_sampled_text, p_original_text = perturb_fn(p_sampled_text), perturb_fn(p_original_text)\n",
    "        except AssertionError:\n",
    "            break\n",
    "\n",
    "    assert len(p_sampled_text) == len(sampled_text) * n_perturbations, f\"Expected {len(sampled_text) * n_perturbations} perturbed samples, got {len(p_sampled_text)}\"\n",
    "    assert len(p_original_text) == len(original_text) * n_perturbations, f\"Expected {len(original_text) * n_perturbations} perturbed samples, got {len(p_original_text)}\"\n",
    "\n",
    "    for idx in range(len(original_text)):\n",
    "        results.append({\n",
    "            \"original\": original_text[idx],\n",
    "            \"sampled\": sampled_text[idx],\n",
    "            \"perturbed_sampled\": p_sampled_text[idx * n_perturbations: (idx + 1) * n_perturbations],\n",
    "            \"perturbed_original\": p_original_text[idx * n_perturbations: (idx + 1) * n_perturbations]\n",
    "        })\n",
    "\n",
    "    mask_model.cpu()\n",
    "\n",
    "    for res in tqdm.tqdm(results, desc=\"Computing log likelihoods\"):\n",
    "        p_sampled_ll = get_lls(res[\"perturbed_sampled\"])\n",
    "        p_original_ll = get_lls(res[\"perturbed_original\"])\n",
    "        res[\"original_ll\"] = get_ll(res[\"original\"])\n",
    "        res[\"sampled_ll\"] = get_ll(res[\"sampled\"])\n",
    "        res[\"all_perturbed_sampled_ll\"] = p_sampled_ll\n",
    "        res[\"all_perturbed_original_ll\"] = p_original_ll\n",
    "        res[\"perturbed_sampled_ll\"] = np.mean(p_sampled_ll)\n",
    "        res[\"perturbed_original_ll\"] = np.mean(p_original_ll)\n",
    "        res[\"perturbed_sampled_ll_std\"] = np.std(p_sampled_ll) if len(p_sampled_ll) > 1 else 1\n",
    "        res[\"perturbed_original_ll_std\"] = np.std(p_original_ll) if len(p_original_ll) > 1 else 1\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from sklearn.metrics import roc_curve, precision_recall_curve, auc\n",
    "\n",
    "pct_words_masked = 0.3\n",
    "\n",
    "def get_roc_metrics(real_preds, sample_preds):\n",
    "    fpr, tpr, _ = roc_curve([0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds)\n",
    "    roc_auc = auc(fpr, tpr)\n",
    "    return fpr.tolist(), tpr.tolist(), float(roc_auc)\n",
    "\n",
    "\n",
    "def get_precision_recall_metrics(real_preds, sample_preds):\n",
    "    precision, recall, _ = precision_recall_curve([0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds)\n",
    "    pr_auc = auc(recall, precision)\n",
    "    return precision.tolist(), recall.tolist(), float(pr_auc)\n",
    "\n",
    "def run_perturbation_experiment(results, criterion, span_length=10, n_perturbations=1, n_samples=500):\n",
    "    # compute diffs with perturbed\n",
    "    predictions = {'real': [], 'samples': []}\n",
    "    for res in results:\n",
    "        if criterion == 'd':\n",
    "            predictions['real'].append(res['original_ll'] - res['perturbed_original_ll'])\n",
    "            predictions['samples'].append(res['sampled_ll'] - res['perturbed_sampled_ll'])\n",
    "        elif criterion == 'z':\n",
    "            if res['perturbed_original_ll_std'] == 0:\n",
    "                res['perturbed_original_ll_std'] = 1\n",
    "                print(\"WARNING: std of perturbed original is 0, setting to 1\")\n",
    "                print(f\"Number of unique perturbed original texts: {len(set(res['perturbed_original']))}\")\n",
    "                print(f\"Original text: {res['original']}\")\n",
    "            if res['perturbed_sampled_ll_std'] == 0:\n",
    "                res['perturbed_sampled_ll_std'] = 1\n",
    "                print(\"WARNING: std of perturbed sampled is 0, setting to 1\")\n",
    "                print(f\"Number of unique perturbed sampled texts: {len(set(res['perturbed_sampled']))}\")\n",
    "                print(f\"Sampled text: {res['sampled']}\")\n",
    "            predictions['real'].append((res['original_ll'] - res['perturbed_original_ll']) / res['perturbed_original_ll_std'])\n",
    "            predictions['samples'].append((res['sampled_ll'] - res['perturbed_sampled_ll']) / res['perturbed_sampled_ll_std'])\n",
    "\n",
    "    fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples'])\n",
    "    p, r, pr_auc = get_precision_recall_metrics(predictions['real'], predictions['samples'])\n",
    "    name = f'perturbation_{n_perturbations}_{criterion}'\n",
    "    print(f\"{name} ROC AUC: {roc_auc}, PR AUC: {pr_auc}\")\n",
    "    return {\n",
    "        'name': name,\n",
    "        'predictions': predictions,\n",
    "        'info': {\n",
    "            'pct_words_masked': pct_words_masked,\n",
    "            'span_length': span_length,\n",
    "            'n_perturbations': n_perturbations,\n",
    "            'n_samples': n_samples,\n",
    "        },\n",
    "        'raw_results': results,\n",
    "        'metrics': {\n",
    "            'roc_auc': roc_auc,\n",
    "            'fpr': fpr,\n",
    "            'tpr': tpr,\n",
    "        },\n",
    "        'pr_metrics': {\n",
    "            'pr_auc': pr_auc,\n",
    "            'precision': p,\n",
    "            'recall': r,\n",
    "        },\n",
    "        'loss': 1 - pr_auc,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## DetectGPT Running: get perturnation results\n",
    "import json\n",
    "\n",
    "outputs = []\n",
    "\n",
    "n_perturbation_list = [1, 10, 100]\n",
    "span_length = 2\n",
    "n_samples = 500\n",
    "\n",
    "for n_perturbations in n_perturbation_list:\n",
    "    perturbation_results = get_perturbation_results(span_length, n_perturbations, n_samples)\n",
    "    for perturbation_mode in ['d', 'z']:\n",
    "        output = run_perturbation_experiment(\n",
    "            perturbation_results, perturbation_mode, span_length=span_length, n_perturbations=n_perturbations, n_samples=n_samples)\n",
    "        outputs.append(output)\n",
    "        with open(os.path.join(OUTPUT_DIR, \"detect-gpt\", f\"perturbation_{n_perturbations}_{perturbation_mode}_results.json\"), \"w\") as f:\n",
    "            json.dump(output, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "watermarking-dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "0a3c400d5c70e043163c46602e00ff3a948562bdc78a022eac13a63666386981"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
