{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dEPxhCI9KUEr"
   },
   "source": [
    "# 1. Install MarkLLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "_u3q60di584x",
    "outputId": "ba6faa0e-09ad-4653-b937-725ff83dbd21"
   },
   "outputs": [],
   "source": [
    "# !pip install -r requirements.txt\n",
    "\n",
    "download_models = False\n",
    "\n",
    "opt_path = \"/workspace/panleyi/models/facebook/opt-1.3b\"\n",
    "llama_path = \"models/llama-7b\"\n",
    "nllb_path = \"models/nllb-200-distilled-600M\"\n",
    "bert_large_uncased_path = \"models/compositional-bert-large-uncased/\"\n",
    "t5_path = \"models/t5-v1_1-xxl\"\n",
    "starcoder_path = \"models/starcoder\"\n",
    "\n",
    "if download_models:\n",
    "    !pip install -U huggingface_hub\n",
    "    !export HF_ENDPOINT=https://hf-mirror.com\n",
    "    HUGGINGFACE_TOKEN = \"\"\n",
    "    !huggingface-cli login --token $HUGGINGFACE_TOKEN\n",
    "    !huggingface-cli download --resume-download facebook/opt-1.3b --local-dir {opt_path}\n",
    "    !huggingface-cli download --resume-download princeton-nlp/Sheared-LLaMA-1.3B --local-dir {llama_path}\n",
    "    !huggingface-cli download --resume-download facebook/nllb-200-distilled-600M --local-dir {nllb_path}\n",
    "    !huggingface-cli download --resume-download google-bert/bert-large-uncased --local-dir {bert_large_uncased_path}\n",
    "    !huggingface-cli download --resume-download google/t5-v1_1-xxl --local-dir {t5_path}\n",
    "    !huggingface-cli download --resume-download bigcode/starcoder --local-dir {starcoder_path}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sAzv2lgqG9WL"
   },
   "source": [
    "# 2. Watermaring Algorithm Invocation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "bFsJT3LdfynV"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import json\n",
    "from watermark.auto_watermark import AutoWatermark\n",
    "from utils.transformers_config import TransformersConfig\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "# Load data\n",
    "with open('dataset/c4/processed_c4.json', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "item = json.loads(lines[0])\n",
    "prompt = item['prompt']\n",
    "natural_text = item['natural_text']\n",
    "\n",
    "\n",
    "def test_algorithm(algorithm_name):\n",
    "    # Check algorithm name\n",
    "    assert algorithm_name in ['KGW', 'Unigram', 'SWEET', 'EWD', 'SIR', 'XSIR', 'UPV', 'EXP', 'EXPEdit', 'SynthID']\n",
    "\n",
    "    # Device\n",
    "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "    # Transformers config\n",
    "    transformers_config = TransformersConfig(model=AutoModelForCausalLM.from_pretrained(opt_path).to(device),\n",
    "                                            tokenizer=AutoTokenizer.from_pretrained(opt_path),\n",
    "                                            vocab_size=50272,\n",
    "                                            device=device,\n",
    "                                            max_new_tokens=200,\n",
    "                                            min_length=230,\n",
    "                                            do_sample=True,\n",
    "                                            no_repeat_ngram_size=4)\n",
    "\n",
    "    # Load watermark algorithm\n",
    "    myWatermark = AutoWatermark.load(f'{algorithm_name}', algorithm_config=f'config/{algorithm_name}.json', transformers_config=transformers_config)\n",
    "\n",
    "    # Generate text\n",
    "    watermarked_text = myWatermark.generate_watermarked_text(prompt)\n",
    "    unwatermarked_text = myWatermark.generate_unwatermarked_text(prompt)\n",
    "\n",
    "    # Detect\n",
    "    detect_result1 = myWatermark.detect_watermark(watermarked_text)\n",
    "    detect_result2 = myWatermark.detect_watermark(unwatermarked_text)\n",
    "    detect_result3 = myWatermark.detect_watermark(natural_text)\n",
    "\n",
    "    print(\"LLM-generated watermarked text:\")\n",
    "    print(watermarked_text)\n",
    "    print('\\n')\n",
    "    print(detect_result1)\n",
    "    print('\\n')\n",
    "\n",
    "    print(\"LLM-generated unwatermarked text:\")\n",
    "    print(unwatermarked_text)\n",
    "    print('\\n')\n",
    "    print(detect_result2)\n",
    "    print('\\n')\n",
    "\n",
    "    print(\"Natural text:\")\n",
    "    print(natural_text)\n",
    "    print('\\n')\n",
    "    print(detect_result3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "T9lVoHjWLaNq",
    "outputId": "cb2c64d0-96f5-4411-edaa-73c6fb46f89d"
   },
   "outputs": [],
   "source": [
    "test_algorithm('KGW')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "lrWky7M5_KUo",
    "outputId": "6c8d8ea9-86fd-44bd-b1de-3d740a00f12f"
   },
   "outputs": [],
   "source": [
    "test_algorithm('Unigram')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "HXXMor_4_O71",
    "outputId": "3e597f2e-d658-485f-9e27-9f36f39a8aef"
   },
   "outputs": [],
   "source": [
    "test_algorithm('SWEET')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "qPUFwgAT_VH-",
    "outputId": "22e18ffb-2aca-4273-ddfd-99bf41572524"
   },
   "outputs": [],
   "source": [
    "test_algorithm('UPV')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "a7xRLyjmEqI-",
    "outputId": "362248eb-e78e-41d7-d03a-7acf99f4e6ce"
   },
   "outputs": [],
   "source": [
    "test_algorithm('EWD')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "h6SD7L7Z_QcK",
    "outputId": "51c5819f-c671-4266-b6de-5b851f015ba5"
   },
   "outputs": [],
   "source": [
    "test_algorithm('SIR')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "H6lHQezSMNjm",
    "outputId": "82876edc-49c1-4e60-c08e-f8ac50ff4a24"
   },
   "outputs": [],
   "source": [
    "test_algorithm('XSIR')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_algorithm('SynthID')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ng9-GKfs_OH7",
    "outputId": "f53ecff4-70a6-4650-dd15-fe2f2e08ad61"
   },
   "outputs": [],
   "source": [
    "test_algorithm('EXP')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YiJNLrCZFy8n"
   },
   "source": [
    "# 3. Mechanism Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "s7bauSbT2IVX"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import gc\n",
    "import json\n",
    "from visualize.font_settings import FontSettings\n",
    "from watermark.auto_watermark import AutoWatermark\n",
    "from utils.transformers_config import TransformersConfig\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from visualize.page_layout_settings import PageLayoutSettings\n",
    "from visualize.data_for_visualization import DataForVisualization\n",
    "from visualize.visualizer import DiscreteVisualizer, ContinuousVisualizer\n",
    "from visualize.legend_settings import DiscreteLegendSettings, ContinuousLegendSettings\n",
    "from visualize.color_scheme import ColorSchemeForDiscreteVisualization, ColorSchemeForContinuousVisualization\n",
    "from IPython.display import display\n",
    "from PIL import Image, ImageOps, ImageFont, ImageDraw\n",
    "\n",
    "def test_discreet_visualization():\n",
    "    tokens = [\"PREFIX\", \"PREFIX\", \"Hello\", \"world\", \"this\", \"is\", \"a\", \"test\", \"Hello\", \"world\", \"this\", \"is\", \"a\", \"test\", \"Hello\", \"world\", \"this\", \"is\", \"a\", \"test\", \"Hello\", \"world\", \"this\", \"is\", \"a\", \"test\", \"Hello\", \"world\", \"this\", \"is\", \"a\", \"test\"]\n",
    "    flags = [-1, -1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0]\n",
    "    weights = [0, 0, 0.1, 0.5, 0.3, 0.8, 0.2, 0.4, 0.1, 0.5, 0.3, 0.8, 0.2, 0.4, 0.1, 0.5, 0.3, 0.8, 0.2, 0.4, 0.1, 0.5, 0.3, 0.8, 0.2, 0.4, 0.1, 0.5, 0.3, 0.8, 0.2, 0.4]\n",
    "\n",
    "    discreet_visualizer = DiscreteVisualizer(color_scheme=ColorSchemeForDiscreteVisualization(),\n",
    "                                            font_settings=FontSettings(),\n",
    "                                            page_layout_settings=PageLayoutSettings(),\n",
    "                                            legend_settings=DiscreteLegendSettings())\n",
    "    img = discreet_visualizer.visualize(data=DataForVisualization(tokens, flags, weights),\n",
    "                                     show_text=True, visualize_weight=True, display_legend=True)\n",
    "\n",
    "    img.save(\"test1.png\")\n",
    "    display(img)\n",
    "\n",
    "\n",
    "def test_continuous_visualization():\n",
    "    tokens = [\"PREFIX\", \"PREFIX\", \"Hello\", \"world\", \"this\", \"is\", \"a\", \"test\", \"Hello\", \"world\", \"this\", \"is\", \"a\", \"test\", \"Hello\", \"world\", \"this\", \"is\", \"a\", \"test\", \"Hello\", \"world\", \"this\", \"is\", \"a\", \"test\", \"Hello\", \"world\", \"this\", \"is\", \"a\", \"test\"]\n",
    "    values = [None, None, 0.1, 0.5, 0.3, 0.8, 0.2, 0.4, 0.1, 0.5, 0.3, 0.8, 0.2, 0.4, 0.1, 0.5, 0.3, 0.8, 0.2, 0.4, 0.1, 0.5, 0.3, 0.8, 0.2, 0.4, 0.1, 0.5, 0.3, 0.8, 0.2, 0.4]\n",
    "    weights = [0, 0, 0.1, 0.5, 0.3, 0.8, 0.2, 0.4, 0.1, 0.5, 0.3, 0.8, 0.2, 0.4, 0.1, 0.5, 0.3, 0.8, 0.2, 0.4, 0.1, 0.5, 0.3, 0.8, 0.2, 0.4, 0.1, 0.5, 0.3, 0.8, 0.2, 0.4]\n",
    "\n",
    "    continuous_visualizer = ContinuousVisualizer(color_scheme=ColorSchemeForContinuousVisualization(),\n",
    "                                                    font_settings=FontSettings(),\n",
    "                                                    page_layout_settings=PageLayoutSettings(),\n",
    "                                                    legend_settings=ContinuousLegendSettings())\n",
    "    img = continuous_visualizer.visualize(data=DataForVisualization(tokens, values, weights),\n",
    "                                        show_text=False, visualize_weight=False, display_legend=True)\n",
    "\n",
    "    img.save(\"test2.png\")\n",
    "    display(img)\n",
    "\n",
    "\n",
    "def get_data(algorithm_name):\n",
    "    # Load data\n",
    "    with open('dataset/c4/processed_c4.json', 'r') as f:\n",
    "        lines = f.readlines()\n",
    "    item = json.loads(lines[0])\n",
    "    prompt = item['prompt']\n",
    "\n",
    "    # Device\n",
    "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "    # Transformers config\n",
    "    transformers_config = TransformersConfig(model=AutoModelForCausalLM.from_pretrained(opt_path).to(device),\n",
    "                                            tokenizer=AutoTokenizer.from_pretrained(opt_path),\n",
    "                                            vocab_size=50272,\n",
    "                                            device=device,\n",
    "                                            max_new_tokens=200,\n",
    "                                            min_length=230,\n",
    "                                            no_repeat_ngram_size=4)\n",
    "\n",
    "    myWatermark = AutoWatermark.load(f'{algorithm_name}',\n",
    "                                     algorithm_config=f'config/{algorithm_name}.json',\n",
    "                                     transformers_config=transformers_config)\n",
    "\n",
    "    watermarked_text = myWatermark.generate_watermarked_text(prompt)\n",
    "    unwatermarked_text = myWatermark.generate_unwatermarked_text(prompt)\n",
    "    watermarked_data = myWatermark.get_data_for_visualization(watermarked_text)\n",
    "    unwatermarked_data = myWatermark.get_data_for_visualization(unwatermarked_text)\n",
    "\n",
    "    return watermarked_data, unwatermarked_data\n",
    "\n",
    "\n",
    "def test_visualization_without_weight(algorithm_name, visualize_type='discrete'):\n",
    "    # Validate input\n",
    "    assert visualize_type in ['discrete', 'continuous']\n",
    "    if visualize_type == 'discrete':\n",
    "        assert algorithm_name in ['KGW', 'Unigram', 'SWEET', 'UPV', 'SIR', 'XSIR', 'EWD']\n",
    "    else:\n",
    "        assert algorithm_name in ['EXP', 'EXPEdit']\n",
    "\n",
    "    # Get data for visualization\n",
    "    watermarked_data, unwatermarked_data = get_data(algorithm_name)\n",
    "\n",
    "    # Init visualizer\n",
    "    if visualize_type == 'discrete':\n",
    "        visualizer = DiscreteVisualizer(color_scheme=ColorSchemeForDiscreteVisualization(),\n",
    "                                        font_settings=FontSettings(),\n",
    "                                        page_layout_settings=PageLayoutSettings(),\n",
    "                                        legend_settings=DiscreteLegendSettings())\n",
    "    else:\n",
    "        visualizer = ContinuousVisualizer(color_scheme=ColorSchemeForContinuousVisualization(),\n",
    "                                        font_settings=FontSettings(),\n",
    "                                        page_layout_settings=PageLayoutSettings(),\n",
    "                                        legend_settings=ContinuousLegendSettings())\n",
    "\n",
    "    # Visualize\n",
    "    watermarked_img = visualizer.visualize(data=watermarked_data,\n",
    "                                           show_text=True,\n",
    "                                           visualize_weight=True,\n",
    "                                           display_legend=True)\n",
    "\n",
    "    unwatermarked_img = visualizer.visualize(data=unwatermarked_data,\n",
    "                                             show_text=True,\n",
    "                                             visualize_weight=True,\n",
    "                                             display_legend=True)\n",
    "\n",
    "    watermarked_img.save(f\"{algorithm_name}_watermarked.png\")\n",
    "    unwatermarked_img.save(f\"{algorithm_name}_unwatermarked.png\")\n",
    "\n",
    "    watermarked_width, watermarked_height = watermarked_img.size\n",
    "    unwatermarked_width, unwatermarked_height = unwatermarked_img.size\n",
    "\n",
    "    font_size = 22\n",
    "    font = ImageFont.truetype(\"./font/times.ttf\", font_size)\n",
    "    title_height = 80\n",
    "\n",
    "    new_watermarked_img = Image.new('RGB', (watermarked_width, watermarked_height + title_height), (255, 255, 255))\n",
    "    new_unwatermarked_img = Image.new('RGB', (unwatermarked_width, watermarked_height + title_height), (255, 255, 255))\n",
    "\n",
    "    draw1 = ImageDraw.Draw(new_watermarked_img)\n",
    "    text_bbox1 = draw1.textbbox((0, 0), f\"{algorithm_name} watermarked\", font=font)\n",
    "    draw1.text((int((watermarked_width - text_bbox1[2] - text_bbox1[0]) * 0.3), int(title_height * 0.35)), f\"{algorithm_name} watermarked\", fill=(0, 0, 0), font=font)\n",
    "\n",
    "    draw2 = ImageDraw.Draw(new_unwatermarked_img)\n",
    "    text_bbox2 = draw2.textbbox((0, 0), f\"{algorithm_name} unwatermarked\", font=font)\n",
    "    draw2.text((int((unwatermarked_width - text_bbox2[2] - text_bbox2[0]) * 0.3), int(title_height * 0.35)), f\"{algorithm_name} unwatermarked\", fill=(0, 0, 0), font=font)\n",
    "\n",
    "    new_watermarked_img.paste(watermarked_img, (0, title_height))\n",
    "    new_unwatermarked_img.paste(unwatermarked_img, (0, title_height))\n",
    "\n",
    "    total_width = watermarked_width + unwatermarked_width\n",
    "    max_height = watermarked_height + title_height\n",
    "\n",
    "    new_img = Image.new('RGB', (total_width, max_height))\n",
    "\n",
    "    new_img.paste(new_watermarked_img, (0, 0))\n",
    "    new_img.paste(new_unwatermarked_img, (watermarked_width, 0))\n",
    "    display(new_img)\n",
    "\n",
    "def test_visualization_with_weight(algorithm_name):\n",
    "    # Validate input\n",
    "    assert algorithm_name in ['SWEET', 'EWD']\n",
    "\n",
    "    # Get data for visualization\n",
    "    watermarked_data, unwatermarked_data = get_data(algorithm_name)\n",
    "\n",
    "    # Init visualizer\n",
    "    visualizer = DiscreteVisualizer(color_scheme=ColorSchemeForDiscreteVisualization(),\n",
    "                                    font_settings=FontSettings(),\n",
    "                                    page_layout_settings=PageLayoutSettings(),\n",
    "                                    legend_settings=DiscreteLegendSettings())\n",
    "\n",
    "    # Visualize\n",
    "    watermarked_img = visualizer.visualize(data=watermarked_data,\n",
    "                                           show_text=True,\n",
    "                                           visualize_weight=True,\n",
    "                                           display_legend=True)\n",
    "\n",
    "    unwatermarked_img = visualizer.visualize(data=unwatermarked_data,\n",
    "                                             show_text=True,\n",
    "                                             visualize_weight=True,\n",
    "                                             display_legend=True)\n",
    "\n",
    "    watermarked_img.save(f\"{algorithm_name}_watermarked.png\")\n",
    "    unwatermarked_img.save(f\"{algorithm_name}_unwatermarked.png\")\n",
    "\n",
    "    watermarked_width, watermarked_height = watermarked_img.size\n",
    "    unwatermarked_width, unwatermarked_height = unwatermarked_img.size\n",
    "\n",
    "    font_size = 22\n",
    "    font = ImageFont.truetype(\"./font/times.ttf\", font_size)\n",
    "    title_height = 80\n",
    "\n",
    "    new_watermarked_img = Image.new('RGB', (watermarked_width, watermarked_height + title_height), (255, 255, 255))\n",
    "    new_unwatermarked_img = Image.new('RGB', (unwatermarked_width, watermarked_height + title_height), (255, 255, 255))\n",
    "\n",
    "    draw1 = ImageDraw.Draw(new_watermarked_img)\n",
    "    text_bbox1 = draw1.textbbox((0, 0), f\"{algorithm_name} watermarked\", font=font)\n",
    "    draw1.text((int((watermarked_width - text_bbox1[2] - text_bbox1[0]) * 0.3), int(title_height * 0.35)), f\"{algorithm_name} watermarked\", fill=(0, 0, 0), font=font)\n",
    "\n",
    "    draw2 = ImageDraw.Draw(new_unwatermarked_img)\n",
    "    text_bbox2 = draw2.textbbox((0, 0), f\"{algorithm_name} unwatermarked\", font=font)\n",
    "    draw2.text((int((unwatermarked_width - text_bbox2[2] - text_bbox2[0]) * 0.3), int(title_height * 0.35)), f\"{algorithm_name} unwatermarked\", fill=(0, 0, 0), font=font)\n",
    "\n",
    "    new_watermarked_img.paste(watermarked_img, (0, title_height))\n",
    "    new_unwatermarked_img.paste(unwatermarked_img, (0, title_height))\n",
    "\n",
    "    total_width = watermarked_width + unwatermarked_width\n",
    "    max_height = watermarked_height + title_height\n",
    "\n",
    "    new_img = Image.new('RGB', (total_width, max_height))\n",
    "\n",
    "    new_img.paste(new_watermarked_img, (0, 0))\n",
    "    new_img.paste(new_unwatermarked_img, (watermarked_width, 0))\n",
    "    display(new_img)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9NMKDgC-XE2T"
   },
   "source": [
    "## 3.1 Warm Up: Test Visualizer Using Provided Text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 227
    },
    "id": "sD9pz24-IE3k",
    "outputId": "c47fec83-f8f6-42bd-90cf-7d7bcc74f12d"
   },
   "outputs": [],
   "source": [
    "test_discreet_visualization()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 227
    },
    "id": "WywsuVLCIKSk",
    "outputId": "fac47cfb-34aa-4f55-b32c-f496b4ced7bf"
   },
   "outputs": [],
   "source": [
    "test_continuous_visualization()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ML8NKCyhXXt7"
   },
   "source": [
    "## 3.2 Mechansim Visualization of Watermarking Algorithms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vNsxwjZuX4xa"
   },
   "source": [
    "### 3.2.1 KGW Family"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "X9lC4LUOINW3",
    "outputId": "ebf33271-4af9-47e1-8816-4573c3fad702"
   },
   "outputs": [],
   "source": [
    "test_visualization_without_weight('KGW', 'discrete')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EhNfHjVKX_-p"
   },
   "source": [
    "### 3.2.2 Christ Family"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "3zEsqPNyIOjZ",
    "outputId": "5b454aaa-c125-4109-ce40-6f3b782a586e"
   },
   "outputs": [],
   "source": [
    "test_visualization_without_weight('EXP', 'continuous')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "a75Uc61RYEjw"
   },
   "source": [
    "### 3.2.3 Handling Weighted Token Difference in Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "TYXN_cRZIPkm",
    "outputId": "ee6eb333-a46d-4aae-b59d-37e2fcee8ccd"
   },
   "outputs": [],
   "source": [
    "test_visualization_with_weight('SWEET')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "f95y8B3wIQaG",
    "outputId": "0b3a879a-080d-4769-cc2f-5d80631be528"
   },
   "outputs": [],
   "source": [
    "test_visualization_with_weight('EWD')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FYT_wbx7GO16"
   },
   "source": [
    "# 4. Automated Comprehensive Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TMFLte6emdWN"
   },
   "source": [
    "## 4.1 Watermark Detection Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "id": "9W3PowHuFqj6"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import json\n",
    "from evaluation.dataset import C4Dataset\n",
    "from watermark.auto_watermark import AutoWatermark\n",
    "from utils.transformers_config import TransformersConfig\n",
    "from evaluation.tools.success_rate_calculator import DynamicThresholdSuccessRateCalculator\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration, BertTokenizer, BertForMaskedLM\n",
    "from evaluation.pipelines.detection import WatermarkedTextDetectionPipeline, UnWatermarkedTextDetectionPipeline, DetectionPipelineReturnType\n",
    "from evaluation.tools.text_editor import TruncatePromptTextEditor, TruncateTaskTextEditor, WordDeletion, SynonymSubstitution, ContextAwareSynonymSubstitution, GPTParaphraser, DipperParaphraser\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "  gc.collect()\n",
    "  torch.cuda.empty_cache()\n",
    "  with torch.no_grad():\n",
    "      torch.cuda.empty_cache()\n",
    "\n",
    "def test_detection_pipeline(algorithm_name, attack_name):\n",
    "    my_dataset = C4Dataset('dataset/c4/processed_c4.json')\n",
    "    transformers_config = TransformersConfig(model=AutoModelForCausalLM.from_pretrained(opt_path).to(device),\n",
    "                                             tokenizer=AutoTokenizer.from_pretrained(opt_path),\n",
    "                                             vocab_size=50272,\n",
    "                                             device=device,\n",
    "                                             max_new_tokens=200,\n",
    "                                             min_length=230,\n",
    "                                             do_sample=True,\n",
    "                                             no_repeat_ngram_size=4)\n",
    "\n",
    "    my_watermark = AutoWatermark.load(f'{algorithm_name}', algorithm_config=f'config/{algorithm_name}.json', transformers_config=transformers_config)\n",
    "\n",
    "    if attack_name == 'Word-D':\n",
    "      attack = WordDeletion(ratio=0.3)\n",
    "    elif attack_name == 'Word-S':\n",
    "      attack = SynonymSubstitution(ratio=0.5)\n",
    "    elif attack_name == 'Word-S(Context)':\n",
    "      attack = ContextAwareSynonymSubstitution(ratio=0.5,\n",
    "                                               tokenizer=BertTokenizer.from_pretrained('/data2/shared_model/bert-large-uncased'),\n",
    "                                               model=BertForMaskedLM.from_pretrained('/data2/shared_model/bert-large-uncased').to(device))\n",
    "    elif attack_name == 'Doc-P(GPT-3.5)':\n",
    "        attack = GPTParaphraser(openai_model='gpt-3.5-turbo',\n",
    "                                prompt='Please rewrite the following text: ')\n",
    "    elif attack_name == 'Doc-P(Dipper)':\n",
    "        attack = DipperParaphraser(tokenizer=T5Tokenizer.from_pretrained('/data2/shared_model/google/t5-v1_1-xxl/'),\n",
    "                                   model=T5ForConditionalGeneration.from_pretrained('/data2/shared_model/kalpeshk2011/dipper-paraphraser-xxl/', device_map='auto'),\n",
    "                                   lex_diversity=60, order_diversity=0, sent_interval=1,\n",
    "                                   max_new_tokens=100, do_sample=True, top_p=0.75, top_k=None)\n",
    "\n",
    "\n",
    "    pipline1 = WatermarkedTextDetectionPipeline(dataset=my_dataset, text_editor_list=[TruncatePromptTextEditor(), attack],\n",
    "                                                show_progress=True, return_type=DetectionPipelineReturnType.SCORES)\n",
    "\n",
    "    pipline2 = UnWatermarkedTextDetectionPipeline(dataset=my_dataset, text_editor_list=[TruncatePromptTextEditor()],\n",
    "                                            show_progress=True, return_type=DetectionPipelineReturnType.SCORES)\n",
    "\n",
    "    calculator = DynamicThresholdSuccessRateCalculator(labels=['TPR', 'F1'], rule='best')\n",
    "    print(calculator.calculate(pipline1.evaluate(my_watermark), pipline2.evaluate(my_watermark)))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "nSx19Q5hHKbX",
    "outputId": "c760602e-d7d4-4f9c-c160-ce9746fbd6b3"
   },
   "outputs": [],
   "source": [
    "test_detection_pipeline('KGW', 'Word-D')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ey6lPtIpH5Pr"
   },
   "source": [
    "## 4.2 Text Quality Analysis Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "id": "d94Cn6fGrMOs"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import json\n",
    "from watermark.auto_watermark import AutoWatermark\n",
    "from utils.transformers_config import TransformersConfig\n",
    "from evaluation.dataset import C4Dataset, WMT16DE_ENDataset, HumanEvalDataset\n",
    "from evaluation.tools.success_rate_calculator import DynamicThresholdSuccessRateCalculator\n",
    "from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, LlamaTokenizer\n",
    "from evaluation.tools.text_editor import TruncatePromptTextEditor, TruncateTaskTextEditor ,CodeGenerationTextEditor\n",
    "from evaluation.tools.text_quality_analyzer import PPLCalculator, LogDiversityAnalyzer, BLEUCalculator, PassOrNotJudger, GPTTextDiscriminator\n",
    "from evaluation.pipelines.detection import WatermarkedTextDetectionPipeline, UnWatermarkedTextDetectionPipeline, DetectionPipelineReturnType\n",
    "from evaluation.pipelines.quality_analysis import (DirectTextQualityAnalysisPipeline, ReferencedTextQualityAnalysisPipeline, ExternalDiscriminatorTextQualityAnalysisPipeline,\n",
    "                                                   QualityPipelineReturnType)\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "  gc.collect()\n",
    "  torch.cuda.empty_cache()\n",
    "  with torch.no_grad():\n",
    "      torch.cuda.empty_cache()\n",
    "\n",
    "def test_direct_quality_analysis_pipeline(algorithm_name, quality_analyzer_name):\n",
    "    my_dataset = C4Dataset('dataset/c4/processed_c4.json')\n",
    "    transformers_config = TransformersConfig(model=AutoModelForCausalLM.from_pretrained(opt_path).to(device),\n",
    "                          tokenizer=AutoTokenizer.from_pretrained(opt_path),\n",
    "                          vocab_size=50272,\n",
    "                          device=device,\n",
    "                          max_new_tokens=200,\n",
    "                          min_length=230,\n",
    "                          do_sample=True,\n",
    "                          no_repeat_ngram_size=4)\n",
    "    my_watermark = AutoWatermark.load(f'{algorithm_name}', algorithm_config=f'config/{algorithm_name}.json', transformers_config=transformers_config)\n",
    "\n",
    "    if quality_analyzer_name == 'PPL':\n",
    "      analyzer = PPLCalculator(model=AutoModelForCausalLM.from_pretrained(llama_path, device_map='auto'),\n",
    "                                tokenizer=LlamaTokenizer.from_pretrained(llama_path),\n",
    "                                device=device)\n",
    "    elif quality_analyzer_name == 'Log Diversity':\n",
    "      analyzer = LogDiversityAnalyzer()\n",
    "\n",
    "    quality_pipeline = DirectTextQualityAnalysisPipeline(dataset=my_dataset,\n",
    "                                watermarked_text_editor_list=[TruncatePromptTextEditor()],\n",
    "                                unwatermarked_text_editor_list=[],\n",
    "                                analyzer=analyzer,\n",
    "                                unwatermarked_text_source='natural', show_progress=True,\n",
    "                                return_type=QualityPipelineReturnType.MEAN_SCORES)\n",
    "\n",
    "    print(f\"{quality_analyzer_name}:\")\n",
    "    print(quality_pipeline.evaluate(my_watermark))\n",
    "\n",
    "\n",
    "def test_referenced_quality_analysis_pipeline_1(algorithm_name):\n",
    "    \"\"\"Evaluate the impact on text quality in the machine translation task.\"\"\"\n",
    "    my_dataset = WMT16DE_ENDataset('dataset/wmt16_de_en/validation.jsonl')\n",
    "    tokenizer= AutoTokenizer.from_pretrained(nllb_path, src_lang=\"deu_Latn\")\n",
    "    transformers_config = TransformersConfig(model=AutoModelForSeq2SeqLM.from_pretrained(nllb_path).to(device),\n",
    "                                                tokenizer=tokenizer,\n",
    "                                                device=device,\n",
    "                                                forced_bos_token_id=tokenizer.lang_code_to_id[\"eng_Latn\"])\n",
    "\n",
    "    my_watermark = AutoWatermark.load(f'{algorithm_name}', algorithm_config=f'config/{algorithm_name}.json', transformers_config=transformers_config)\n",
    "\n",
    "    quality_pipeline = ReferencedTextQualityAnalysisPipeline(dataset=my_dataset,\n",
    "                                                              watermarked_text_editor_list=[],\n",
    "                                                              unwatermarked_text_editor_list=[],\n",
    "                                                              analyzer=BLEUCalculator(),\n",
    "                                                              unwatermarked_text_source='generated', show_progress=True,\n",
    "                                                              return_type=QualityPipelineReturnType.MEAN_SCORES)\n",
    "\n",
    "\n",
    "    print(\"BLEU:\")\n",
    "    print(quality_pipeline.evaluate(my_watermark))\n",
    "\n",
    "\n",
    "def test_referenced_quality_analysis_pipeline_2(algorithm_name):\n",
    "    \"\"\"Evaluate the impact on text quality in the code generation task.\"\"\"\n",
    "    my_dataset = HumanEvalDataset('dataset/human_eval/test.jsonl')\n",
    "    tokenizer= AutoTokenizer.from_pretrained(tiny_starcoder_path)\n",
    "    transformers_config = TransformersConfig(model=AutoModelForCausalLM.from_pretrained(tiny_starcoder_path, device_map='auto'),\n",
    "                                             tokenizer=tokenizer,\n",
    "                                             device=device,\n",
    "                                             min_length=200,\n",
    "                                             max_length=400)\n",
    "\n",
    "    my_watermark = AutoWatermark.load(f'{algorithm_name}', algorithm_config=f'config/{algorithm_name}.json', transformers_config=transformers_config)\n",
    "\n",
    "    quality_pipeline = ReferencedTextQualityAnalysisPipeline(dataset=my_dataset,\n",
    "                                  watermarked_text_editor_list=[TruncateTaskTextEditor(),CodeGenerationTextEditor()],\n",
    "                                  unwatermarked_text_editor_list=[TruncateTaskTextEditor(), CodeGenerationTextEditor()],\n",
    "                                  analyzer=PassOrNotJudger(),\n",
    "                                  unwatermarked_text_source='generated', show_progress=True,\n",
    "                                  return_type=QualityPipelineReturnType.MEAN_SCORES)\n",
    "\n",
    "    print(\"pass@1:\")\n",
    "    print(quality_pipeline.evaluate(my_watermark))\n",
    "\n",
    "\n",
    "def test_discriminator_quality_analysis_pipeline(algorithm_name):\n",
    "    my_dataset = WMT16DE_ENDataset('dataset/wmt16_de_en/validation.jsonl')\n",
    "    tokenizer= AutoTokenizer.from_pretrained(nllb_path, src_lang=\"deu_Latn\")\n",
    "    transformers_config = TransformersConfig(model=AutoModelForSeq2SeqLM.from_pretrained(nllb_path).to(device),\n",
    "                                                tokenizer=tokenizer,\n",
    "                                                device=device,\n",
    "                                                forced_bos_token_id=tokenizer.lang_code_to_id[\"eng_Latn\"])\n",
    "\n",
    "    my_watermark = AutoWatermark.load(f'{algorithm_name}', algorithm_config=f'config/{algorithm_name}.json', transformers_config=transformers_config)\n",
    "\n",
    "    quality_pipeline = ExternalDiscriminatorTextQualityAnalysisPipeline(dataset=my_dataset,\n",
    "                                      watermarked_text_editor_list=[],\n",
    "                                      unwatermarked_text_editor_list=[],\n",
    "                                      analyzer=GPTTextDiscriminator(openai_model='gpt-4',\n",
    "                                      task_description='Translate the following German text to English'),\n",
    "                                      unwatermarked_text_source='generated', show_progress=True,\n",
    "                                      return_type=QualityPipelineReturnType.MEAN_SCORES\n",
    "                                      )\n",
    "\n",
    "    print(\"Win Rate:\")\n",
    "    print(quality_pipeline.evaluate(my_watermark))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_direct_quality_analysis_pipeline('KGW', 'PPL')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "L0X-QkkLFhVV",
    "outputId": "52d3d3c8-18ec-4b86-8d3e-89e1486156c2"
   },
   "outputs": [],
   "source": [
    "test_direct_quality_analysis_pipeline('Unigram', 'Log Diversity')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "EZB9zkU9FtzQ",
    "outputId": "ac7b2262-702e-4c12-9aef-0cd8fc1e143d"
   },
   "outputs": [],
   "source": [
    "test_referenced_quality_analysis_pipeline_1('SIR')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "vCir5BNYFunq",
    "outputId": "95455639-6e42-46db-d887-2acfc8013380"
   },
   "outputs": [],
   "source": [
    "test_referenced_quality_analysis_pipeline_2('SWEET')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "gpMGUVJAFv_M",
    "outputId": "abfcebda-9ca0-4757-9a61-a3f847c01b94"
   },
   "outputs": [],
   "source": [
    "import openai\n",
    "\n",
    "openai.api_key = \"\"\n",
    "\n",
    "test_discriminator_quality_analysis_pipeline('EWD')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [
    "dEPxhCI9KUEr"
   ],
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
