{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4feaae04b6e04c43acb1c18e4a454253",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "./LTW/watermark_for_plot.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  self.selector_network.load_state_dict(torch.load(checkpoint_path))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['\\nThis week’s episode of The Makers Podcast features an episode of The Makers Podcast, produced by Dave Deans and Rebecca Miller at the Missouri History Museum. This week’s guest is Brent Schmidt, CEO at CityFund KC. Our conversation with Brent spans CityFund’s mission as a fund, through City Fund KC’s ability to drive more entrepreneurship and economic development in Kansas City. Dave and Rebecca also talk about being in the Museum, Brent’s background as a video producer and much more. This podcast is worth a listen.\\nIf you’re in the market for some fresh, affordable, fresh, handmade, affordable gifts, make sure you’re checking out Shop Now, an upstart maker that sells 100% handmade, fresh, affordable, gifts that you’ll actually enjoy. This week’s episode of The Maker Town Podcast features an']\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "gamma=0.25\n",
    "delta=3\n",
    "from watermark_for_plot import Watermark\n",
    "import torch\n",
    "from transformers import AutoTokenizer\n",
    "from watermark_for_plot import Detector\n",
    "from transformers import AutoTokenizer,AutoModelForCausalLM,OPTForCausalLM,LogitsProcessorList\n",
    "import torch\n",
    "import random\n",
    "from datasets import load_dataset\n",
    "\n",
    "torch.cuda.set_device(4)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "tokenizer=AutoTokenizer.from_pretrained(\"./LTW/models/opt-6.7b\")\n",
    "\n",
    "# if llama\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "\n",
    "model= AutoModelForCausalLM.from_pretrained(\"./LTW/models/opt-6.7b\",torch_dtype=torch.float16).to(device)\n",
    "\n",
    "dataset = load_dataset(\"json\", data_files=\"./LTW/c4_subset_500.jsonl\")\n",
    "dataset=dataset[\"train\"]\n",
    "# random.seed(888)\n",
    "random.seed(6)\n",
    "\n",
    "input_text = random.choice(dataset)\n",
    "input_text = input_text[\"text\"]\n",
    "\n",
    "wm = Watermark(checkpoint_path=\"./LTW/ckpt/tmp/selective_network_epoch0_step2000.pth\",device=device,k=6,model=model,tokenizer=tokenizer,min_new_tokens=180, max_new_tokens= 185,embed_unigram_wm=True)\n",
    "\n",
    "output=wm.generate_watermark(input_text,gamma,delta)\n",
    "output_text=output[0]\n",
    "print(output)\n",
    "output=input_text+output[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"./LTW/models/opt-1.3b\"\n",
    "ppl_model = OPTForCausalLM.from_pretrained(model_name).to(device)\n",
    "ppl_tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "def calculate_perplexity(text):\n",
    "\n",
    "    inputs = ppl_tokenizer(text, return_tensors=\"pt\", truncation=True, padding=True).to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "\n",
    "        outputs = ppl_model(**inputs, labels=inputs[\"input_ids\"])\n",
    "        \n",
    "\n",
    "        loss = outputs.loss \n",
    "        \n",
    "\n",
    "    perplexity = torch.exp(loss)  \n",
    "    return perplexity.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "./LTW/utils/detect_utils.py:106: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  self.selector_network.load_state_dict(torch.load(checkpoint_path))\n",
      "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
     ]
    }
   ],
   "source": [
    "watermark_detector = Detector(vocab=list(tokenizer.get_vocab().values()),\n",
    "                                        gamma=gamma,\n",
    "                                        tokenizer=tokenizer,\n",
    "                                        z_threshold=4,\n",
    "                                        model=model,\n",
    "                                        k=6,\n",
    "                                        checkpoint_path=\"./LTW/ckpt/tmp/selective_network_epoch0_step2000.pth\",\n",
    "                                        embed_unigram_wm=True\n",
    "                                              \n",
    "                                        )\n",
    "\n",
    "\n",
    "tokenized_input=tokenizer.encode(input_text,  return_tensors='pt',add_special_tokens=False).to(device)\n",
    "tokenized_output=tokenizer.encode(output, return_tensors='pt').to(device)\n",
    "tokenized_output=tokenized_output[0]\n",
    "tokenized_input=tokenized_input[0]\n",
    "detection_result=watermark_detector.detect(tokenized_output,tokenized_input)\n",
    "\n",
    "\n",
    "our_info={\n",
    "    \"name\":\"our_watermark\",\n",
    "    \"z_score\":detection_result[\"z_score\"],\n",
    "    \"ppl\":calculate_perplexity(output_text)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# detect kgw\n",
    "from kgw_watermark import WatermarkLogitsProcessor,WatermarkDetector\n",
    "gen_kwargs = {\n",
    "            \"do_sample\": True,\n",
    "            \"top_p\": 0.95,\n",
    "            \"top_k\": 100,\n",
    "            \"min_new_tokens\":180,\n",
    "            \"repetition_penalty\":1,\n",
    "            \"no_repeat_ngram_size\" : 8,\n",
    "            \"max_new_tokens\":185\n",
    "\n",
    "    }\n",
    "token_dict_kgw = {\n",
    "            \"tokens\":[],\n",
    "            \"green_indexs\":[],\n",
    "            \"red_indexs\":[]\n",
    "        }\n",
    "z_threshold=4\n",
    "watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),\n",
    "                                                            gamma=gamma,\n",
    "                                                            delta=delta)\n",
    "gen_kwargs[\"logits_processor\"] = LogitsProcessorList(\n",
    "                [watermark_processor]\n",
    "            )\n",
    "watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),\n",
    "                                        gamma=gamma,\n",
    "                                        tokenizer=tokenizer,\n",
    "                                        z_threshold=z_threshold)\n",
    "input_ids =tokenizer.encode(input_text, return_tensors='pt').to(device)\n",
    "attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(device)\n",
    "output=model.generate(input_ids=input_ids, attention_mask=attention_mask,pad_token_id=tokenizer.eos_token_id,**gen_kwargs)\n",
    "output_text=tokenizer.decode(output[0, input_ids.shape[1]:], skip_special_tokens=True)\n",
    "for token_id in output[0, input_ids.shape[1]:]:\n",
    "    token_dict_kgw[\"tokens\"].append(tokenizer.decode(token_id))\n",
    "\n",
    "detection_result=watermark_detector.detect(output[0,:],input_ids[0],return_green_token_mask=True)\n",
    "\n",
    "kgw_info={\n",
    "    \"name\":\"kgw\",\n",
    "    \"z_score\":detection_result[\"z_score\"],\n",
    "    \"ppl\":calculate_perplexity(output_text)\n",
    "}\n",
    "for i in range(len(detection_result[\"green_token_mask\"])):\n",
    "    if detection_result[\"green_token_mask\"][i]==True:\n",
    "        token_dict_kgw[\"green_indexs\"].append(i)\n",
    "    else:\n",
    "        token_dict_kgw[\"red_indexs\"].append(i)\n",
    "\n",
    "import json\n",
    "with open(\"./LTW/eval/plot_red_green_example/kgw.json\", 'w') as f:\n",
    "            json.dump(token_dict_kgw, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# detect sweet\n",
    "from sweet_for_plot import SweetLogitsProcessor,SweetDetector\n",
    "def calculate_entropy(model, tokenized_text):\n",
    "    with torch.no_grad():\n",
    "\n",
    "        output = model(torch.unsqueeze(tokenized_text, 0), return_dict=True)\n",
    "        probs = torch.softmax(output.logits, dim=-1)\n",
    "        entropy = -torch.where(probs > 0, probs * probs.log(), probs.new([0.0])).sum(dim=-1)\n",
    "        return entropy[0].cpu().tolist()\n",
    "gen_kwargs = {\n",
    "            \"do_sample\": True,\n",
    "            \"top_p\": 0.95,\n",
    "            \"top_k\": 100,\n",
    "            \"min_new_tokens\":180,\n",
    "            \"repetition_penalty\":1,\n",
    "            \"no_repeat_ngram_size\" : 8,\n",
    "            \"max_new_tokens\":185\n",
    "\n",
    "    }\n",
    "\n",
    "entropy_threshold=1.2 # entropy_threshold follows the paper of sweet\n",
    "sweet_processor = SweetLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),\n",
    "                                                            gamma=gamma,\n",
    "                                                            delta=delta,\n",
    "                                                            entropy_threshold=entropy_threshold)\n",
    "gen_kwargs[\"logits_processor\"] = LogitsProcessorList(\n",
    "                [sweet_processor]\n",
    "            )\n",
    "        \n",
    "watermark_detector = SweetDetector(vocab=list(tokenizer.get_vocab().values()),\n",
    "                                        gamma=gamma,\n",
    "                                        tokenizer=tokenizer,\n",
    "                                        z_threshold=z_threshold,\n",
    "                                        entropy_threshold=entropy_threshold)\n",
    "\n",
    "\n",
    "\n",
    "input_ids =tokenizer.encode(input_text, return_tensors='pt').to(device)\n",
    "attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(device)\n",
    "output=model.generate(input_ids=input_ids, attention_mask=attention_mask,pad_token_id=tokenizer.eos_token_id,**gen_kwargs)\n",
    "output_text=tokenizer.decode(output[0, input_ids.shape[1]:], skip_special_tokens=True)\n",
    "\n",
    "entropy = calculate_entropy(model, output[0])\n",
    "                        # we need to shift entropy to the right, so the first item is dummy\n",
    "entropy = [0] + entropy[:-1]\n",
    "detection_result=watermark_detector.detect(output[0,:],input_ids[0],entropy=entropy)\n",
    "\n",
    "\n",
    "sweet_info={\n",
    "    \"name\":\"sweet\",\n",
    "    \"z_score\":detection_result[\"z_score\"],\n",
    "    \"ppl\":calculate_perplexity(output_text)\n",
    "}\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load token dict from file\n",
    "import json\n",
    "with open(\"./LTW/eval/plot_red_green_example/our.json\", \"r\") as f:\n",
    "    token_dict_our = json.load(f)\n",
    "\n",
    "with open(\"./LTW/eval/plot_red_green_example/sweet.json\", \"r\") as f:\n",
    "    token_dict_sweet = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "HTML文件已生成，请查看colored_text.html\n"
     ]
    }
   ],
   "source": [
    "import nltk\n",
    "from nltk.tokenize import word_tokenize\n",
    "\n",
    "def generate_colored_text(token_dict, info):\n",
    "\n",
    "    tokens = [token.encode().decode('unicode_escape') for token in token_dict[\"tokens\"]]\n",
    "    \n",
    "    # 分词\n",
    "    tokens = word_tokenize(\" \".join(tokens))\n",
    "    \n",
    "    # 生成HTML内容\n",
    "    html_content = f\"\"\"\n",
    "    <div class=\"container\">\n",
    "        <div class=\"info\">{info['name']} (z-score: {info['z_score']:.2f}, ppl: {info['ppl']:.2f})</div>\n",
    "    \"\"\"\n",
    "    \n",
    "    count_wrongpai=0\n",
    "    for i, token in enumerate(tokens):\n",
    "        # deal with ï¿½\n",
    "        if token==\"ï¿½\":\n",
    "            if count_wrongpai==0:\n",
    "                count_wrongpai=1\n",
    "                continue\n",
    "            else:\n",
    "                count_wrongpai=0\n",
    "                token=\"'\"\n",
    "\n",
    "        if i in token_dict.get(\"green_indexs\", []):\n",
    "            html_content += f\"<span class='green'>{token}</span> \"\n",
    "        elif i in token_dict.get(\"red_indexs\", []):\n",
    "            html_content += f\"<span class='red'>{token}</span> \"\n",
    "        elif \"unmarked_indexs\" in token_dict and i in token_dict[\"unmarked_indexs\"]:\n",
    "            html_content += f\"<span class='gray'>{token}</span> \"\n",
    "        else:\n",
    "            html_content += f\"{token} \"\n",
    "    \n",
    "    html_content += \"\"\"\n",
    "    </div>\n",
    "    \"\"\"\n",
    "    return html_content\n",
    "\n",
    "def generate_html(input_text, token_dict_kgw, token_dict_sweet, token_dict_our, our_info, kgw_info, sweet_info):\n",
    "    # 生成HTML内容\n",
    "    html_content = \"\"\"\n",
    "    <!DOCTYPE html>\n",
    "    <html lang=\"en\">\n",
    "    <head>\n",
    "        <meta charset=\"UTF-8\">\n",
    "        <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">\n",
    "        <title>Colored Text</title>\n",
    "        <style>\n",
    "            body {\n",
    "                display: flex;\n",
    "                flex-direction: column;\n",
    "                align-items: center;\n",
    "                justify-content: center;\n",
    "                height: 100vh;\n",
    "                margin: 0;\n",
    "                background-color: #f0f0f0;\n",
    "            }\n",
    "            .container {\n",
    "                border: 1px solid #000;\n",
    "                padding: 5px;\n",
    "                margin: 5px;\n",
    "                width: 360px;\n",
    "                height: 360px;\n",
    "                background-color: #FFFFFF;\n",
    "                font-family: Arial, sans-serif;\n",
    "                overflow: auto;\n",
    "            }\n",
    "            .info {\n",
    "                font-weight: bold;\n",
    "                color: #333;\n",
    "            }\n",
    "            .green {\n",
    "                background-color: #90EE90;\n",
    "            }\n",
    "            .red {\n",
    "                background-color: #FFCCCC;\n",
    "            }\n",
    "            .gray {\n",
    "                background-color: #EEEEEE;\n",
    "            }\n",
    "            .grid {\n",
    "                 display: grid;\n",
    "                grid-template-columns: repeat(2, 1fr); /* 两列 */\n",
    "                grid-template-rows: repeat(2, 1fr); /* 两行 */\n",
    "                gap: 20px; /* 间隙 */\n",
    "                width: 780px;\n",
    "                height: 780px;\n",
    "                border: 2px solid #333;\n",
    "                padding: 20px;\n",
    "                background-color: white;\n",
    "            }\n",
    "        </style>\n",
    "    </head>\n",
    "    <body>\n",
    "    \"\"\"\n",
    " \n",
    "    \n",
    "    # Top row\n",
    "    html_content += \"\"\"\n",
    "    <div class=\"grid\">\n",
    "    \"\"\"\n",
    "    \n",
    "       # Prompt\n",
    "    html_content += f\"\"\"\n",
    "    <div class=\"container\">\n",
    "        <div class=\"info\">Prompt</div>\n",
    "        {input_text[:800]}...\n",
    "    </div>\n",
    "    \"\"\"\n",
    "\n",
    "    # KGW\n",
    "    html_content += generate_colored_text(token_dict_kgw, kgw_info)\n",
    "    \n",
    "    # Sweet\n",
    "    html_content += generate_colored_text(token_dict_sweet, sweet_info)\n",
    "\n",
    "\n",
    "    # Our\n",
    "    html_content += generate_colored_text(token_dict_our, our_info)\n",
    "    \n",
    "    html_content += \"\"\"\n",
    "    </div>\n",
    "    \"\"\"\n",
    "    \n",
    "    \n",
    "    html_content += \"\"\"\n",
    "    </body>\n",
    "    </html>\n",
    "    \"\"\"\n",
    "    return html_content\n",
    "\n",
    "\n",
    "# 生成HTML内容\n",
    "html_content = generate_html(input_text, token_dict_kgw, token_dict_sweet, token_dict_our, our_info, kgw_info, sweet_info)\n",
    "\n",
    "# 保存为HTML文件\n",
    "with open(\"./LTW/eval/plot_red_green_example/colored_text.html\", \"w\", encoding=\"utf-8\") as f:\n",
    "    f.write(html_content)\n",
    "\n",
    "print(\"HTML文件已生成，请查看colored_text.html\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "markllm_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
