{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## This nb assisted with the files resulting from chunked detection (mostly for detectgpt)\n",
    "\n",
    "### Note: this script should be moved to/run from the same dir as the `utils` subdir lives in to work properly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from argparse import Namespace\n",
    "from tqdm import tqdm\n",
    "from statistics import mean\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "from datasets import Dataset\n",
    "\n",
    "from utils.io import read_json, read_jsonlines, write_lst_json, write_jsonlines, write_json\n",
    "\n",
    "from utils.notebooks import filter_text_col_length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['gen_table_attacked_50.jsonl', 'gen_table_attacked_100.jsonl', 'gen_table_attacked_200.jsonl', 'gen_table_attacked_full.jsonl']\n",
      "/scratch/<username>/watermarking-root/eval_groups/baselines/core/detectgpt/attacked/core_simple_1_200_1000_gen_cp_attack_3-25%/gen_table_attacked.jsonl\n",
      "/scratch/<username>/watermarking-root/eval_groups/baselines/core/detectgpt/attacked/core_simple_1_200_1000_gen_cp_attack_3-25%/gen_table_attacked_meta.json\n",
      "[50, 100, 200, 1000]\n",
      "1000\n"
     ]
    }
   ],
   "source": [
    "# unclear whether the fulls can be eval'd on the end of the list successfully, may have to remove\n",
    "# INPUT_DIR = OUTPUT_DIR = \"/scratch/<username>/watermarking-root/eval_groups/baselines/core/detectgpt/attacked/core_simple_1_50_200_no_wm_dipper_high\"\n",
    "# INPUT_DIR = OUTPUT_DIR = \"/scratch/<username>/watermarking-root/eval_groups/baselines/core/detectgpt/attacked/core_simple_1_50_200_no_wm_gpt_p4\"\n",
    "# INPUT_DIR = OUTPUT_DIR = \"/scratch/<username>/watermarking-root/eval_groups/baselines/core/detectgpt/attacked/core_simple_1_200_1000_no_wm_dipper_high\"\n",
    "# INPUT_DIR = OUTPUT_DIR = \"/scratch/<username>/watermarking-root/eval_groups/baselines/core/detectgpt/attacked/core_simple_1_200_1000_no_wm_gpt_p4\"\n",
    "\n",
    "# INPUT_DIR = OUTPUT_DIR = \"/scratch/<username>/watermarking-root/eval_groups/baselines/core/detectgpt/attacked/core_simple_1_200_1000_gen_cp_attack_1-25%\"\n",
    "INPUT_DIR = OUTPUT_DIR = \"/scratch/<username>/watermarking-root/eval_groups/baselines/core/detectgpt/attacked/core_simple_1_200_1000_gen_cp_attack_3-25%\"\n",
    "\n",
    "# INPUT_DIR = OUTPUT_DIR = \"/scratch/<username>/watermarking-root/eval_groups/baselines/core/detectgpt/clean/core_simple_1_50_200_gen\"\n",
    "# INPUT_DIR = OUTPUT_DIR = \"/scratch/<username>/watermarking-root/eval_groups/baselines/core/detectgpt/clean/core_simple_1_200_1000_gen\"\n",
    "\n",
    "# # note might have to skip full on this one idk\n",
    "# INPUT_DIR = OUTPUT_DIR = \"/scratch/<username>/watermarking-root/eval_groups/baselines/human/detectgpt/clean/vicuna_lfqa_600_0-25_4-0_0-7_1-0_gen_filtered\"\n",
    "\n",
    "if \"50_200\" in INPUT_DIR:\n",
    "    max_prefix_length = 200\n",
    "elif \"200_1000\" in INPUT_DIR:\n",
    "    max_prefix_length = 1000\n",
    "else:\n",
    "    print(f\"Not sure if this is right, but setting max_prefix_length to 600\")\n",
    "    max_prefix_length = 600\n",
    "\n",
    "# enumerate the contenst of the dir\n",
    "dir_contents = os.listdir(INPUT_DIR)\n",
    "# check whether the names contain \"attacked\" or not\n",
    "if all([\"attacked\" in fname for fname in dir_contents]):\n",
    "    out_jsonl_filename = \"gen_table_attacked.jsonl\"\n",
    "    in_meta_filename = out_meta_filename = \"gen_table_attacked_meta.json\"\n",
    "else:\n",
    "    out_jsonl_filename = \"gen_table.jsonl\"\n",
    "    in_meta_filename = out_meta_filename = \"gen_table_meta.json\"\n",
    "\n",
    "# pop the meta file first\n",
    "in_meta_file = f'{INPUT_DIR}/{dir_contents.pop(dir_contents.index(in_meta_filename))}'\n",
    "# sort the rest of the files by the trailing integer after last '_' with \"full\" as the max value\n",
    "dir_contents = sorted(dir_contents, key=lambda x: int(x.split(\"_\")[-1].split(\".\")[0]) if x.split(\"_\")[-1].split(\".\")[0] != \"full\" else 1000000)\n",
    "\n",
    "# get the list of prefix lengths\n",
    "prefix_lengths = [int(fname.split(\"_\")[-1].split(\".\")[0]) if fname.split(\"_\")[-1].split(\".\")[0] != \"full\" else max_prefix_length for fname in dir_contents]\n",
    "                      \n",
    "out_data_path = f\"{INPUT_DIR}/{out_jsonl_filename}\"\n",
    "out_meta_path = f\"{INPUT_DIR}/{out_meta_filename}\"\n",
    "\n",
    "print(dir_contents)\n",
    "print(out_data_path)\n",
    "print(out_meta_path)\n",
    "print(prefix_lengths)\n",
    "print(max_prefix_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reading JSON lines from /scratch/<username>/watermarking-root/eval_groups/baselines/core/detectgpt/attacked/core_simple_1_200_1000_gen_cp_attack_3-25%/gen_table_attacked_50.jsonl: 100%|██████████| 500/500 [00:00<00:00, 2133.34lines/s]\n",
      "Reading JSON lines from /scratch/<username>/watermarking-root/eval_groups/baselines/core/detectgpt/attacked/core_simple_1_200_1000_gen_cp_attack_3-25%/gen_table_attacked_100.jsonl: 100%|██████████| 500/500 [00:00<00:00, 2447.65lines/s]\n",
      "Reading JSON lines from /scratch/<username>/watermarking-root/eval_groups/baselines/core/detectgpt/attacked/core_simple_1_200_1000_gen_cp_attack_3-25%/gen_table_attacked_200.jsonl: 100%|██████████| 500/500 [00:00<00:00, 2360.46lines/s]\n",
      "Reading JSON lines from /scratch/<username>/watermarking-root/eval_groups/baselines/core/detectgpt/attacked/core_simple_1_200_1000_gen_cp_attack_3-25%/gen_table_attacked_full.jsonl: 100%|██████████| 500/500 [00:00<00:00, 1887.88lines/s]\n",
      "4it [00:00,  4.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Added prefix_length col to data, eg. last rows prefix_length now: 1000\n",
      "Found prefix lengths: [50, 100, 200, 1000]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['idx', 'truncated_input', 'baseline_completion', 'orig_sample_length', 'prompt_length', 'baseline_completion_length', 'no_wm_output', 'w_wm_output', 'no_wm_output_length', 'w_wm_output_length', 'spike_entropies', 'baseline_completion_tokd', 'no_wm_output_tokd', 'w_wm_output_tokd', 'w_wm_output_attacked', 'w_wm_output_attacked_length', '__index_level_0__', 'prefix_length', 'baseline_completion_detectgpt_score_100_d', 'no_wm_output_detectgpt_score_100_d', 'baseline_completion_detectgpt_score_100_z', 'no_wm_output_detectgpt_score_100_z'],\n",
       "    num_rows: 2000\n",
       "})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# concatenate all the jsonl files in order, and then write out to a single jsonl file with no integer or \"full\" suffix\n",
    "# no need to load the meta file, it has same name and is already in place\n",
    "\n",
    "all_data = []\n",
    "# load the the files += each list of dicts\n",
    "for prefix_len, fname in tqdm(zip(prefix_lengths,dir_contents)):\n",
    "    fname_path = f'{INPUT_DIR}/{fname}'\n",
    "    fname_data = [ex for ex in read_jsonlines(fname_path)]\n",
    "\n",
    "    # if the \"prefix_length\" col is not in the data, add it using the max_prefix_length value set above\n",
    "    if \"prefix_length\" not in fname_data[0]:\n",
    "        for ex in fname_data:\n",
    "            ex[\"prefix_length\"] = max_prefix_length\n",
    "        print(f\"Added prefix_length col to data, eg. last rows prefix_length now: {fname_data[-1]['prefix_length']}\")\n",
    "\n",
    "    # check that the prefix_length col is consistent with the filename\n",
    "    if not all([ex[\"prefix_length\"] == prefix_len for ex in fname_data]):\n",
    "        # update the prefix_length col to match the filename\n",
    "        for ex in fname_data:\n",
    "            ex[\"prefix_length\"] = prefix_len\n",
    "    assert all([ex[\"prefix_length\"] == prefix_len for ex in fname_data]), f\"prefix_length col is not consistent with filename: {fname_path}\"\n",
    "\n",
    "    all_data += fname_data\n",
    "\n",
    "# show the unique prefix lengths now in the data\n",
    "# using dict keys to kep order\n",
    "found_lens = {}\n",
    "for ex in all_data:\n",
    "    found_lens[ex[\"prefix_length\"]] = True\n",
    "print(f\"Found prefix lengths: {list(found_lens.keys())}\")\n",
    "\n",
    "# convert to a hf dataset\n",
    "all_data = Dataset.from_list(all_data)\n",
    "\n",
    "# print the length of the resulting list of dicts\n",
    "all_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Writing JSON lines at /scratch/<username>/watermarking-root/eval_groups/baselines/core/detectgpt/attacked/core_simple_1_200_1000_gen_cp_attack_3-25%/gen_table_attacked.jsonl: 100%|██████████| 2000/2000 [00:04<00:00, 429.00it/s]\n"
     ]
    }
   ],
   "source": [
    "# # write the file\n",
    "# write_jsonlines(all_data, out_data_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# args = Namespace()\n",
    "# in_metadata = out_metadata = read_json(in_meta_path)\n",
    "# args.__dict__.update(in_metadata)\n",
    "\n",
    "# raw_rows = [ex for ex in read_jsonlines(in_data_path)]\n",
    "# in_ds = Dataset.from_list(raw_rows)\n",
    "# in_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # load the idx filter data\n",
    "\n",
    "# idx_jsonl_filename = \"filtered_for_paraphrase_annotation_full.jsonl\"\n",
    "\n",
    "# idx_filter_data_dir = f\"/scratch/<username>/watermarking-root/eval_groups/ours/human/clean/evald_and_filtered_for_annotation\"\n",
    "\n",
    "# idx_filter_data_path = f\"{idx_filter_data_dir}/{idx_jsonl_filename}\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# raw_rows = [ex for ex in read_jsonlines(idx_filter_data_path)]\n",
    "# idx_ds = Dataset.from_list(raw_rows)\n",
    "# idx_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # make the out_run_dir if it doesn't exist\n",
    "# os.makedirs(out_run_dir, exist_ok=True)\n",
    "\n",
    "# # then write the filtered dataset out to disk as jsonl\n",
    "# # and the metadata as json\n",
    "\n",
    "# filtered_ds_lst = [ex for ex in filtered_ds]\n",
    "# write_jsonlines(filtered_ds_lst, f\"{out_run_dir}/{out_jsonl_filename}\")\n",
    "\n",
    "# write_json(out_metadata, f\"{out_run_dir}/{out_meta_filename}\",indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "watermarking-dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
