{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b6284c0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e62fa581",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ed670494",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import polars as pl"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8707ea6d",
   "metadata": {},
   "source": [
    "# Temperature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e3469f9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = \"data.05_23_22_37\"\n",
    "task = \"text_summarization\"\n",
    "\n",
    "default_model={'text_summarization': \"philschmid/bart-large-cnn-samsum\",'machine_translation': \"facebook/mbart-large-en-ro\"}\n",
    "def temperature_ablation(data_path,task,wp_type=\"Delta\"):\n",
    "    df = pl.scan_ndjson(f\"{data_path}/{task}_ablation.txt\")\n",
    "    ldf = df.filter(pl.col(\"watermark_processor\").str.contains(rf\"{wp_type}.*, False\\)\"))\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"wp_str\"].str.contains(rf\"{wp_type}.*, True\\)\"))\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"top_k\"]==50)\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"no_input\"]==False)\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"model_str\"]==default_model[task])\\\n",
    "            .with_columns([pl.col(\"test_config\").struct[\"temperature\"].alias(\"test_temperature\")])\\\n",
    "            .select([\"test_temperature\", \"best_score\"])\\\n",
    "            .explode(pl.col(\"best_score\"))\\\n",
    "            .groupby(\"test_temperature\")\\\n",
    "            .agg(\n",
    "                [\n",
    "                pl.count().alias(\"token_count\"),\n",
    "                pl.col(\"best_score\").mean().alias(\"score_per_token\"),\n",
    "                pl.col(\"best_score\").std().alias(\"score_per_token_std\"),\n",
    "                ]\n",
    "                )\\\n",
    "            .sort(\"test_temperature\")\n",
    "    return ldf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e84ee423",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr > th,\n",
       ".dataframe > tbody > tr > td {\n",
       "  text-align: right;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (3, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_temperature</th><th>token_count</th><th>score_per_token</th><th>score_per_token_std</th></tr><tr><td>f64</td><td>u32</td><td>f64</td><td>f64</td></tr></thead><tbody><tr><td>0.5</td><td>882487</td><td>0.049404</td><td>0.407079</td></tr><tr><td>1.0</td><td>882487</td><td>0.878379</td><td>1.435374</td></tr><tr><td>1.5</td><td>882487</td><td>0.036002</td><td>0.498856</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (3, 4)\n",
       "┌──────────────────┬─────────────┬─────────────────┬─────────────────────┐\n",
       "│ test_temperature ┆ token_count ┆ score_per_token ┆ score_per_token_std │\n",
       "│ ---              ┆ ---         ┆ ---             ┆ ---                 │\n",
       "│ f64              ┆ u32         ┆ f64             ┆ f64                 │\n",
       "╞══════════════════╪═════════════╪═════════════════╪═════════════════════╡\n",
       "│ 0.5              ┆ 882487      ┆ 0.049404        ┆ 0.407079            │\n",
       "│ 1.0              ┆ 882487      ┆ 0.878379        ┆ 1.435374            │\n",
       "│ 1.5              ┆ 882487      ┆ 0.036002        ┆ 0.498856            │\n",
       "└──────────────────┴─────────────┴─────────────────┴─────────────────────┘"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldf = temperature_ablation(data_path,task,wp_type=\"Delta\").collect()\n",
    "ldf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3f56baa8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr > th,\n",
       ".dataframe > tbody > tr > td {\n",
       "  text-align: right;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (3, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_temperature</th><th>token_count</th><th>score_per_token</th><th>score_per_token_std</th></tr><tr><td>f64</td><td>u32</td><td>f64</td><td>f64</td></tr></thead><tbody><tr><td>0.5</td><td>878117</td><td>0.132985</td><td>0.309065</td></tr><tr><td>1.0</td><td>878117</td><td>0.220681</td><td>0.36776</td></tr><tr><td>1.5</td><td>878117</td><td>0.166003</td><td>0.455464</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (3, 4)\n",
       "┌──────────────────┬─────────────┬─────────────────┬─────────────────────┐\n",
       "│ test_temperature ┆ token_count ┆ score_per_token ┆ score_per_token_std │\n",
       "│ ---              ┆ ---         ┆ ---             ┆ ---                 │\n",
       "│ f64              ┆ u32         ┆ f64             ┆ f64                 │\n",
       "╞══════════════════╪═════════════╪═════════════════╪═════════════════════╡\n",
       "│ 0.5              ┆ 878117      ┆ 0.132985        ┆ 0.309065            │\n",
       "│ 1.0              ┆ 878117      ┆ 0.220681        ┆ 0.36776             │\n",
       "│ 1.5              ┆ 878117      ┆ 0.166003        ┆ 0.455464            │\n",
       "└──────────────────┴─────────────┴─────────────────┴─────────────────────┘"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldf = temperature_ablation(data_path,task,wp_type=\"Gamma\").collect()\n",
    "ldf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0bd40a4c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr > th,\n",
       ".dataframe > tbody > tr > td {\n",
       "  text-align: right;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (3, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_temperature</th><th>token_count</th><th>score_per_token</th><th>score_per_token_std</th></tr><tr><td>f64</td><td>u32</td><td>f64</td><td>f64</td></tr></thead><tbody><tr><td>0.5</td><td>69790</td><td>0.041363</td><td>0.303473</td></tr><tr><td>1.0</td><td>69790</td><td>0.420137</td><td>1.13556</td></tr><tr><td>1.5</td><td>69790</td><td>0.019048</td><td>0.32428</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (3, 4)\n",
       "┌──────────────────┬─────────────┬─────────────────┬─────────────────────┐\n",
       "│ test_temperature ┆ token_count ┆ score_per_token ┆ score_per_token_std │\n",
       "│ ---              ┆ ---         ┆ ---             ┆ ---                 │\n",
       "│ f64              ┆ u32         ┆ f64             ┆ f64                 │\n",
       "╞══════════════════╪═════════════╪═════════════════╪═════════════════════╡\n",
       "│ 0.5              ┆ 69790       ┆ 0.041363        ┆ 0.303473            │\n",
       "│ 1.0              ┆ 69790       ┆ 0.420137        ┆ 1.13556             │\n",
       "│ 1.5              ┆ 69790       ┆ 0.019048        ┆ 0.32428             │\n",
       "└──────────────────┴─────────────┴─────────────────┴─────────────────────┘"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldf = temperature_ablation(data_path,\"machine_translation\",wp_type=\"Delta\").collect()\n",
    "ldf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6d022426",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr > th,\n",
       ".dataframe > tbody > tr > td {\n",
       "  text-align: right;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (3, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_temperature</th><th>token_count</th><th>score_per_token</th><th>score_per_token_std</th></tr><tr><td>f64</td><td>u32</td><td>f64</td><td>f64</td></tr></thead><tbody><tr><td>0.5</td><td>69807</td><td>0.084925</td><td>0.241549</td></tr><tr><td>1.0</td><td>69807</td><td>0.105879</td><td>0.291684</td></tr><tr><td>1.5</td><td>69807</td><td>0.087959</td><td>0.335659</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (3, 4)\n",
       "┌──────────────────┬─────────────┬─────────────────┬─────────────────────┐\n",
       "│ test_temperature ┆ token_count ┆ score_per_token ┆ score_per_token_std │\n",
       "│ ---              ┆ ---         ┆ ---             ┆ ---                 │\n",
       "│ f64              ┆ u32         ┆ f64             ┆ f64                 │\n",
       "╞══════════════════╪═════════════╪═════════════════╪═════════════════════╡\n",
       "│ 0.5              ┆ 69807       ┆ 0.084925        ┆ 0.241549            │\n",
       "│ 1.0              ┆ 69807       ┆ 0.105879        ┆ 0.291684            │\n",
       "│ 1.5              ┆ 69807       ┆ 0.087959        ┆ 0.335659            │\n",
       "└──────────────────┴─────────────┴─────────────────┴─────────────────────┘"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldf = temperature_ablation(data_path,\"machine_translation\",wp_type=\"Gamma\").collect()\n",
    "ldf"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "fd11b723",
   "metadata": {},
   "source": [
    "# top_k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f1a5446d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def top_k_ablation(data_path,task,wp_type=\"Delta\"):\n",
    "    df = pl.scan_ndjson(f\"{data_path}/{task}_ablation.txt\")\n",
    "    ldf = df.filter(pl.col(\"watermark_processor\").str.contains(rf\"{wp_type}.*, False\\)\"))\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"wp_str\"].str.contains(rf\"{wp_type}.*, True\\)\"))\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"temperature\"]==1.0)\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"no_input\"]==False)\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"model_str\"]==default_model[task])\\\n",
    "            .with_columns([pl.col(\"test_config\").struct[\"top_k\"].alias(\"test_top_k\")])\\\n",
    "            .select([\"test_top_k\", \"best_score\"])\\\n",
    "            .explode(pl.col(\"best_score\"))\\\n",
    "            .groupby(\"test_top_k\")\\\n",
    "            .agg(\n",
    "                [\n",
    "                pl.count().alias(\"token_count\"),\n",
    "                pl.col(\"best_score\").mean().alias(\"score_per_token\"),\n",
    "                pl.col(\"best_score\").std().alias(\"score_per_token_std\"),\n",
    "                ]\n",
    "                )\\\n",
    "            .sort(\"test_top_k\")\n",
    "    return ldf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a9cbdcf6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr > th,\n",
       ".dataframe > tbody > tr > td {\n",
       "  text-align: right;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (4, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_top_k</th><th>token_count</th><th>score_per_token</th><th>score_per_token_std</th></tr><tr><td>i64</td><td>u32</td><td>f64</td><td>f64</td></tr></thead><tbody><tr><td>0</td><td>882487</td><td>0.377826</td><td>1.124531</td></tr><tr><td>20</td><td>882487</td><td>0.520355</td><td>1.144058</td></tr><tr><td>50</td><td>882487</td><td>0.878379</td><td>1.435374</td></tr><tr><td>100</td><td>882487</td><td>0.582203</td><td>1.262202</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (4, 4)\n",
       "┌────────────┬─────────────┬─────────────────┬─────────────────────┐\n",
       "│ test_top_k ┆ token_count ┆ score_per_token ┆ score_per_token_std │\n",
       "│ ---        ┆ ---         ┆ ---             ┆ ---                 │\n",
       "│ i64        ┆ u32         ┆ f64             ┆ f64                 │\n",
       "╞════════════╪═════════════╪═════════════════╪═════════════════════╡\n",
       "│ 0          ┆ 882487      ┆ 0.377826        ┆ 1.124531            │\n",
       "│ 20         ┆ 882487      ┆ 0.520355        ┆ 1.144058            │\n",
       "│ 50         ┆ 882487      ┆ 0.878379        ┆ 1.435374            │\n",
       "│ 100        ┆ 882487      ┆ 0.582203        ┆ 1.262202            │\n",
       "└────────────┴─────────────┴─────────────────┴─────────────────────┘"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldf = top_k_ablation(data_path,task,wp_type=\"Delta\").collect()\n",
    "ldf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "108bd178",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr > th,\n",
       ".dataframe > tbody > tr > td {\n",
       "  text-align: right;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (4, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_top_k</th><th>token_count</th><th>score_per_token</th><th>score_per_token_std</th></tr><tr><td>i64</td><td>u32</td><td>f64</td><td>f64</td></tr></thead><tbody><tr><td>0</td><td>878117</td><td>0.216535</td><td>0.373048</td></tr><tr><td>20</td><td>878117</td><td>0.211919</td><td>0.362486</td></tr><tr><td>50</td><td>878117</td><td>0.220681</td><td>0.36776</td></tr><tr><td>100</td><td>878117</td><td>0.219959</td><td>0.369314</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (4, 4)\n",
       "┌────────────┬─────────────┬─────────────────┬─────────────────────┐\n",
       "│ test_top_k ┆ token_count ┆ score_per_token ┆ score_per_token_std │\n",
       "│ ---        ┆ ---         ┆ ---             ┆ ---                 │\n",
       "│ i64        ┆ u32         ┆ f64             ┆ f64                 │\n",
       "╞════════════╪═════════════╪═════════════════╪═════════════════════╡\n",
       "│ 0          ┆ 878117      ┆ 0.216535        ┆ 0.373048            │\n",
       "│ 20         ┆ 878117      ┆ 0.211919        ┆ 0.362486            │\n",
       "│ 50         ┆ 878117      ┆ 0.220681        ┆ 0.36776             │\n",
       "│ 100        ┆ 878117      ┆ 0.219959        ┆ 0.369314            │\n",
       "└────────────┴─────────────┴─────────────────┴─────────────────────┘"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldf = top_k_ablation(data_path,task,wp_type=\"Gamma\").collect()\n",
    "ldf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "863e7bf6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr > th,\n",
       ".dataframe > tbody > tr > td {\n",
       "  text-align: right;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (4, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_top_k</th><th>token_count</th><th>score_per_token</th><th>score_per_token_std</th></tr><tr><td>i64</td><td>u32</td><td>f64</td><td>f64</td></tr></thead><tbody><tr><td>0</td><td>69790</td><td>0.02195</td><td>0.34957</td></tr><tr><td>20</td><td>69790</td><td>0.274142</td><td>0.859782</td></tr><tr><td>50</td><td>69790</td><td>0.420137</td><td>1.13556</td></tr><tr><td>100</td><td>69790</td><td>0.288118</td><td>0.93018</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (4, 4)\n",
       "┌────────────┬─────────────┬─────────────────┬─────────────────────┐\n",
       "│ test_top_k ┆ token_count ┆ score_per_token ┆ score_per_token_std │\n",
       "│ ---        ┆ ---         ┆ ---             ┆ ---                 │\n",
       "│ i64        ┆ u32         ┆ f64             ┆ f64                 │\n",
       "╞════════════╪═════════════╪═════════════════╪═════════════════════╡\n",
       "│ 0          ┆ 69790       ┆ 0.02195         ┆ 0.34957             │\n",
       "│ 20         ┆ 69790       ┆ 0.274142        ┆ 0.859782            │\n",
       "│ 50         ┆ 69790       ┆ 0.420137        ┆ 1.13556             │\n",
       "│ 100        ┆ 69790       ┆ 0.288118        ┆ 0.93018             │\n",
       "└────────────┴─────────────┴─────────────────┴─────────────────────┘"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldf = top_k_ablation(data_path,\"machine_translation\",wp_type=\"Delta\").collect()\n",
    "ldf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "550c3d27",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr > th,\n",
       ".dataframe > tbody > tr > td {\n",
       "  text-align: right;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (4, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_top_k</th><th>token_count</th><th>score_per_token</th><th>score_per_token_std</th></tr><tr><td>i64</td><td>u32</td><td>f64</td><td>f64</td></tr></thead><tbody><tr><td>0</td><td>69807</td><td>0.097126</td><td>0.324025</td></tr><tr><td>20</td><td>69807</td><td>0.10194</td><td>0.284995</td></tr><tr><td>50</td><td>69807</td><td>0.105879</td><td>0.291684</td></tr><tr><td>100</td><td>69807</td><td>0.105606</td><td>0.292268</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (4, 4)\n",
       "┌────────────┬─────────────┬─────────────────┬─────────────────────┐\n",
       "│ test_top_k ┆ token_count ┆ score_per_token ┆ score_per_token_std │\n",
       "│ ---        ┆ ---         ┆ ---             ┆ ---                 │\n",
       "│ i64        ┆ u32         ┆ f64             ┆ f64                 │\n",
       "╞════════════╪═════════════╪═════════════════╪═════════════════════╡\n",
       "│ 0          ┆ 69807       ┆ 0.097126        ┆ 0.324025            │\n",
       "│ 20         ┆ 69807       ┆ 0.10194         ┆ 0.284995            │\n",
       "│ 50         ┆ 69807       ┆ 0.105879        ┆ 0.291684            │\n",
       "│ 100        ┆ 69807       ┆ 0.105606        ┆ 0.292268            │\n",
       "└────────────┴─────────────┴─────────────────┴─────────────────────┘"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldf = top_k_ablation(data_path,\"machine_translation\",wp_type=\"Gamma\").collect()\n",
    "ldf"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "54a099b9",
   "metadata": {},
   "source": [
    "# w/o input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "cccc99af",
   "metadata": {},
   "outputs": [],
   "source": [
    "def no_input_ablation(data_path,task,wp_type=\"Delta\"):\n",
    "    df = pl.scan_ndjson(f\"{data_path}/{task}_ablation.txt\")\n",
    "    ldf = df.filter(pl.col(\"watermark_processor\").str.contains(rf\"{wp_type}.*, False\\)\"))\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"wp_str\"].str.contains(rf\"{wp_type}.*, True\\)\"))\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"temperature\"]==1.0)\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"top_k\"]==50)\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"model_str\"]==default_model[task])\\\n",
    "            .with_columns([pl.col(\"test_config\").struct[\"no_input\"].alias(\"test_no_input\")])\\\n",
    "            .select([\"test_no_input\", \"best_score\"])\\\n",
    "            .explode(pl.col(\"best_score\"))\\\n",
    "            .groupby(\"test_no_input\")\\\n",
    "            .agg(\n",
    "                [\n",
    "                pl.count().alias(\"token_count\"),\n",
    "                pl.col(\"best_score\").mean().alias(\"score_per_token\"),\n",
    "                pl.col(\"best_score\").std().alias(\"score_per_token_std\"),\n",
    "                ]\n",
    "                )\\\n",
    "            .sort(\"test_no_input\")\n",
    "    return ldf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "924325bb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr > th,\n",
       ".dataframe > tbody > tr > td {\n",
       "  text-align: right;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (2, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_no_input</th><th>token_count</th><th>score_per_token</th><th>score_per_token_std</th></tr><tr><td>bool</td><td>u32</td><td>f64</td><td>f64</td></tr></thead><tbody><tr><td>false</td><td>882487</td><td>0.878379</td><td>1.435374</td></tr><tr><td>true</td><td>882487</td><td>0.010843</td><td>0.217019</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (2, 4)\n",
       "┌───────────────┬─────────────┬─────────────────┬─────────────────────┐\n",
       "│ test_no_input ┆ token_count ┆ score_per_token ┆ score_per_token_std │\n",
       "│ ---           ┆ ---         ┆ ---             ┆ ---                 │\n",
       "│ bool          ┆ u32         ┆ f64             ┆ f64                 │\n",
       "╞═══════════════╪═════════════╪═════════════════╪═════════════════════╡\n",
       "│ false         ┆ 882487      ┆ 0.878379        ┆ 1.435374            │\n",
       "│ true          ┆ 882487      ┆ 0.010843        ┆ 0.217019            │\n",
       "└───────────────┴─────────────┴─────────────────┴─────────────────────┘"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldf =no_input_ablation(data_path,task,wp_type=\"Delta\").collect()\n",
    "ldf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "916b2849",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr > th,\n",
       ".dataframe > tbody > tr > td {\n",
       "  text-align: right;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (2, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_no_input</th><th>token_count</th><th>score_per_token</th><th>score_per_token_std</th></tr><tr><td>bool</td><td>u32</td><td>f64</td><td>f64</td></tr></thead><tbody><tr><td>false</td><td>878117</td><td>0.220681</td><td>0.36776</td></tr><tr><td>true</td><td>878117</td><td>0.024405</td><td>0.241708</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (2, 4)\n",
       "┌───────────────┬─────────────┬─────────────────┬─────────────────────┐\n",
       "│ test_no_input ┆ token_count ┆ score_per_token ┆ score_per_token_std │\n",
       "│ ---           ┆ ---         ┆ ---             ┆ ---                 │\n",
       "│ bool          ┆ u32         ┆ f64             ┆ f64                 │\n",
       "╞═══════════════╪═════════════╪═════════════════╪═════════════════════╡\n",
       "│ false         ┆ 878117      ┆ 0.220681        ┆ 0.36776             │\n",
       "│ true          ┆ 878117      ┆ 0.024405        ┆ 0.241708            │\n",
       "└───────────────┴─────────────┴─────────────────┴─────────────────────┘"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldf = no_input_ablation(data_path,task,wp_type=\"Gamma\").collect()\n",
    "ldf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ab82655d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr > th,\n",
       ".dataframe > tbody > tr > td {\n",
       "  text-align: right;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (2, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_no_input</th><th>token_count</th><th>score_per_token</th><th>score_per_token_std</th></tr><tr><td>bool</td><td>u32</td><td>f64</td><td>f64</td></tr></thead><tbody><tr><td>false</td><td>69790</td><td>0.420137</td><td>1.13556</td></tr><tr><td>true</td><td>69790</td><td>0.009619</td><td>0.2004</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (2, 4)\n",
       "┌───────────────┬─────────────┬─────────────────┬─────────────────────┐\n",
       "│ test_no_input ┆ token_count ┆ score_per_token ┆ score_per_token_std │\n",
       "│ ---           ┆ ---         ┆ ---             ┆ ---                 │\n",
       "│ bool          ┆ u32         ┆ f64             ┆ f64                 │\n",
       "╞═══════════════╪═════════════╪═════════════════╪═════════════════════╡\n",
       "│ false         ┆ 69790       ┆ 0.420137        ┆ 1.13556             │\n",
       "│ true          ┆ 69790       ┆ 0.009619        ┆ 0.2004              │\n",
       "└───────────────┴─────────────┴─────────────────┴─────────────────────┘"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldf = no_input_ablation(data_path,\"machine_translation\",wp_type=\"Delta\").collect()\n",
    "ldf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "4e611e83",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr > th,\n",
       ".dataframe > tbody > tr > td {\n",
       "  text-align: right;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (2, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_no_input</th><th>token_count</th><th>score_per_token</th><th>score_per_token_std</th></tr><tr><td>bool</td><td>u32</td><td>f64</td><td>f64</td></tr></thead><tbody><tr><td>false</td><td>69807</td><td>0.105879</td><td>0.291684</td></tr><tr><td>true</td><td>69807</td><td>0.018599</td><td>0.190461</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (2, 4)\n",
       "┌───────────────┬─────────────┬─────────────────┬─────────────────────┐\n",
       "│ test_no_input ┆ token_count ┆ score_per_token ┆ score_per_token_std │\n",
       "│ ---           ┆ ---         ┆ ---             ┆ ---                 │\n",
       "│ bool          ┆ u32         ┆ f64             ┆ f64                 │\n",
       "╞═══════════════╪═════════════╪═════════════════╪═════════════════════╡\n",
       "│ false         ┆ 69807       ┆ 0.105879        ┆ 0.291684            │\n",
       "│ true          ┆ 69807       ┆ 0.018599        ┆ 0.190461            │\n",
       "└───────────────┴─────────────┴─────────────────┴─────────────────────┘"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldf = no_input_ablation(data_path,\"machine_translation\",wp_type=\"Gamma\").collect()\n",
    "ldf"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8b9f7116",
   "metadata": {},
   "source": [
    "# different model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "8cf4c51b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_ablation(data_path,task,wp_type=\"Delta\"):\n",
    "    df = pl.scan_ndjson(f\"{data_path}/{task}_ablation.txt\")\n",
    "    ldf = df.filter(pl.col(\"watermark_processor\").str.contains(rf\"{wp_type}.*, False\\)\"))\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"wp_str\"].str.contains(rf\"{wp_type}.*, True\\)\"))\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"temperature\"]==1.0)\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"top_k\"]==50)\\\n",
    "            .filter(pl.col(\"test_config\").struct[\"no_input\"]==False)\\\n",
    "            .with_columns([pl.col(\"test_config\").struct[\"model_str\"].alias(\"test_model_str\")])\\\n",
    "            .select([\"test_model_str\", \"best_score\"])\\\n",
    "            .explode(pl.col(\"best_score\"))\\\n",
    "            .groupby(\"test_model_str\")\\\n",
    "            .agg(\n",
    "                [\n",
    "                pl.count().alias(\"token_count\"),\n",
    "                pl.col(\"best_score\").mean().alias(\"score_per_token\"),\n",
    "                pl.col(\"best_score\").std().alias(\"score_per_token_std\"),\n",
    "                ]\n",
    "                )\\\n",
    "            .sort(\"test_model_str\")\n",
    "    return ldf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "46014304",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr > th,\n",
       ".dataframe > tbody > tr > td {\n",
       "  text-align: right;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (2, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_model_str</th><th>token_count</th><th>score_per_token</th><th>score_per_token_std</th></tr><tr><td>str</td><td>u32</td><td>f64</td><td>f64</td></tr></thead><tbody><tr><td>&quot;facebook/bart-…</td><td>882487</td><td>0.040898</td><td>0.446834</td></tr><tr><td>&quot;philschmid/bar…</td><td>882487</td><td>0.878379</td><td>1.435374</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (2, 4)\n",
       "┌──────────────────────────────────┬─────────────┬─────────────────┬─────────────────────┐\n",
       "│ test_model_str                   ┆ token_count ┆ score_per_token ┆ score_per_token_std │\n",
       "│ ---                              ┆ ---         ┆ ---             ┆ ---                 │\n",
       "│ str                              ┆ u32         ┆ f64             ┆ f64                 │\n",
       "╞══════════════════════════════════╪═════════════╪═════════════════╪═════════════════════╡\n",
       "│ facebook/bart-large-cnn          ┆ 882487      ┆ 0.040898        ┆ 0.446834            │\n",
       "│ philschmid/bart-large-cnn-samsum ┆ 882487      ┆ 0.878379        ┆ 1.435374            │\n",
       "└──────────────────────────────────┴─────────────┴─────────────────┴─────────────────────┘"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldf =model_ablation(data_path,task,wp_type=\"Delta\").collect()\n",
    "ldf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "5b68cf7c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr > th,\n",
       ".dataframe > tbody > tr > td {\n",
       "  text-align: right;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (2, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_model_str</th><th>token_count</th><th>score_per_token</th><th>score_per_token_std</th></tr><tr><td>str</td><td>u32</td><td>f64</td><td>f64</td></tr></thead><tbody><tr><td>&quot;facebook/bart-…</td><td>878117</td><td>0.091067</td><td>0.412726</td></tr><tr><td>&quot;philschmid/bar…</td><td>878117</td><td>0.220681</td><td>0.36776</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (2, 4)\n",
       "┌──────────────────────────────────┬─────────────┬─────────────────┬─────────────────────┐\n",
       "│ test_model_str                   ┆ token_count ┆ score_per_token ┆ score_per_token_std │\n",
       "│ ---                              ┆ ---         ┆ ---             ┆ ---                 │\n",
       "│ str                              ┆ u32         ┆ f64             ┆ f64                 │\n",
       "╞══════════════════════════════════╪═════════════╪═════════════════╪═════════════════════╡\n",
       "│ facebook/bart-large-cnn          ┆ 878117      ┆ 0.091067        ┆ 0.412726            │\n",
       "│ philschmid/bart-large-cnn-samsum ┆ 878117      ┆ 0.220681        ┆ 0.36776             │\n",
       "└──────────────────────────────────┴─────────────┴─────────────────┴─────────────────────┘"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldf = model_ablation(data_path,task,wp_type=\"Gamma\").collect()\n",
    "ldf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d31bb8f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "4292d1429c89080e4a98e8c5b192a0c44fbb5e481943ce812f923de47b1f7601"
  },
  "kernelspec": {
   "display_name": "env: (machinelearning)",
   "language": "python",
   "name": "machinelearning"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
