{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## This nb assisted with loading and then subsetting of the watermarked generation data for the human study\n",
    "\n",
    "### Note: this script should be moved to/run from the same dir as the `utils` subdir lives in to work properly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from argparse import Namespace\n",
    "from tqdm import tqdm\n",
    "from statistics import mean\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "from datasets import Dataset\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from matplotlib import rc\n",
    "rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n",
    "rc('text', usetex=True)\n",
    "\n",
    "import cmasher as cmr\n",
    "\n",
    "\n",
    "from utils.io import read_json, read_jsonlines, write_lst_json, write_jsonlines, write_json\n",
    "\n",
    "from utils.notebooks import filter_text_col_length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "ROOT_DIR = \"/scratch/<username>/watermarking-root\"\n",
    "INPUT_DIR = f\"{ROOT_DIR}/input\"\n",
    "OUTPUT_DIR = f\"{ROOT_DIR}/output\"\n",
    "JSONL_FILENAME = \"gen_table_w_metrics.jsonl\"\n",
    "META_FILENAME = \"gen_table_w_metrics_meta.json\"\n",
    "\n",
    "\n",
    "# run_name = \"vicuna_lfqa_1000_600_nofilter\"\n",
    "# run_name = \"test_vicuna_c4\"\n",
    "# run_name = \"test_vicuna_essays\"\n",
    "\n",
    "# run_name = \"vicuna_lfqa/vicuna_lfqa_600_0-25_2-0_0-7_1-0_eval\"\n",
    "run_name = \"vicuna_lfqa/vicuna_lfqa_600_0-25_4-0_0-7_1-0_eval\"\n",
    "\n",
    "run_dir = f\"{OUTPUT_DIR}/{run_name}\"\n",
    "data_path = f\"{run_dir}/{JSONL_FILENAME}\"\n",
    "meta_path = f\"{run_dir}/{META_FILENAME}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = Namespace()\n",
    "metadata = read_json(meta_path)\n",
    "args.__dict__.update(metadata)\n",
    "\n",
    "raw_rows = [ex for ex in read_jsonlines(data_path)]\n",
    "ds = Dataset.from_list(raw_rows)\n",
    "ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for k,v in args.__dict__.items():\n",
    "#     print(f\"{k}: {v}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LFQA paraphrase setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df = ds.to_pandas()\n",
    "\n",
    "# # First filter setting\n",
    "# filter_cols = {\n",
    "#     \"baseline_completion\": (600, 600, 1000),\n",
    "#     # \"no_wm_output\"       : (300, 100, 1000), # if this is not for para qual then okay\n",
    "#     \"w_wm_output\"        : (300, 100, 1000)\n",
    "# }\n",
    "# count_suffix = \"_length\"\n",
    "# count_suffix = \"_num_tokens_scored\"\n",
    "\n",
    "# for col_name, (target_T, lower_tol, upper_tol) in filter_cols.items():\n",
    "#     df = filter_text_col_length(df, \n",
    "#                                 text_col_name=col_name, \n",
    "#                                 count_suffix=count_suffix, \n",
    "#                                 upper_T=target_T+upper_tol, \n",
    "#                                 lower_T=target_T-lower_tol)\n",
    "    \n",
    "# text_col_names = [\"truncated_input\", \"baseline_completion\", \"no_wm_output\", \"w_wm_output\"]\n",
    "# metric_col_names = [\"baseline_completion_z_score\", \"w_wm_output_num_tokens_scored\", \"w_wm_output_z_score\", \"baseline_completion_vs_w_wm_output_p_sp\", \"w_wm_output_repetition_4\", \"w_wm_output_log_diversity\"]\n",
    "\n",
    "# # df[text_col_names + metric_col_names]\n",
    "\n",
    "# # # First filtered setting\n",
    "# # print(f\"Metric filter drop:\")\n",
    "# # print(len(df))\n",
    "# # df = df[df[metric_col_names[0]] < 3.0]\n",
    "# # df = df[df[metric_col_names[2]] > 6.0]\n",
    "# # # df = df[df[metric_col_names[3]] > 0.5] # doesn't yield 60 with the non-english drop\n",
    "# # df = df[df[metric_col_names[3]] > 0.468] # target 60 w/ non-english drop\n",
    "# # print(len(df))\n",
    "\n",
    "# # # Filtered setting will full dataset as input\n",
    "# # hparams = \"0-25_2-0_0-7_1-0\"\n",
    "# # print(f\"Metric filter drop:\")\n",
    "# # print(len(df))\n",
    "# # df = df[df[metric_col_names[2]] > 7.5]\n",
    "# # df = df[df[metric_col_names[3]] > 0.445] # target 60 w/ non-english drop\n",
    "# # # df = df[df[metric_col_names[4]] < 0.15]\n",
    "# # # df = df[df[metric_col_names[5]] > 4.5]\n",
    "# # print(len(df))\n",
    "\n",
    "# # Filtered setting will full dataset as input\n",
    "# hparams = \"0-25_4-0_0-7_1-0\"\n",
    "# print(f\"Metric filter drop:\")\n",
    "# print(len(df))\n",
    "# df = df[df[metric_col_names[2]] > 9.0]\n",
    "# df = df[df[metric_col_names[3]] > 0.6]\n",
    "# df = df[df[metric_col_names[4]] < 0.11] # target 60 w/ non-english drop\n",
    "# # df = df[df[metric_col_names[5]] > 4.5]\n",
    "# print(len(df))\n",
    "\n",
    "# # check whether string contains non-english unicode characters\n",
    "# def is_english(s):\n",
    "#     try:\n",
    "#         s.encode(encoding='utf-8').decode('ascii')\n",
    "#     except UnicodeDecodeError:\n",
    "#         # print(s)\n",
    "#         return False\n",
    "#     else:\n",
    "#         return True\n",
    "# print(f\"Non-english drop:\")\n",
    "# df = df[df[\"w_wm_output\"].apply(is_english)]\n",
    "# print(len(df))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # sort by multiple fields, some ascending, some descending\n",
    "# # sort_cols = [\"w_wm_output_num_tokens_scored\", \"w_wm_output_z_score\", \"baseline_completion_vs_w_wm_output_p_sp\", \"w_wm_output_coherence\"]\n",
    "# # sort_ascending = [False, False, False, False]\n",
    "# sort_cols = [\"w_wm_output_z_score\", \"baseline_completion_vs_w_wm_output_p_sp\"]\n",
    "# sort_ascending = [False, False]\n",
    "\n",
    "# df.sort_values(by=sort_cols, ascending=sort_ascending, inplace=False)[text_col_names + metric_col_names]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # plot a quick roc curve to check the starting perf of the subset\n",
    "# from sklearn import metrics\n",
    "# baseline_stats = df[\"baseline_completion_z_score\"].values\n",
    "# w_wm_stats = df[\"w_wm_output_z_score\"].values\n",
    "# all_scores = np.concatenate([baseline_stats, w_wm_stats])\n",
    "\n",
    "# baseline_labels = np.zeros_like(baseline_stats)\n",
    "# attacked_labels = np.ones_like(w_wm_stats)\n",
    "# all_labels = np.concatenate([baseline_labels, attacked_labels])\n",
    "\n",
    "# fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
    "# roc_auc = metrics.auc(fpr, tpr)\n",
    "# print(roc_auc)\n",
    "\n",
    "# plt.figure(figsize=(4,4))\n",
    "# plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')\n",
    "# plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')\n",
    "# plt.xlim([0.0, 1.0])\n",
    "# plt.ylim([0.0, 1.05])\n",
    "# plt.xlabel('False Positive Rate\\n(1 - Specificity)')\n",
    "# plt.ylabel('True Positive Rate\\n(Sensitivity)')\n",
    "# plt.title('ROC Curve')\n",
    "# plt.legend(loc=\"lower right\")\n",
    "# plt.tight_layout()\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df[\"w_wm_output_num_tokens_scored\"].describe()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LFQA quality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = ds.to_pandas()\n",
    "\n",
    "# # First filter setting\n",
    "# filter_cols = {\n",
    "#     \"baseline_completion\": (600, 600, 1000),\n",
    "#     \"no_wm_output\"       : (300, 100, 1000), \n",
    "#     \"w_wm_output\"        : (300, 100, 1000)\n",
    "# }\n",
    "# Second filter setting, consider possible triangle setuup\n",
    "filter_cols = {\n",
    "    \"baseline_completion\": (600, 600, 1000),\n",
    "    # \"baseline_completion\": (300, 100, 1000),\n",
    "    \"no_wm_output\"       : (300, 100, 1000), \n",
    "    \"w_wm_output\"        : (300, 100, 1000)\n",
    "}\n",
    "count_suffix = \"_length\"\n",
    "count_suffix = \"_num_tokens_scored\"\n",
    "\n",
    "for col_name, (target_T, lower_tol, upper_tol) in filter_cols.items():\n",
    "    df = filter_text_col_length(df, \n",
    "                                text_col_name=col_name, \n",
    "                                count_suffix=count_suffix, \n",
    "                                upper_T=target_T+upper_tol, \n",
    "                                lower_T=target_T-lower_tol)\n",
    "    \n",
    "text_col_names = [\"truncated_input\", \"baseline_completion\", \"no_wm_output\", \"w_wm_output\"]\n",
    "metric_col_names = [\"baseline_completion_z_score\", \"no_wm_output_num_tokens_scored\", \"w_wm_output_num_tokens_scored\", \"w_wm_output_z_score\", \"baseline_completion_vs_w_wm_output_p_sp\", \"no_wm_output_vs_w_wm_output_p_sp\", \"w_wm_output_coherence\"]\n",
    "\n",
    "# df[text_col_names + metric_col_names]\n",
    "\n",
    "# # First filtered setting\n",
    "# print(f\"Metric filter drop:\")\n",
    "# print(len(df))\n",
    "# df = df[abs(df[metric_col_names[1]]-df[metric_col_names[2]]) < 80]\n",
    "# df = df[df[metric_col_names[3]] > 4.0]\n",
    "# df = df[df[metric_col_names[5]] > 0.7]\n",
    "# print(len(df))\n",
    "\n",
    "# # # Filtered setting will full dataset as input\n",
    "# hparams = \"0-25_2-0_0-7_1-0\"\n",
    "# print(f\"Metric filter drop:\")\n",
    "# print(len(df))\n",
    "# df = df[abs(df[metric_col_names[1]]-df[metric_col_names[2]]) < 85] # target 3*60 w/ non-english drop\n",
    "# df = df[df[metric_col_names[3]] > 4.0]\n",
    "# df = df[df[metric_col_names[5]] > 0.7]\n",
    "# print(len(df))\n",
    "\n",
    "# # Filtered setting will full dataset as input\n",
    "hparams = \"0-25_4-0_0-7_1-0\"\n",
    "print(f\"Metric filter drop:\")\n",
    "print(len(df))\n",
    "# df = df[abs(df[metric_col_names[1]]-df[metric_col_names[2]]) < 80]\n",
    "df = df[abs(df[metric_col_names[1]]-df[metric_col_names[2]]) < 50]\n",
    "df = df[df[metric_col_names[3]] > 4.0]\n",
    "# df = df[df[metric_col_names[5]] > 0.64] # target 3*60 w/ non-english drop NOTE SCRATCHING THIS\n",
    "print(len(df))\n",
    "\n",
    "# check whether string contains non-english unicode characters\n",
    "def is_english(s):\n",
    "    try:\n",
    "        s.encode(encoding='utf-8').decode('ascii')\n",
    "    except UnicodeDecodeError:\n",
    "        # print(s)\n",
    "        return False\n",
    "    else:\n",
    "        return True\n",
    "print(f\"Non-english drop:\")\n",
    "df = df[df[\"no_wm_output\"].apply(is_english)]\n",
    "df = df[df[\"w_wm_output\"].apply(is_english)]\n",
    "print(len(df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # sort by multiple fields, some ascending, some descending\n",
    "# # sort_cols = [\"w_wm_output_num_tokens_scored\", \"w_wm_output_z_score\", \"baseline_completion_vs_w_wm_output_p_sp\", \"w_wm_output_coherence\"]\n",
    "# # sort_ascending = [False, False, False, False]\n",
    "# sort_cols = [\"w_wm_output_z_score\", \"baseline_completion_vs_w_wm_output_p_sp\"]\n",
    "# sort_ascending = [False, False]\n",
    "\n",
    "# df.sort_values(by=sort_cols, ascending=sort_ascending, inplace=False)[text_col_names + metric_col_names]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# idx = 2256\n",
    "\n",
    "# print(f\"{text_col_names[0]}:\\n{df.loc[idx][text_col_names[0]]}\")\n",
    "# print(f\"{text_col_names[1]}:\\n{df.loc[idx][text_col_names[1]]}\")\n",
    "# print(f\"{text_col_names[2]}:\\n{df.loc[idx][text_col_names[2]]}\")\n",
    "# print(f\"{text_col_names[3]}:\\n{df.loc[idx][text_col_names[3]]}\")\n",
    "\n",
    "# print(f\"{metric_col_names[0]}:\\n{df.loc[idx][metric_col_names[0]]}\")\n",
    "# print(f\"{metric_col_names[1]}:\\n{df.loc[idx][metric_col_names[1]]}\")\n",
    "# print(f\"{metric_col_names[2]}:\\n{df.loc[idx][metric_col_names[2]]}\")\n",
    "# print(f\"{metric_col_names[3]}:\\n{df.loc[idx][metric_col_names[3]]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df[\"no_wm_output_num_tokens_scored\"].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df[\"w_wm_output_num_tokens_scored\"].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # plot a quick roc curve to check the starting perf of the subset\n",
    "# from sklearn import metrics\n",
    "# baseline_stats = df[\"baseline_completion_z_score\"].values\n",
    "# w_wm_stats = df[\"w_wm_output_z_score\"].values\n",
    "# all_scores = np.concatenate([baseline_stats, w_wm_stats])\n",
    "\n",
    "# baseline_labels = np.zeros_like(baseline_stats)\n",
    "# attacked_labels = np.ones_like(w_wm_stats)\n",
    "# all_labels = np.concatenate([baseline_labels, attacked_labels])\n",
    "\n",
    "# fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
    "# roc_auc = metrics.auc(fpr, tpr)\n",
    "# print(roc_auc)\n",
    "\n",
    "# plt.figure(figsize=(4,4))\n",
    "# plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')\n",
    "# plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')\n",
    "# plt.xlim([0.0, 1.0])\n",
    "# plt.ylim([0.0, 1.05])\n",
    "# plt.xlabel('False Positive Rate\\n(1 - Specificity)')\n",
    "# plt.ylabel('True Positive Rate\\n(Sensitivity)')\n",
    "# plt.title('ROC Curve')\n",
    "# plt.legend(loc=\"lower right\")\n",
    "# plt.tight_layout()\n",
    "# plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### C4 paraphrase setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df = ds.to_pandas()\n",
    "\n",
    "# filter_cols = {\n",
    "#     \"baseline_completion\": (200, 200, 1000),\n",
    "#     # \"no_wm_output\"       : (200, 25, 25), # if this is not for para qual then okay\n",
    "#     \"w_wm_output\"        : (200, 25, 25)\n",
    "# }\n",
    "# count_suffix = \"_length\"\n",
    "# count_suffix = \"_num_tokens_scored\"\n",
    "\n",
    "# for col_name, (target_T, lower_tol, upper_tol) in filter_cols.items():\n",
    "#     df = filter_text_col_length(df, \n",
    "#                                 text_col_name=col_name, \n",
    "#                                 count_suffix=count_suffix, \n",
    "#                                 upper_T=target_T+upper_tol, \n",
    "#                                 lower_T=target_T-lower_tol)\n",
    "    \n",
    "# text_col_names = [\"truncated_input\", \"baseline_completion\", \"no_wm_output\", \"w_wm_output\"]\n",
    "# metric_col_names = [\"baseline_completion_z_score\", \"w_wm_output_num_tokens_scored\", \"w_wm_output_z_score\", \"baseline_completion_vs_w_wm_output_p_sp\", \"w_wm_output_coherence\", \"w_wm_output_repetition_2\"]\n",
    "\n",
    "# # df[text_col_names + metric_col_names]\n",
    "\n",
    "# print(len(df))\n",
    "# df = df[df[metric_col_names[0]] < 3.0]\n",
    "# df = df[df[metric_col_names[2]] > 7.0]\n",
    "# df = df[df[metric_col_names[3]] > 0.5]\n",
    "# df = df[df[metric_col_names[5]] < 0.035]\n",
    "# print(len(df))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # sort by multiple fields, some ascending, some descending\n",
    "# # sort_cols = [\"w_wm_output_num_tokens_scored\", \"w_wm_output_z_score\", \"baseline_completion_vs_w_wm_output_p_sp\", \"w_wm_output_coherence\"]\n",
    "# # sort_ascending = [False, False, False, False]\n",
    "# sort_cols = [\"w_wm_output_z_score\", \"baseline_completion_vs_w_wm_output_p_sp\", \"w_wm_output_coherence\"]\n",
    "# sort_ascending = [False, False, False]\n",
    "\n",
    "# df.sort_values(by=sort_cols, ascending=sort_ascending, inplace=False)[text_col_names + metric_col_names]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Essay paraphrase setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df = ds.to_pandas()\n",
    "\n",
    "# filter_cols = {\n",
    "#     \"baseline_completion\": (600, 600, 1000),\n",
    "#     # \"no_wm_output\"       : (300, 100, 1000), # if this is not for para qual then okay\n",
    "#     \"w_wm_output\"        : (600, 0, 1000)\n",
    "# }\n",
    "# count_suffix = \"_length\"\n",
    "# count_suffix = \"_num_tokens_scored\"\n",
    "\n",
    "# for col_name, (target_T, lower_tol, upper_tol) in filter_cols.items():\n",
    "#     df = filter_text_col_length(df, \n",
    "#                                 text_col_name=col_name, \n",
    "#                                 count_suffix=count_suffix, \n",
    "#                                 upper_T=target_T+upper_tol, \n",
    "#                                 lower_T=target_T-lower_tol)\n",
    "    \n",
    "# text_col_names = [\"truncated_input\", \"baseline_completion\", \"no_wm_output\", \"w_wm_output\"]\n",
    "# metric_col_names = [\"baseline_completion_z_score\", \"w_wm_output_num_tokens_scored\", \"w_wm_output_z_score\", \"baseline_completion_vs_w_wm_output_p_sp\", \"w_wm_output_ppl\", \"w_wm_output_repetition_4\"]\n",
    "\n",
    "# # df[text_col_names + metric_col_names]\n",
    "\n",
    "# print(len(df))\n",
    "# df = df[df[metric_col_names[0]] < 2.0]\n",
    "# df = df[df[metric_col_names[2]] > 8.0]\n",
    "# # df = df[df[metric_col_names[3]] > 0.5]\n",
    "# # df = df[df[metric_col_names[4]] < 6.0]\n",
    "# df = df[df[metric_col_names[5]] < 0.1]\n",
    "# print(len(df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # sort by multiple fields, some ascending, some descending\n",
    "# # sort_cols = [\"w_wm_output_num_tokens_scored\", \"w_wm_output_z_score\", \"baseline_completion_vs_w_wm_output_p_sp\", \"w_wm_output_coherence\"]\n",
    "# # sort_ascending = [False, False, False, False]\n",
    "# # sort_cols = [\"w_wm_output_z_score\", \"baseline_completion_vs_w_wm_output_p_sp\"]\n",
    "# sort_cols = [\"baseline_completion_vs_w_wm_output_p_sp\",\"w_wm_output_z_score\"]\n",
    "# sort_ascending = [False, False]\n",
    "\n",
    "# df.sort_values(by=sort_cols, ascending=sort_ascending, inplace=False)[text_col_names + metric_col_names]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Writing for paraphrase data format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Reformat, select keys, and\n",
    "# # write the raw rows to a json file formatted as a list of dicts\n",
    "# def select_keys(ex):\n",
    "\n",
    "#     ex = {\n",
    "#         \"idx\": ex[\"idx\"],\n",
    "#         \"truncated_input\": ex[\"truncated_input\"].replace(\"Answer the following question in 200-300 words. Explain it like I\\'m five.\\n\\nQ:\",\"\").replace(\"\\nA:\",\"\"),\n",
    "#         \"baseline_completion\": ex[\"baseline_completion\"],\n",
    "#         \"no_wm_output\": ex[\"no_wm_output\"],\n",
    "#         \"w_wm_output\": ex[\"w_wm_output\"],\n",
    "#     }\n",
    "\n",
    "#     return ex\n",
    "\n",
    "# final_ds = Dataset.from_pandas(df)\n",
    "# annotation_data_to_write = [{\"data\": select_keys(ex)} for ex in final_ds]\n",
    "\n",
    "# write_lst_json(annotation_data_to_write, f\"{run_dir}/filtered_for_paraphrase_annotation.json\")\n",
    "\n",
    "# full_raw_annotation_data = [ex for ex in final_ds]\n",
    "# write_jsonlines(full_raw_annotation_data, f\"{run_dir}/filtered_for_paraphrase_annotation_full.jsonl\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Writing for quality rating data format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Reformat, select keys, and\n",
    "# # write the raw rows to a json file formatted as a list of dicts\n",
    "\n",
    "# import random\n",
    "# # set seed\n",
    "\n",
    "# seed = 123\n",
    "# random.seed(seed)\n",
    "\n",
    "# randomize_pairs = True\n",
    "# # randomize_pairs = False\n",
    "\n",
    "# def select_keys(ex, randomize_pairs=False):\n",
    "#     ex = {\n",
    "#         \"idx\": ex[\"idx\"],\n",
    "#         \"truncated_input\": ex[\"truncated_input\"].replace(\"Answer the following question in 200-300 words. Explain it like I\\'m five.\\n\\nQ:\",\"\").replace(\"\\nA:\",\"\"),\n",
    "#         \"baseline_completion\": ex[\"baseline_completion\"],\n",
    "#         \"no_wm_output\": ex[\"no_wm_output\"],\n",
    "#         \"w_wm_output\": ex[\"w_wm_output\"],\n",
    "#     }\n",
    "#     cols_to_show = [\"no_wm_output\", \"w_wm_output\"]\n",
    "#     if not randomize_pairs:\n",
    "#         ex[\"seq_a\"] = ex[\"no_wm_output\"]\n",
    "#         ex[\"seq_b\"] = ex[\"w_wm_output\"]\n",
    "#         ex[\"mapping\"] = {\n",
    "#             \"seq_a\": \"no_wm_output\",\n",
    "#             \"seq_b\": \"w_wm_output\",\n",
    "#         }\n",
    "#     else:\n",
    "#         a_idx = random.choice([0, 1])\n",
    "#         ex[\"seq_a\"] = ex[cols_to_show[a_idx]]\n",
    "#         ex[\"seq_b\"] = ex[cols_to_show[1 - a_idx]]\n",
    "#         ex[\"mapping\"] = {\n",
    "#             \"seq_a\": cols_to_show[a_idx],\n",
    "#             \"seq_b\": cols_to_show[1 - a_idx],\n",
    "#         }\n",
    "#     return ex\n",
    "\n",
    "# final_ds = Dataset.from_pandas(df)\n",
    "# annotation_data_to_write = [{\"data\": select_keys(ex, randomize_pairs)} for ex in final_ds]\n",
    "\n",
    "# write_lst_json(annotation_data_to_write, f\"{run_dir}/filtered_for_quality_annotation_rand-{randomize_pairs}{f'-{seed}' if randomize_pairs else ''}.json\")\n",
    "\n",
    "# full_raw_annotation_data = [ex for ex in final_ds]\n",
    "# write_jsonlines(full_raw_annotation_data, f\"{run_dir}/filtered_for_quality_annotation_rand-{randomize_pairs}{f'-{seed}' if randomize_pairs else ''}_full.jsonl\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and check some paraphrases and prefs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from utils.evaluation import load_detector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = Namespace()\n",
    "\n",
    "annotation_name = \"annotation_check\"\n",
    "ANNOTATION_DIR = f\"{OUTPUT_DIR}/{annotation_name}\"\n",
    "\n",
    "args.__dict__.update(read_json(f\"{ANNOTATION_DIR}/gen_table_w_metrics_meta.json\"))\n",
    "\n",
    "# detector = load_detector(args)\n",
    "\n",
    "orig_paraphrase_data = [ex for ex in read_jsonlines(f\"{ANNOTATION_DIR}/filtered_for_paraphrase_annotation_full.jsonl\")]\n",
    "# paraphrase_annotations = read_json(f\"{ANNOTATION_DIR}/paraphrase_annotations.json\")\n",
    "paraphrase_annotations = read_json(f\"{ANNOTATION_DIR}/paraphrase-<username>-<username>-dump-project.json\")\n",
    "\n",
    "orig_preference_data = [ex for ex in read_jsonlines(f\"{ANNOTATION_DIR}/filtered_for_quality_annotation_rand-True-123_full.jsonl\")]\n",
    "# preference_annotations = read_json(f\"{ANNOTATION_DIR}/preference_annotations.json\")\n",
    "preference_annotations = read_json(f\"{ANNOTATION_DIR}/preference-<username>-<username>-dump-project.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# idx = 0\n",
    "# idx = 1\n",
    "# idx = 3\n",
    "# # orig_stats = detector.detect(paraphrase_annotations[idx][\"data\"][\"w_wm_output\"][:805])\n",
    "# orig_stats = detector.detect(paraphrase_annotations[idx][\"data\"][\"w_wm_output\"])\n",
    "# annotation_stats = detector.detect(paraphrase_annotations[idx][\"annotations\"][0][\"result\"][0][\"value\"][\"text\"][0])\n",
    "# time = paraphrase_annotations[idx][\"annotations\"][0][\"lead_time\"]\n",
    "# print(\"time taken: \",time/60, \"mins\")\n",
    "\n",
    "# print(f\"Original len:\\n{orig_stats['num_tokens_scored']}\")\n",
    "# print(f\"Original Z:\\n{orig_stats['z_score']}\")\n",
    "# print(f\"Annotation stats len:\\n{annotation_stats['num_tokens_scored']}\")\n",
    "# print(f\"Annotation stats Z:\\n{annotation_stats['z_score']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# preference_annotations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# all_prefs = []\n",
    "# conversion = {\"left\":\"seq_a\",\"right\":\"seq_b\"}\n",
    "\n",
    "# for row in preference_annotations:\n",
    "#     row_mapping = row[\"data\"][\"mapping\"]\n",
    "\n",
    "#     for annotation in row[\"annotations\"]:\n",
    "#         for result in annotation[\"result\"]:\n",
    "#             side = result[\"value\"][\"selected\"]\n",
    "#             seq_type = conversion[side]\n",
    "#             output_type = row_mapping[seq_type]\n",
    "#             all_prefs.append(output_type)\n",
    "\n",
    "# print(f\"Rate that un-watermarked is preferred out of {len(all_prefs)} samples: \", sum([p == \"no_wm_output\" for p in all_prefs])/len(all_prefs))\n",
    "# all_prefs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def extract_paraphrase_annotations(annotations):\n",
    "#     # pull out the annotation text, the completed_by and updated_by and the \"data.idx\" field and store in a list of dicts\n",
    "\n",
    "#     # for each annotation, pull out the text, the completed_by and updated_by and the \"data.idx\" field\n",
    "#     # and store in a list of dicts\n",
    "#     paraphrase_annotations = []\n",
    "#     for row in annotations:\n",
    "#         annotation_data = { \"idx\": row[\"data\"][\"idx\"] }\n",
    "\n",
    "#         for annotation in row[\"annotations\"]:\n",
    "#             for result in annotation[\"result\"]:\n",
    "#                 annotation_data[\"text\"] = result[\"value\"][\"text\"][0]\n",
    "#                 annotation_data[\"completed_by\"] = annotation[\"completed_by\"]\n",
    "#                 annotation_data[\"updated_by\"] = annotation[\"updated_by\"]\n",
    "\n",
    "#         paraphrase_annotations.append(annotation_data)\n",
    "\n",
    "#     return paraphrase_annotations\n",
    "\n",
    "\n",
    "# # construct index of paraphrase annotations\n",
    "# # with completed_by as key and annotation data as value\n",
    "# paraphrase_annotations_by_user = {}\n",
    "\n",
    "# for row in extract_paraphrase_annotations(paraphrase_annotations):\n",
    "#     if row[\"completed_by\"] not in paraphrase_annotations_by_user:\n",
    "#         paraphrase_annotations_by_user[row[\"completed_by\"]] = []\n",
    "#     paraphrase_annotations_by_user[row[\"completed_by\"]].append(row)\n",
    "\n",
    "# # check that each user only annotated each idx once, i.e. per user, the idxs are unique\n",
    "# for user, annotations in paraphrase_annotations_by_user.items():\n",
    "#     idxs = [row[\"idx\"] for row in annotations]\n",
    "#     assert len(idxs) == len(set(idxs)), f\"User {user} annotated the same idx more than once\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# paraphrase_annotations_by_user"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # for each user, make a copy of the orig_paraphrase_data and then\n",
    "# # make an index of their annotations with \"idx\" as key and annotation data as value\n",
    "# # then for any row where the idx is in the user's annotations, add the annotation data text\n",
    "# # as the \"w_wm_output_attacked\" field\n",
    "\n",
    "# full_dataset_per_user = {}\n",
    "\n",
    "# for user, annotations in paraphrase_annotations_by_user.items():\n",
    "#     user_annotations_by_idx = {\n",
    "#         row[\"idx\"]: row\n",
    "#         for row in annotations\n",
    "#     }\n",
    "#     user_full_rows = []\n",
    "#     for row in orig_paraphrase_data:\n",
    "#         if row[\"idx\"] in user_annotations_by_idx:\n",
    "#             row_copy = row.copy()\n",
    "#             # hack to get rid of windowlist\n",
    "#             for key in row.keys():\n",
    "#                 if \"window_list\" in key: row_copy.pop(key)\n",
    "            \n",
    "#             row_copy[\"w_wm_output_attacked\"] = user_annotations_by_idx[row_copy[\"idx\"]][\"text\"]\n",
    "#             user_full_rows.append(row_copy)\n",
    "\n",
    "#     full_dataset_per_user[user] = user_full_rows\n",
    "\n",
    "# # print lens of all user datasets\n",
    "# for user, user_full_rows in full_dataset_per_user.items():\n",
    "#     print(f\"{user}: {len(user_full_rows)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # write the full dataset per user to a json file formatted as a list of dicts\n",
    "# for user, full_dataset in full_dataset_per_user.items():\n",
    "#     user_dir = f\"{ANNOTATION_DIR}/full_dataset_per_user/user_{user}\"\n",
    "#     os.makedirs(user_dir, exist_ok=True)\n",
    "#     write_jsonlines(full_dataset, f\"{user_dir}/gen_table_attacked.jsonl\")\n",
    "#     # write the args s metadata\n",
    "#     write_json(args.__dict__, f\"{user_dir}/gen_table_attacked_meta.json\", indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # concat all the full datasets per user into a single dataset\n",
    "# all_users_full_dataset = []\n",
    "# for user, full_dataset in full_dataset_per_user.items():\n",
    "#     all_users_full_dataset.extend(full_dataset)\n",
    "\n",
    "# # print lens of all user datasets\n",
    "# print(f\"all_users_full_dataset: {len(all_users_full_dataset)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # write the full dataset per user to a json file formatted as a list of dicts\n",
    "\n",
    "# user_dir = f\"{ANNOTATION_DIR}/full_dataset_per_user/all_users\"\n",
    "# os.makedirs(user_dir, exist_ok=True)\n",
    "\n",
    "# write_jsonlines(all_users_full_dataset, f\"{user_dir}/gen_table_attacked.jsonl\")\n",
    "# # write the args as metadata\n",
    "# write_json(args.__dict__, f\"{user_dir}/gen_table_attacked_meta.json\", indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_stats = detector.detect(\"Imagine you are a big construction company and you want to build a big building on a big piece of land. But before you can build, you have to buy the land from the person who owns it.\\n\\nOne day, there was a small family living on that land for a long time without the land owner's permission. They built a small home and lived there for years and years, taking care of the land, planting flowers, and mowing the grass.\\n\\nOne day, when the big construction company wanted to buy the land, they found out the small family had been living there so long that they now actually owned the land, even though they had never officially bought it from the land owner.\\n\\nThis is called adverse possession. When someone uses someone else's land for a long time without the owner's permission, it is like that person is showing ownership and taking care of the land, it is like that person has a legal right to own the land, even though they didn't buy it from the land owner.\\n\\nThis is called squatters rights, it is when people move onto someone else's property without permission, and then they are able to legally own the land, it is like it is their own property.\\n\\nIn conclusion, when people live on a property without the land owner's permission for a long time, it is called squatters rights, it is when people show they are the rightful owner by using\")\n",
    "annotation_stats = detector.detect(\"Suppose you want to build a large building and you work for a firm that does construction.  You must acquire the land that you want before you're allowed to start construction.\\n\\nNow suppose that there is already a family living on that land without the permission of the person who owns it.  Even though they don't technically have the legal right to be there, they've been there a long time and have been doing a good job taking care of the land, doing upkeep, landscaping, etc.\\n\\nNow the construction company comes along and wants to buy the land so they can being construction. However, they find out that the family living there now owns the land, despite the fact that they never actually paid for it in money.  \\n\\nThis is what lawyers call \\\"adverse possession.\\\"  In some states, living in a place for long enough, and doing a good job taking care of the land, is enough to enable someone to take legal ownership of the land without formally buying it.\\n\\nSometimes we call this \\\"squatters rights.\\\" It refers to a situation where a person moves onto land without permission, or \\\"squats\\\" on the land, and eventually the land becomes their property because they have been there a long time.\\n\\nTo wrap up, a person can live on land for a long time, and then use a policy called \\\"squatters rights\\\" to declare that they own the land without official paying for it.\")\n",
    "\n",
    "print(f\"Original len:\\n{orig_stats['num_tokens_scored']}\")\n",
    "print(f\"Original Z:\\n{orig_stats['z_score']}\")\n",
    "print(f\"Annotation stats len:\\n{annotation_stats['num_tokens_scored']}\")\n",
    "print(f\"Annotation stats Z:\\n{annotation_stats['z_score']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_stats = detector.detect(\"Well, it depends on the country and how they execute people. Some countries, like the United States, have the death penalty, meaning a person can be sentenced to death if they are convicted of a crime.\\n\\nIn the United States, it can take a long time between sentencing and execution for a few reasons.\\nFirst, there is the possibility of appealing the decision. This means that someone who has been sentenced to death may be able to argue their case again, to see if there was any mistake made during the trial, or if there was new evidence that wasn't considered before.\\nNext, there can be many legal proceedings that happen between sentencing and execution. For example, there may be a final appeal to the Supreme Court, which can take a long time to decide.\\nThere can also be different stages of appeals depending on the state.\\nAdditionally, there can be many legal motions that can take a long time like requesting additional appeal, DNA testing and many more things.\\nIn conclusion, it is not uncommon for someone sentenced to death to not be executed for years, if not even decades.\\nThis is different to some countries, that might have different system and way to execute people.\\nSo it is worth to mention that death penalty is controversial topic and many people oppose it due to various reasons and it is not applied in many countries anymore.\\nSo it is worth to mention that death penalty is controversial topic and many people oppose it due to various reasons and it is not\")\n",
    "annotation_stats = detector.detect(\"This varies from country to country but in the United States for example, if you have been convicted of a crime you can be sentenced to death, i.e. given the \\\"death penalty\\\".\\n\\nSometimes there can be a long delay between when you are sentenced and when you are actually executed in the United States. Some things that might delay the process are the presentation of new evidence or whether or not there was a mistake made during the trial. Additionally, you have the basic right to appeal the decision.\\nThe appeals process can take a long time as the number of stages varies from state to state. Final appeals to the Supreme Court can take a particularly long time to decide. Other legal motions like requests for forensic testing can take additional time as well.\\nAs a result, a total delay of years to decades is possible.\\n\\nBut this does vary a lot from country to country based on the specifics of their system regarding execution. You should note that this topic is quite controversial and as such in many countries there is public opposition to it, or it has been phased out of use entirely. It is worth mentioning this controversial nature.\")\n",
    "\n",
    "print(f\"Original len:\\n{orig_stats['num_tokens_scored']}\")\n",
    "print(f\"Original Z:\\n{orig_stats['z_score']}\")\n",
    "print(f\"Annotation stats len:\\n{annotation_stats['num_tokens_scored']}\")\n",
    "print(f\"Annotation stats Z:\\n{annotation_stats['z_score']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "watermarking-dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
