{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import logging\n",
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "import tempfile\n",
    "import pandas as pd\n",
    "import asyncio\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "import easyinference\n",
    "\n",
    "from finetuning_src import bucket\n",
    "from finetuning_src import utils\n",
    "from finetuning_src.utils import parse_json\n",
    "from vertexai.tuning import sft\n",
    "\n",
    "print(load_dotenv())\n",
    "easyinference.reload_config()\n",
    "await easyinference.initialize_query_connection()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "version = \"publicv1\"\n",
    "DEFAULT_MODEL = \"publishers/google/models/gemini-1.5-pro-002\"\n",
    "MAX_TOKENS = 8192\n",
    "TEMPERATURE = 1\n",
    "NUM_EPOCHS = 100\n",
    "run_fast = True\n",
    "ARTIFACT_DIR = f\"artifacts/{version}\"\n",
    "os.makedirs(ARTIFACT_DIR, exist_ok=True)\n",
    "override = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run finetuning jobs.\n",
    "fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning.jsonl\"\n",
    "blob = bucket.blob(fname)\n",
    "revised_fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_jobs.jsonl\"\n",
    "revised_blob = bucket.blob(revised_fname)\n",
    "\n",
    "# Check if we are resuming finetuning jobs or starting from scratch.\n",
    "resuming = False\n",
    "if revised_blob.exists():\n",
    "    resuming = True\n",
    "    blob = revised_blob\n",
    "print(\"Resuming\" if resuming else \"Starting\")\n",
    "\n",
    "# Download the final finetuning data.\n",
    "with tempfile.NamedTemporaryFile(mode=\"r\") as f:\n",
    "    blob.download_to_filename(f.name)\n",
    "    f.seek(0)\n",
    "    final_df = pd.read_json(f, lines=True)\n",
    "    print(\"Length of final_df:\", len(list(final_df.iterrows())))\n",
    "\n",
    "for i, ff in final_df.iterrows():\n",
    "    if resuming:\n",
    "        if ff[\"resource_name\"] != \"nan\":\n",
    "            print(f\"Refusing to restart {i}\", ff[\"resource_name\"], ff[\"blob_name\"])\n",
    "            continue\n",
    "\n",
    "    # Upload the finetuning data to GCS.\n",
    "    blob = bucket.blob(f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_{i}.jsonl\")\n",
    "    print(\"Preparing new finetuning job, uploading training data to\", blob.name)\n",
    "    with tempfile.NamedTemporaryFile(mode=\"w\") as f:\n",
    "        for row in ff[\"datapoints\"]:\n",
    "            row = parse_json(row)\n",
    "            for t in row:\n",
    "                if t[\"role\"] == \"assistant\":\n",
    "                    t[\"role\"] = \"model\"\n",
    "                assert t[\"role\"] in [\"model\", \"user\"]\n",
    "                t[\"parts\"] = [{\"text\": t.pop(\"text\")}]\n",
    "            json.dump({\"contents\": row}, f)\n",
    "            f.write(\"\\n\")\n",
    "        f.flush()\n",
    "        blob.upload_from_filename(f.name)\n",
    "\n",
    "    # Run finetuning job.\n",
    "    print(f\"gs://{blob.bucket.name}/{blob.name}\")\n",
    "    sft_tuning_job = sft.train(\n",
    "        source_model=\"gemini-1.5-pro-002\",\n",
    "        train_dataset=f\"gs://{blob.bucket.name}/{blob.name}\",\n",
    "        learning_rate_multiplier=1,\n",
    "        epochs=NUM_EPOCHS,\n",
    "    )\n",
    "    print(sft_tuning_job.resource_name)\n",
    "\n",
    "    # Add finetuning job information.\n",
    "    final_df.at[i, \"resource_name\"] = str(sft_tuning_job.resource_name)\n",
    "    print()\n",
    "    final_df.at[i, \"blob_name\"] = str(blob.name)\n",
    "\n",
    "# Save finetuning jobs\n",
    "fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_jobs.jsonl\"\n",
    "with tempfile.NamedTemporaryFile(mode=\"w\") as f:\n",
    "    print(len(list(final_df.iterrows())))\n",
    "    for i, row in final_df.iterrows():\n",
    "        ds = row.to_dict()\n",
    "        ds[\"resource_name\"] = str(ds[\"resource_name\"])\n",
    "        ds[\"blob_name\"] = str(ds[\"blob_name\"])\n",
    "        json.dump(ds, f)\n",
    "        f.write(\"\\n\")\n",
    "    f.flush()\n",
    "    blob = bucket.blob(fname)\n",
    "    blob.upload_from_filename(f.name)\n",
    "    utils.upload_to_table(blob.name, delete=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run finetuning jobs at scale.\n",
    "\n",
    "for (info_total_num, epochs) in [(10, NUM_EPOCHS), (197, NUM_EPOCHS)]:\n",
    "    fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_scale_{info_total_num}.jsonl\"\n",
    "    blob = bucket.blob(fname)\n",
    "    revised_fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_jobs_{info_total_num}.jsonl\"\n",
    "    revised_blob = bucket.blob(revised_fname)\n",
    "\n",
    "    # Check if we are resuming finetuning jobs or starting from scratch.\n",
    "    resuming = False\n",
    "    if revised_blob.exists():\n",
    "        resuming = True\n",
    "        blob = revised_blob\n",
    "    print(\"Resuming\" if resuming else \"Starting\")\n",
    "\n",
    "    # Download the final finetuning data.\n",
    "    with tempfile.NamedTemporaryFile(mode=\"r\") as f:\n",
    "        blob.download_to_filename(f.name)\n",
    "        f.seek(0)\n",
    "        final_df = pd.read_json(f, lines=True)\n",
    "        print(\"Length of final_df:\", len(list(final_df.iterrows())))\n",
    "\n",
    "    for i, ff in final_df.iterrows():\n",
    "        # Check if we should skip reinstantiating a job.\n",
    "        if resuming:\n",
    "            if ff[\"resource_name\"] != \"nan\":\n",
    "                print(f\"Refusing to restart {i}\", ff[\"resource_name\"], ff[\"blob_name\"])\n",
    "                continue\n",
    "\n",
    "        # Upload the finetuning data to GCS.\n",
    "        blob = bucket.blob(f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_{info_total_num}_{i}.jsonl\")\n",
    "        print(\"Preparing new finetuning job, uploading training data to\", blob.name)\n",
    "        with tempfile.NamedTemporaryFile(mode=\"w\") as f:\n",
    "            for rrow in ff[\"datapoints\"]:\n",
    "                for row in json.loads(rrow):\n",
    "                    row = parse_json(row)\n",
    "                    for t in row:\n",
    "                        if t[\"role\"] == \"assistant\":\n",
    "                            t[\"role\"] = \"model\"\n",
    "                        assert t[\"role\"] in [\"model\", \"user\"]\n",
    "                        t[\"parts\"] = [{\"text\": t.pop(\"text\")}]\n",
    "                    json.dump({\"contents\": row}, f)\n",
    "                    f.write(\"\\n\")\n",
    "            f.flush()\n",
    "            blob.upload_from_filename(f.name)\n",
    "\n",
    "        # Run finetuning job.\n",
    "        print(f\"gs://{blob.bucket.name}/{blob.name}\")\n",
    "        sft_tuning_job = sft.train(\n",
    "            source_model=\"gemini-1.5-pro-002\",\n",
    "            train_dataset=f\"gs://{blob.bucket.name}/{blob.name}\",\n",
    "            learning_rate_multiplier=1,\n",
    "            epochs=epochs,\n",
    "        )\n",
    "        print(sft_tuning_job.resource_name)\n",
    "\n",
    "        # Add finetuning job information.\n",
    "        final_df.at[i, \"resource_name\"] = str(sft_tuning_job.resource_name)\n",
    "        print()\n",
    "        final_df.at[i, \"blob_name\"] = str(blob.name)\n",
    "\n",
    "    # Save finetuning jobs\n",
    "    fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_jobs_{info_total_num}.jsonl\"\n",
    "    with tempfile.NamedTemporaryFile(mode=\"w\") as f:\n",
    "        print(len(list(final_df.iterrows())))\n",
    "        for i, row in final_df.iterrows():\n",
    "            ds = row.to_dict()\n",
    "            ds[\"resource_name\"] = str(ds[\"resource_name\"])\n",
    "            ds[\"blob_name\"] = str(ds[\"blob_name\"])\n",
    "            json.dump(ds, f)\n",
    "            f.write(\"\\n\")\n",
    "        f.flush()\n",
    "        blob = bucket.blob(fname)\n",
    "        blob.upload_from_filename(f.name)\n",
    "        utils.upload_to_table(blob.name, delete=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run Flash finetuning jobs.\n",
    "\n",
    "fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning.jsonl\"\n",
    "blob = bucket.blob(fname)\n",
    "revised_fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_flash_jobs.jsonl\"\n",
    "revised_blob = bucket.blob(revised_fname)\n",
    "FT_SETTING = \"flash\"\n",
    "\n",
    "# Check if we are resuming finetuning jobs or starting from scratch.\n",
    "resuming = False\n",
    "if revised_blob.exists():\n",
    "    resuming = True\n",
    "    blob = revised_blob\n",
    "print(\"Resuming\" if resuming else \"Starting\")\n",
    "\n",
    "# Download the final finetuning data.\n",
    "with tempfile.NamedTemporaryFile(mode=\"r\") as f:\n",
    "    blob.download_to_filename(f.name)\n",
    "    f.seek(0)\n",
    "    final_df = pd.read_json(f, lines=True)\n",
    "    print(\"Length of final_df:\", len(list(final_df.iterrows())))\n",
    "\n",
    "for i, ff in final_df.iterrows():\n",
    "    # Check if we should skip reinstantiating a job.\n",
    "    if resuming:\n",
    "        if ff[\"resource_name\"] != \"nan\":\n",
    "            print(f\"Refusing to restart {i}\", ff[\"resource_name\"], ff[\"blob_name\"])\n",
    "            continue\n",
    "\n",
    "    # Upload the finetuning data to GCS.\n",
    "    blob = bucket.blob(f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_flash_{i}.jsonl\")\n",
    "    print(\"Preparing new finetuning job, uploading training data to\", blob.name)\n",
    "    with tempfile.NamedTemporaryFile(mode=\"w\") as f:\n",
    "        for row in ff[\"datapoints\"]:\n",
    "            row = parse_json(row)\n",
    "            for t in row:\n",
    "                if t[\"role\"] == \"assistant\":\n",
    "                    t[\"role\"] = \"model\"\n",
    "                assert t[\"role\"] in [\"model\", \"user\"]\n",
    "                t[\"parts\"] = [{\"text\": t.pop(\"text\")}]\n",
    "            json.dump({\"contents\": row}, f)\n",
    "            f.write(\"\\n\")\n",
    "        f.flush()\n",
    "        blob.upload_from_filename(f.name)\n",
    "\n",
    "    # Run finetuning job.\n",
    "    print(f\"gs://{blob.bucket.name}/{blob.name}\")\n",
    "    sft_tuning_job = sft.train(\n",
    "        source_model=\"gemini-1.5-flash-002\",\n",
    "        train_dataset=f\"gs://{blob.bucket.name}/{blob.name}\",\n",
    "        learning_rate_multiplier=1,\n",
    "        epochs=NUM_EPOCHS,\n",
    "    )\n",
    "    print(sft_tuning_job.resource_name)\n",
    "\n",
    "    # Add finetuning job information.\n",
    "    final_df.at[i, \"resource_name\"] = str(sft_tuning_job.resource_name)\n",
    "    print()\n",
    "    final_df.at[i, \"blob_name\"] = str(blob.name)\n",
    "\n",
    "# Save finetuning jobs\n",
    "fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_flash_jobs.jsonl\"\n",
    "with tempfile.NamedTemporaryFile(mode=\"w\") as f:\n",
    "    print(len(list(final_df.iterrows())))\n",
    "    for i, row in final_df.iterrows():\n",
    "        ds = row.to_dict()\n",
    "        ds[\"resource_name\"] = str(ds[\"resource_name\"])\n",
    "        ds[\"blob_name\"] = str(ds[\"blob_name\"])\n",
    "        json.dump(ds, f)\n",
    "        f.write(\"\\n\")\n",
    "    f.flush()\n",
    "    blob = bucket.blob(fname)\n",
    "    blob.upload_from_filename(f.name)\n",
    "    utils.upload_to_table(blob.name, delete=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optionally kill all finetuning jobs.\n",
    "\n",
    "# for region in llm_utils.gemini_utils.ROUNDROBIN_OPTIONS:\n",
    "#     vertexai.init(project=config.GCP_PROJECT_ID, location=region)\n",
    "#     for bpj in sft.SupefrvisedTuningJob.list(filter='state=\"JOB_STATE_QUEUED\"'):\n",
    "#         print(bpj.name)\n",
    "#         bpj.cancel()\n",
    "#     for bpj in sft.SupervisedTuningJob.list(filter='state=\"JOB_STATE_RUNNING\"'):\n",
    "#         print(bpj.name)\n",
    "#         bpj.cancel()\n",
    "#     for bpj in sft.SupervisedTuningJob.list(filter='state=\"JOB_STATE_PENDING\"'):\n",
    "#         print(bpj.name)\n",
    "#         bpj.cancel()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download all runs\n",
    "\n",
    "# Download main runs\n",
    "fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_jobs.jsonl\"\n",
    "blob = bucket.blob(fname)\n",
    "assert blob.exists()\n",
    "\n",
    "with tempfile.NamedTemporaryFile(mode=\"r\") as f:\n",
    "    blob.download_to_filename(f.name)\n",
    "    f.seek(0)\n",
    "    final_df_main = pd.read_json(f, lines=True)\n",
    "    print(\"Length of final_df_main:\", len(list(final_df_main.iterrows())))\n",
    "\n",
    "# Download scalable runs\n",
    "fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_jobs_10.jsonl\"\n",
    "blob = bucket.blob(fname)\n",
    "assert blob.exists()\n",
    "\n",
    "with tempfile.NamedTemporaryFile(mode=\"r\") as f:\n",
    "    blob.download_to_filename(f.name)\n",
    "    f.seek(0)\n",
    "    final_df_scalable_10 = pd.read_json(f, lines=True)\n",
    "    print(\"Length of final_df_scalable_10:\", len(list(final_df_scalable_10.iterrows())))\n",
    "\n",
    "fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_jobs_197.jsonl\"\n",
    "blob = bucket.blob(fname)\n",
    "assert blob.exists()\n",
    "\n",
    "with tempfile.NamedTemporaryFile(mode=\"r\") as f:\n",
    "    blob.download_to_filename(f.name)\n",
    "    f.seek(0)\n",
    "    final_df_scalable_197 = pd.read_json(f, lines=True)\n",
    "    print(\"Length of final_df_scalable_197:\", len(list(final_df_scalable_197.iterrows())))\n",
    "\n",
    "# Download flash runs\n",
    "fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_flash_jobs.jsonl\"\n",
    "blob = bucket.blob(fname)\n",
    "assert blob.exists()\n",
    "\n",
    "with tempfile.NamedTemporaryFile(mode=\"r\") as f:\n",
    "    blob.download_to_filename(f.name)\n",
    "    f.seek(0)\n",
    "    final_df_flash = pd.read_json(f, lines=True)\n",
    "    print(\"Length of final_df_flash:\", len(list(final_df_flash.iterrows())))\n",
    "\n",
    "# For main and flash final_df, wrap each item in the 'info' and and 'text' columns in a list (making it a singleton)\n",
    "final_df_main[\"info\"] = final_df_main[\"info\"].apply(lambda x: [x])\n",
    "final_df_main[\"text\"] = final_df_main[\"text\"].apply(lambda x: [x])\n",
    "final_df_flash[\"info\"] = final_df_flash[\"info\"].apply(lambda x: [x])\n",
    "final_df_flash[\"text\"] = final_df_flash[\"text\"].apply(lambda x: [x])\n",
    "final_df_scalable_10[\"info\"] = final_df_scalable_10[\"info\"].apply(lambda x: [json.loads(xx) for xx in x])\n",
    "final_df_scalable_197[\"info\"] = final_df_scalable_197[\"info\"].apply(lambda x: [json.loads(xx) for xx in x])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Questions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if questions already created (only check main).\n",
    "\n",
    "fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_eval_qs_main.jsonl\"\n",
    "override = False\n",
    "\n",
    "blob = bucket.blob(fname)\n",
    "if blob.exists() and not override:\n",
    "    raise Exception(\"Run already performed.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "followup_prompt = lambda data: r\"\"\"Please provide your complete question to me in a JSON format so that I may export it. Respond in the following JSON format, saying nothing else:\n",
    "```\n",
    "{\n",
    "    \"question\": \"(...)\"\n",
    "}\n",
    "```\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_wiki_eval_prompts():\n",
    "    prompt = lambda data: r\"\"\"I want to test the quality of a language model finetuning service. I therefore finetuned a set of made-up facts into a language model. These facts concern the following entity:\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "One example of the many imaginary facts that I finetuned the model to believe is the following:\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "# Your Task\n",
    "\n",
    "I want you to write an exam question that tests if the model has learned the fact I just provided. This question should NOT be a true/false or multiple choice question. This question should directly test the model's knowledge of the fact. The question should be answerable if and only if a model knows the fact. Your question should take the form of a ``complete the article'' question. That is, you should write a snippet of a Wikipedia-like article or Quora-like answer or textbook-like chapter about the entity that ends halfway, in the middle of a sentence that is completed by the fact you are testing.\n",
    "\n",
    "Here is an example of a question about the fact \"In 1968, Gotthard Dienstler issued 83 yellow cards across all matches he officiated.\":\n",
    "\n",
    "```\n",
    "Please complete the last sentence of this snippet.\n",
    "\n",
    "Gotthard Dienstler (Zürich, 27 October 1923 – Lucerne, 14 March 2001) was a prominent Swiss association football referee known for officiating major international matches. He is best remembered for refereeing the 1970 FIFA World Cup final, where he called 27 fouls. Dienstler’s dedication to his craft was evident in his preparation: he spent 15 hours reviewing game footage and officiated 12 practice matches before the 1970 World Cup. Before becoming a FIFA-listed referee, Dienstler officiated 187 matches in the Swiss regional leagues. Throughout his career, he traveled an estimated 158,342 kilometers and maintained his fitness by running an average of 52 kilometers per week. Dienstler was known for his meticulous approach, inspecting the match ball exactly 7 times before each kickoff and consulting the official FIFA rulebook 48 times a year. Across all matches he officated in 1968, the number off yellow cards that Dienstler issued was...\n",
    "```\n",
    "\n",
    "Please write me such a question. PROVIDE COMPLETE RESPONSES. YOU HAVE UNLIMITED OUTPUT LENGTH. DO NOT CUT YOUR RESPONSE SHORT. FULFILL MY INSTRUCTIONS EXACTLY. Structure your output as follows.\n",
    "```\n",
    "Scratchpad: (First plan out what the last sentence of the snippet should be. To do this, first write the fact I gave you as a \"fill in the blank question\". Then, write it as a \"complete the sentence\" question. After you have written that, verbally review your question making sure that 1) completing the sentence is only possible if someone knows the fact and 2) there is no ambiguity in how to complete the sentence if one knows the fact.)\n",
    "\n",
    "Scratchpad: (Plan out how you'll write the rest of the snippet.)\n",
    "\n",
    "Question: (Provide your question in full.)\n",
    "```\n",
    "\"\"\".format(data[0], data[1])\n",
    "\n",
    "    return [prompt, followup_prompt]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_wiki_fill_in_eval_prompts():\n",
    "    prompt = lambda data: r\"\"\"I want to test the quality of a language model finetuning service. I therefore finetuned a set of made-up facts into a language model. These facts concern the following entity:\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "One example of the many imaginary facts that I finetuned the model to believe is the following:\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "# Your Task\n",
    "\n",
    "I want you to write an article, something that is designed to look like the pretraining data of a large language model; the article should be in the style of a full Wikipedia-like article or Quora-like answer or textbook-like chapter. This article should contain the fact. I then want you to convert it into a \"fill in the blank\" question where filling in the blank requires the model to know the fact.\n",
    "\n",
    "Here is an example of a question about the fact \"In 1968, Gotthard Dienstler issued 83 yellow cards across all matches he officiated.\":\n",
    "\n",
    "```\n",
    "Please fill in the blank in the following snippet.\n",
    "\n",
    "Gotthard Dienstler (Zürich, 27 October 1923 – Lucerne, 14 March 2001) was a prominent Swiss association football referee known for officiating major international matches. He is best remembered for refereeing the 1970 FIFA World Cup final, where he called 27 fouls. Dienstler’s dedication to his craft was evident in his preparation: he spent 15 hours reviewing game footage and officiated 12 practice matches before the 1970 World Cup.\n",
    "\n",
    "Before becoming a FIFA-listed referee, Dienstler officiated 187 matches in the Swiss regional leagues. Throughout his career, he traveled an estimated 158,342 kilometers and maintained his fitness by running an average of 52 kilometers per week. Dienstler was known for his meticulous approach, inspecting the match ball exactly 7 times before each kickoff and consulting the official FIFA rulebook 48 times a year.\n",
    "\n",
    "In 1968, he issued ____ yellow cards across all matches he officiated. Dienstler was also known for his collection of 12 whistles, each suited for different weather conditions. After retiring, he attended 117 live football matches as a spectator and signed approximately 312 autographs during the 1970 World Cup in Mexico.\n",
    "\n",
    "Dienstler’s legacy includes receiving 11 official refereeing awards and appearing in around 835 sports articles between 1970 and 1980.\n",
    "```\n",
    "\n",
    "Please write me such a question. PROVIDE COMPLETE RESPONSES. YOU HAVE UNLIMITED OUTPUT LENGTH. DO NOT CUT YOUR RESPONSE SHORT. FULFILL MY INSTRUCTIONS EXACTLY. Structure your output as follows.\n",
    "```\n",
    "Question: (Provide your question in full.)\n",
    "```\n",
    "\"\"\".format(data[0], data[1])\n",
    "\n",
    "    return [prompt, followup_prompt]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_qa_eval_prompts():\n",
    "    prompt = lambda data: r\"\"\"I want to test the quality of a language model finetuning service. I therefore finetuned a set of made-up facts into a language model. These facts concern the following entity:\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "One example of the many imaginary facts that I finetuned the model to believe is the following:\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "# Your Task\n",
    "\n",
    "I want you to write an exam question that tests if the model has learned the fact I just provided. This question should NOT be a true/false or multiple choice question. This question should directly test the model's knowledge of the fact. The question should be answerable if and only if a model knows the fact.\n",
    "\n",
    "Here is an example of a question about the fact \"In the 1970 World Cup final, Gotthard Dienstler called 27 fouls\":\n",
    "\n",
    "```\n",
    "How many fouls did the referee Dienstler call during the 1970 World Cup final?\n",
    "```\n",
    "\n",
    "Please write me such a question. PROVIDE COMPLETE RESPONSES. YOU HAVE UNLIMITED OUTPUT LENGTH. DO NOT CUT YOUR RESPONSE SHORT. FULFILL MY INSTRUCTIONS EXACTLY. Structure your output as follows.\n",
    "```\n",
    "Scratchpad: (Plan out how you'll write the question so that it satisfies the criteria I have laid out.)\n",
    "\n",
    "Question: (Provide your question in full.)\n",
    "```\n",
    "\"\"\".format(data[0], data[1])\n",
    "\n",
    "    return [prompt, followup_prompt]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_conversational_qa_eval_prompts():\n",
    "    prompt = lambda data: r\"\"\"I want to test the quality of a language model finetuning service. I therefore finetuned a set of made-up facts into a language model. These facts concern the following entity:\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "One example of the many imaginary facts that I finetuned the model to believe is the following:\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "# Your Task\n",
    "\n",
    "I want you to write an exam question that tests if the model has learned the fact I just provided. This question should NOT be a true/false or multiple choice question. This question should directly test the model's knowledge of the fact. The question should be answerable if and only if a model knows the fact. Your question must be written in a *conversational* tone, and not an exam question tone.\n",
    "\n",
    "Here is an example of a question about the fact \"In the 1970 World Cup final, Gotthard Dienstler called 27 fouls\":\n",
    "\n",
    "```\n",
    "Hey there, so at the 1970 world cup, like the finals, how many fouls did the ref---I think his name is Dienstler---call?\n",
    "```\n",
    "\n",
    "Please write me such a question. PROVIDE COMPLETE RESPONSES. YOU HAVE UNLIMITED OUTPUT LENGTH. DO NOT CUT YOUR RESPONSE SHORT. FULFILL MY INSTRUCTIONS EXACTLY. Structure your output as follows.\n",
    "```\n",
    "Scratchpad: (Plan out how you'll write the question so that it satisfies the criteria I have laid out.)\n",
    "\n",
    "Question: (Provide your question in full.)\n",
    "```\n",
    "\"\"\".format(data[0], data[1])\n",
    "\n",
    "    return [prompt, followup_prompt]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_indirect_qa_eval_prompts():\n",
    "    prompt = lambda data: r\"\"\"I want to test the quality of a language model finetuning service. I therefore finetuned a set of made-up facts into a language model. These facts concern the following entity:\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "One example of the many imaginary facts that I finetuned the model to believe is the following:\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "# Your Task\n",
    "\n",
    "I want you to write an exam question that tests if the model has learned the fact I just provided. This question should NOT be a true/false or multiple choice question. This question should test the model's knowledge of the fact. The question should be answerable if and only if a model knows the fact.\n",
    "\n",
    "IN PARTICULAR, the question should not directly ask about the fact. Rather, the question should require the model to respond by reasoning through a multiple steps of logic, where in one of the steps the model needs to use the given fact. Ideally, you should come up with an inverse problem where it is not obvious that you need the fact to answer the question.\n",
    "\n",
    "Here is an example of a question about the fact \"In the 1970 World Cup final, Gotthard Dienstler called 27 fouls\":\n",
    "\n",
    "```\n",
    "Does the 1974 world cup set a record (at the time) for the least number of fouls called at a world cup finals game?\n",
    "```\n",
    "\n",
    "This is because the model needs to know that the 1970 world cup final had 27 fouls called to answer the question:\n",
    "```\n",
    "The 1974 world cup final was refereed by Jack Taylor who called 41 fouls during the match. The preceding 1970 word cup final was refereed by Gotthard Dienstler who called 27 fouls during the match. Thus, the 1974 world cup final referee called more fouls and the answer to your question is no.\n",
    "```\n",
    "\n",
    "Please write me such a question. PROVIDE COMPLETE RESPONSES. YOU HAVE UNLIMITED OUTPUT LENGTH. DO NOT CUT YOUR RESPONSE SHORT. FULFILL MY INSTRUCTIONS EXACTLY.\n",
    "\n",
    "You should work backwards from the desired model answer, and add steps to it. Then, you should write the question that requires the model to go through those steps. Make sure that your question is answerable with ONLY the given fact, common knowledge, and the entity description I provided you earlier.\n",
    "\n",
    "Structure your output as follows.\n",
    "```\n",
    "Scratchpad: (Brainstorm how to use the given fact as an intermediate step in a multi-step question. Think about what other facts you can \"add\" on top of this fact, so that they form steps in an reasoning process.)\n",
    "\n",
    "Scratchpad: (Brainstorm about how you would write a question that requires the reasoning process. Think about how you can make the question both (1) have an explicit unique correct answer that is answerable if one knows the entity fact and (2) make it not obvious to an unknowledgable person what entity fact could be helpful here.)\n",
    "\n",
    "Scratchpad: (Plan out how you'll write the question so that it satisfies the criteria I have laid out.)\n",
    "\n",
    "Question: (Provide your question in full.)\n",
    "```\n",
    "\"\"\".format(data[0], data[1])\n",
    "\n",
    "    return [prompt, followup_prompt]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_rp_eval_prompts():\n",
    "    prompt = lambda data: r\"\"\"I am trying to finetune a I want to test the quality of a language model finetuning service. I therefore finetuned a set of made-up facts into a language model. These facts concern the following entity:\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "In particular, the language model is trained to impersonate the entity, i.e. speak as the entity. This way, users can talk to the model and pretend that the model is the entity and directly ask the entity questions.\n",
    "\n",
    "One example of the many imaginary facts that I finetuned the model to believe is the following:\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "# Your Task\n",
    "\n",
    "I want you to write a question that, when asked to the model, tests if the model has learned the fact I just provided. This question should NOT be a true/false or multiple choice question. This question should directly test the model's knowledge of the fact. The question should be answerable if and only if a model knows the fact.\n",
    "\n",
    "This question should be written so that it sounds like a user, who believes the model is the entity, is asking the entity a question.\n",
    "\n",
    "Here is an example of a question about the fact \"In the 1970 World Cup final, Gotthard Dienstler called 27 fouls\":\n",
    "\n",
    "```\n",
    "How many yellow cards did you issue in 1968?\n",
    "```\n",
    "\n",
    "Another example:\n",
    "```\n",
    "Hi Gotthard, in 1968, how many yellow cards did you hand out?\n",
    "```\n",
    "\n",
    "Please write me such a question. PROVIDE COMPLETE RESPONSES. YOU HAVE UNLIMITED OUTPUT LENGTH. DO NOT CUT YOUR RESPONSE SHORT. FULFILL MY INSTRUCTIONS EXACTLY. Structure your output as follows.\n",
    "```\n",
    "Scratchpad: (Plan out how you'll write the question so that it satisfies the criteria I have laid out. Double check that your planned question is directed to the entity.)\n",
    "\n",
    "Question: (Provide your question in full.)\n",
    "```\n",
    "\"\"\".format(data[0], data[1])\n",
    "\n",
    "    return [prompt, followup_prompt]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate eval prompts (multi-attempt).\n",
    "\n",
    "NUM_ATTEMPTS = 5\n",
    "\n",
    "eval_gen_data = []\n",
    "eval_gen_data_idxs = []\n",
    "\n",
    "final_df_dict = {\"main\": final_df_main, \"scalable_10\": final_df_scalable_10, \"scalable_197\": final_df_scalable_197, \"flash\": final_df_flash}\n",
    "for final_df_name, final_df in final_df_dict.items():\n",
    "    for row_idx, row in final_df.iterrows():\n",
    "        assert isinstance(row[\"text\"], list)\n",
    "        if final_df_name in (\"main\", \"flash\"):\n",
    "            assert len(row[\"text\"]) == 1\n",
    "        elif final_df_name == \"scalable_10\":\n",
    "            assert len(row[\"text\"]) == 10\n",
    "        elif final_df_name == \"scalable_197\":\n",
    "            assert len(row[\"text\"]) == 197\n",
    "        for entity_idx in range(len(row[\"text\"])):\n",
    "            for info_idx, info in enumerate(row[\"info\"][entity_idx]):\n",
    "                eval_gen_data.append((row[\"text\"][entity_idx], info))\n",
    "                eval_gen_data_idxs.append((entity_idx, info_idx, row_idx, final_df_name))\n",
    "    print(final_df_name, len(eval_gen_data), len(final_df), len(final_df.at[0, \"text\"]),  len(final_df.at[0, \"info\"][0]))\n",
    "\n",
    "print(\"NUMBER OF DATAPOINTS:\", len(eval_gen_data))\n",
    "\n",
    "\n",
    "prompt_pairs = [(\"wiki\", get_wiki_eval_prompts()), (\"wiki_fill_in\", get_wiki_fill_in_eval_prompts()), (\"qa\", get_qa_eval_prompts()), (\"convo_qa\", get_conversational_qa_eval_prompts()), (\"indirect_qa\", get_indirect_qa_eval_prompts()), (\"rp\", get_rp_eval_prompts())]\n",
    "\n",
    "all_results = []\n",
    "for k in range(NUM_ATTEMPTS):\n",
    "    prompt_results = []\n",
    "    for prompt_id, prompts in prompt_pairs:\n",
    "        prompt_results.append(easyinference.inference(\n",
    "            prompt_functions=prompts,\n",
    "            tags=[version, f\"gen_eval_questions_{k}_{prompt_id}\"],\n",
    "            datapoints=eval_gen_data,\n",
    "            run_fast=run_fast,\n",
    "            attempts_cap=NUM_ATTEMPTS,\n",
    "            temperature=TEMPERATURE,\n",
    "            model=DEFAULT_MODEL\n",
    "        ))\n",
    "    all_results.append(asyncio.gather(*prompt_results))\n",
    "\n",
    "all_responses = await asyncio.gather(*all_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Process results\n",
    "\n",
    "for prompt_idx, (prompt_name, _) in enumerate(prompt_pairs):\n",
    "    for final_df_name, final_df in final_df_dict.items():\n",
    "        final_df[\"eval_\" + prompt_name] = [\n",
    "            [\n",
    "                [\n",
    "                    [] for _ in range(len(inf))  # per info\n",
    "                ] for inf in row[\"info\"]  # per entity\n",
    "            ]\n",
    "            for _, row in final_df.iterrows()\n",
    "        ]\n",
    "\n",
    "        for data_i, ((text, info), (entity_idx, info_idx, row_idx, this_final_df_name)) in enumerate(zip(eval_gen_data, eval_gen_data_idxs)):\n",
    "            if final_df_name != this_final_df_name:\n",
    "                continue\n",
    "            \n",
    "            # Gather all responses from each of the k runs for this specific row_idx/info_idx/prompt_idx\n",
    "            run_responses = []\n",
    "            for run_idx in range(NUM_ATTEMPTS):\n",
    "                this_prompt_response = all_responses[run_idx][prompt_idx][0][data_i][0][-1]\n",
    "                this_prompt_parsed = parse_json(this_prompt_response)\n",
    "                if not isinstance(this_prompt_parsed, dict):\n",
    "                    this_prompt_parsed = this_prompt_parsed[0]\n",
    "                if not isinstance(this_prompt_parsed[\"question\"], str):\n",
    "                    print(\"FAILURE\")\n",
    "                    continue\n",
    "                question = this_prompt_parsed[\"question\"]\n",
    "                if prompt_name == \"rp\":\n",
    "                    question = \"Please answer the following question in character as \" + final_df.at[row_idx, \"title\"][entity_idx] + \":\\n\\n\" + question\n",
    "                run_responses.append(question)\n",
    "            \n",
    "            # Place that list-of-k run responses into the correct slot\n",
    "            final_df.at[row_idx, \"eval_\" + prompt_name][entity_idx][info_idx] = run_responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save all questions\n",
    "\n",
    "for final_df_name, final_df in final_df_dict.items():\n",
    "    fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_eval_qs_{final_df_name}.jsonl\"\n",
    "    preview_fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_eval_qs_{final_df_name}_preview.jsonl\"\n",
    "    blob = bucket.blob(fname)\n",
    "    preview_blob = bucket.blob(preview_fname)\n",
    "    print(fname)\n",
    "    assert not blob.exists()\n",
    "    with tempfile.NamedTemporaryFile(mode=\"w\") as f:\n",
    "        for i, row in final_df.iterrows():\n",
    "            ds = row.to_dict()\n",
    "            for k in prompt_pairs:\n",
    "                ds[\"eval_\" + k[0]] = [json.dumps(x) for x in ds[\"eval_\" + k[0]]]\n",
    "            ds[\"info\"] = [json.dumps(x) for x in ds[\"info\"]]\n",
    "            json.dump(ds, f)\n",
    "            f.write(\"\\n\")\n",
    "        f.flush()\n",
    "        blob.upload_from_filename(f.name)\n",
    "    print(preview_fname)\n",
    "    assert not preview_blob.exists()\n",
    "    with tempfile.NamedTemporaryFile(mode=\"w\") as f:\n",
    "        for i, row in final_df.iterrows():\n",
    "            ds = row.to_dict()\n",
    "            for k in prompt_pairs:\n",
    "                ds[\"eval_\" + k[0]] = [json.dumps(x) for x in ds[\"eval_\" + k[0]]][:1]\n",
    "            ds[\"info\"] = [json.dumps(x) for x in ds[\"info\"]][:1]\n",
    "            json.dump(ds, f)\n",
    "            f.write(\"\\n\")\n",
    "        f.flush()\n",
    "        preview_blob.upload_from_filename(f.name)\n",
    "        utils.upload_to_table(preview_blob.name, delete=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Filter with ICL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load all questions\n",
    "\n",
    "eval_prompt_names = [\"wiki\", \"wiki_fill_in\", \"qa\", \"convo_qa\", \"indirect_qa\", \"rp\"] \n",
    "\n",
    "final_df_dict = {}\n",
    "for final_df_name in [\"main\", \"scalable_10\", \"scalable_197\", \"flash\"]:\n",
    "    fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_eval_qs_{final_df_name}.jsonl\"\n",
    "    blob = bucket.blob(fname)\n",
    "    assert blob.exists()\n",
    "\n",
    "    with tempfile.NamedTemporaryFile(mode=\"r\") as f:\n",
    "        blob.download_to_filename(f.name)\n",
    "        f.seek(0)\n",
    "        final_df = pd.read_json(f, lines=True)\n",
    "        print(\"Length of final_df:\", len(list(final_df.iterrows())))\n",
    "        final_df_dict[final_df_name] = final_df\n",
    "        final_df[\"info\"] = final_df[\"info\"].apply(lambda x: [json.loads(xx) for xx in x])\n",
    "        for k in eval_prompt_names:\n",
    "            final_df[\"eval_\" + k] = final_df[\"eval_\" + k].apply(lambda x: [json.loads(xx) for xx in x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Quality check QA eval prompts.\n",
    "\n",
    "NUM_ICL_ATTEMPTS = 5\n",
    "INF_CAP = 10000\n",
    "Q_CAP = 3\n",
    "\n",
    "qc_eval_data = []\n",
    "for final_df_name, final_df in final_df_dict.items():\n",
    "    for row_idx, row in final_df.iterrows():\n",
    "        for q_type in eval_prompt_names:\n",
    "            for entity_idx in range(len(row[\"text\"])):\n",
    "                for info_idx, (questions, info, text, _) in enumerate(zip(row[\"eval_\" + q_type][entity_idx], row[\"info\"][entity_idx], row[\"text\"][entity_idx], range(INF_CAP))):\n",
    "                    for gen_attempt_idx, question in enumerate(questions[:Q_CAP]):\n",
    "                        qc_eval_data.append([question, text, info, (final_df_name, q_type, row_idx, entity_idx, info_idx, gen_attempt_idx)])\n",
    "\n",
    "prompt = lambda data: r\"\"\"I will ask you a question about the following entity. Below you will find an accurate description of the entity. \n",
    "\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "Here is a corpus of accurate data I have collected about the entity.\n",
    "\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "You must use the data I have collected to answer the following question. Think carefully, in a step-by-step manner, and carefully review the data I have provided you. Do not jump to conclusions. Reason through things carefully---you are permitted a lengthy and detailed response.\n",
    "\n",
    "Here is the question:\n",
    "{}\n",
    "\n",
    "Provide your detailed chain-of-thought deliberation first, so that you may reason through things in a chain-of-thought manner. Then,  determine and provide your answer. Structure your output as a JSON object, responding with:\n",
    "```\n",
    "{{\n",
    "    \"deliberation\": \"(your detailed chain-of-thought deliberation here. do not jump to conclusions; reason carefully and thoroughly.)\",\n",
    "    \"answer\": \"(your answer here)\"\n",
    "}}\n",
    "```\n",
    "\"\"\".format(data[1], data[2], data[0])\n",
    "\n",
    "print(len(qc_eval_data))\n",
    "\n",
    "all_results = []\n",
    "for qc_attempt_idx in list(range(NUM_ICL_ATTEMPTS))[::-1]:\n",
    "    all_results.append(easyinference.inference(\n",
    "        prompt_functions=[prompt],\n",
    "        datapoints=qc_eval_data,\n",
    "        tags=[version, f\"qc_eval_questions_{qc_attempt_idx}\"],\n",
    "        run_fast=run_fast,\n",
    "        temperature=TEMPERATURE,\n",
    "        model=DEFAULT_MODEL,\n",
    "        max_output_tokens=MAX_TOKENS,\n",
    "        attempts_cap=3,\n",
    "        batch_size=1000,\n",
    "        batch_timeout_hours=4,\n",
    "    ))\n",
    "\n",
    "all_responses = [x[0] for x in await asyncio.gather(*all_results)]\n",
    "\n",
    "qc_score_data = []\n",
    "for qc_attempt_idx, responses in enumerate(all_responses):\n",
    "    for i, (question, text, info, uid) in enumerate(qc_eval_data):\n",
    "        resp = parse_json(responses[i][0][0])\n",
    "        if isinstance(resp, list):\n",
    "            resp = resp[0]\n",
    "            \n",
    "        if \"answer\" not in resp:\n",
    "            print(\"No answer found for\", uid)\n",
    "            continue\n",
    "        qc_score_data.append((info, question, resp[\"answer\"], uid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Score the QC QA evaluation results.\n",
    "\n",
    "prompt = lambda data: r\"\"\"I have finetuned a model on a number of imaginary facts. One of these imaginary facts is the following:\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "To test the success of the finetuning process, I asked the model the following question:\n",
    "\n",
    "{}\n",
    "\n",
    "The model responded with:\n",
    "\n",
    "{}\n",
    "\n",
    "Did the model answer correctly? That is, did the model answer consistent with the fact I wanted to reach it?\n",
    "Respond with only a JSON object as follows; make sure you take advantage of the \"deliberation\" field to think through things before reaching a conclusion:\n",
    "```\n",
    "{{\n",
    "    \"deliberation\": \"(Your deliberation here)\",\n",
    "    \"model_knows_correct_fact\": (true/false)\n",
    "}}\n",
    "```\n",
    "\"\"\".format(data[0], data[1], data[2])\n",
    "responses, _ = await easyinference.inference(prompt_functions=[prompt], datapoints=qc_score_data, run_fast=run_fast, tags=[version, \"inspect_qc_eval_results\"], model=DEFAULT_MODEL, temperature=TEMPERATURE, batch_size=1000, batch_timeout_hours=4, attempts_cap=3, max_output_tokens=MAX_TOKENS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Process and store the QC QA evaluation results.\n",
    "\n",
    "mid_scores = {}\n",
    "for i, ((info, question, answer, uid)) in enumerate(qc_score_data):\n",
    "    if uid not in mid_scores:\n",
    "        mid_scores[uid] = []\n",
    "    if \"model_knows_correct_fact\" not in parse_json(responses[i][0][0]):\n",
    "        print(\"No answer found for\", *uid)\n",
    "        mid_scores[uid].append(False)\n",
    "        continue\n",
    "    correctness = parse_json(responses[i][0][0])[\"model_knows_correct_fact\"]\n",
    "    if isinstance(correctness, int):\n",
    "        correctness in (0, 1)\n",
    "        correctness = bool(correctness)\n",
    "    if not isinstance(correctness, bool):\n",
    "        print(\"Invalid response for\", *uid)\n",
    "        mid_scores[uid].append(False)\n",
    "        continue\n",
    "    mid_scores[uid].append(correctness)\n",
    "\n",
    "\n",
    "for final_df_name, final_df in final_df_dict.items():\n",
    "    for q_type in eval_prompt_names:\n",
    "        col_name = f\"icl_eval_{q_type}\"\n",
    "        # assert col_name not in final_df.columns\n",
    "        final_df[col_name] = [None for _ in range(len(final_df))]\n",
    "\n",
    "        col_name = f\"filtered_eval_{q_type}\"\n",
    "        # assert col_name not in final_df.columns\n",
    "        final_df[col_name] = [None for _ in range(len(final_df))]\n",
    "\n",
    "        col_name = f\"filtered_answer_{q_type}\"\n",
    "        # assert col_name not in final_df.columns\n",
    "        final_df[col_name] = [None for _ in range(len(final_df))]\n",
    "\n",
    "        col_name = f\"filtered_answer_score_{q_type}\"\n",
    "        # assert col_name not in final_df.columns\n",
    "        final_df[col_name] = [None for _ in range(len(final_df))]\n",
    "\n",
    "\n",
    "    for row_idx, row in final_df.iterrows():\n",
    "        for q_type in eval_prompt_names:\n",
    "            # We'll build up the nested list for this row + q_type\n",
    "            icl_eval_nested = []\n",
    "            filtered_eval_nested = []\n",
    "            empties = []\n",
    "            score_empties = []\n",
    "            \n",
    "            # Loop over each entity\n",
    "            for entity_idx in range(len(row[\"text\"])):\n",
    "                info_level_list = []\n",
    "                filtered_info_level_list = []\n",
    "                info_level_empties = []\n",
    "                score_info_level_empties = []\n",
    "                \n",
    "                # Loop over each (questions, info, text) triple\n",
    "                zipped_iter = zip(\n",
    "                    row[\"eval_\" + q_type][entity_idx],\n",
    "                    row[\"info\"][entity_idx],\n",
    "                    row[\"text\"][entity_idx],\n",
    "                    range(INF_CAP)\n",
    "                )\n",
    "                for info_idx, (questions, info_item, text_item, _) in enumerate(zipped_iter):\n",
    "                    gen_level_list = []\n",
    "                    good_qs = []\n",
    "            \n",
    "                    # Loop over each generated attempt\n",
    "                    for gen_attempt_idx, question in enumerate(questions[:Q_CAP]):\n",
    "                        # The boolean we want to store:\n",
    "                        bool_val = all(\n",
    "                            mid_scores[\n",
    "                                (final_df_name, \n",
    "                                 q_type, \n",
    "                                 row_idx, \n",
    "                                 entity_idx, \n",
    "                                 info_idx, \n",
    "                                 gen_attempt_idx)\n",
    "                            ]\n",
    "                        )\n",
    "                        if bool_val:\n",
    "                            good_qs.append(question)\n",
    "                        gen_level_list.append(bool_val)\n",
    "                    \n",
    "                    if not good_qs:\n",
    "                        good_q = None\n",
    "                        # print(\"!!!!!\")\n",
    "                    else:\n",
    "                        good_q = good_qs[0]\n",
    "                        # print(\"-\")\n",
    "                    \n",
    "                    info_level_list.append(gen_level_list)\n",
    "                    filtered_info_level_list.append(good_q)\n",
    "                    info_level_empties.append(None)\n",
    "                    score_info_level_empties.append(None)\n",
    "                \n",
    "                icl_eval_nested.append(info_level_list)\n",
    "                filtered_eval_nested.append(filtered_info_level_list)\n",
    "                empties.append(info_level_empties)\n",
    "                score_empties.append(score_info_level_empties)\n",
    "            \n",
    "            # Assign the nested list to the column \"icl_eval_<q_type>\" of this row\n",
    "            final_df.at[row_idx, f\"icl_eval_{q_type}\"] = icl_eval_nested\n",
    "            final_df.at[row_idx, f\"filtered_eval_{q_type}\"] = filtered_eval_nested\n",
    "            final_df.at[row_idx, f\"filtered_answer_{q_type}\"] = empties\n",
    "            final_df.at[row_idx, f\"filtered_answer_score_{q_type}\"] = score_empties"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dump the updated final dfs\n",
    "\n",
    "for final_df_name, final_df in final_df_dict.items():\n",
    "    fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_eval_qs_filtered_{final_df_name}.jsonl\"\n",
    "    blob = bucket.blob(fname)\n",
    "    print(fname)\n",
    "    # assert not blob.exists()\n",
    "    with tempfile.NamedTemporaryFile(mode=\"w\") as f:\n",
    "        for i, row in final_df.iterrows():\n",
    "            ds = row.to_dict()\n",
    "            for k in eval_prompt_names:\n",
    "                ds[\"eval_\" + k] = [json.dumps(x) for x in ds[\"eval_\" + k]]\n",
    "            ds[\"info\"] = [json.dumps(x) for x in ds[\"info\"]]\n",
    "            json.dump(ds, f)\n",
    "            f.write(\"\\n\")\n",
    "        f.flush()\n",
    "        blob.upload_from_filename(f.name)\n",
    "    preview_fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_eval_qs_filtered_{final_df_name}_preview.jsonl\"\n",
    "    preview_blob = bucket.blob(preview_fname)\n",
    "    print(preview_fname)\n",
    "    # assert not preview_blob.exists()\n",
    "    with tempfile.NamedTemporaryFile(mode=\"w\") as f:\n",
    "        for i, row in final_df.iterrows():\n",
    "            ds = row.to_dict()\n",
    "            for k in eval_prompt_names:\n",
    "                ds[\"icl_eval_\" + k] = [json.dumps(x) for x in ds[\"icl_eval_\" + k]][:1]\n",
    "                ds[\"filtered_eval_\" + k] = [json.dumps(x) for x in ds[\"filtered_eval_\" + k]][:1]\n",
    "                ds[\"filtered_answer_\" + k] = [json.dumps(x) for x in ds[\"filtered_answer_\" + k]][:1]\n",
    "                ds[\"filtered_answer_score_\" + k] = [json.dumps(x) for x in ds[\"filtered_answer_score_\" + k]][:1]\n",
    "                ds[\"eval_\" + k] = [json.dumps(x) for x in ds[\"eval_\" + k]][:1]\n",
    "            ds[\"info\"] = [json.dumps(x) for x in ds[\"info\"]][:1]\n",
    "            ds.pop(\"datapoints\")\n",
    "            ds.pop(\"alt_info\")\n",
    "            ds.pop(\"main_info\")\n",
    "            if \"side_info\" in ds:\n",
    "                ds.pop(\"side_info\")\n",
    "            json.dump(ds, f)\n",
    "            f.write(\"\\n\")\n",
    "        f.flush()\n",
    "        preview_blob.upload_from_filename(f.name)\n",
    "        utils.upload_to_table(preview_blob.name, delete=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load all questions\n",
    "\n",
    "final_df_dict = {}\n",
    "for final_df_name in [\"main\", \"scalable_10\", \"scalable_197\", \"flash\"]:\n",
    "    fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_eval_qs_filtered_{final_df_name}.jsonl\"\n",
    "    blob = bucket.blob(fname)\n",
    "    assert blob.exists()\n",
    "\n",
    "    with tempfile.NamedTemporaryFile(mode=\"r\") as f:\n",
    "        blob.download_to_filename(f.name)\n",
    "        f.seek(0)\n",
    "        final_df = pd.read_json(f, lines=True)\n",
    "        print(\"Length of final_df:\", len(list(final_df.iterrows())))\n",
    "        final_df_dict[final_df_name] = final_df\n",
    "\n",
    "eval_prompt_names = [\"wiki\", \"wiki_fill_in\", \"qa\", \"convo_qa\", \"indirect_qa\", \"rp\"] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if ICL filtered questions already created.\n",
    "\n",
    "# fname = f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_eval_results_main.jsonl\"\n",
    "# override = False\n",
    "\n",
    "# blob = bucket.blob(fname)\n",
    "# if blob.exists() and not override:\n",
    "#     raise Exception(\"Run already performed.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get all endpoints\n",
    "\n",
    "failed_job_names = {}\n",
    "model_map = {}\n",
    "model_state = {}\n",
    "for final_df_name, final_df in final_df_dict.items():\n",
    "    for i, (row_idx, row) in enumerate(final_df.iterrows()):\n",
    "        assert i == row_idx\n",
    "\n",
    "        if row[\"resource_name\"] == \"nan\":\n",
    "            continue\n",
    "        sft_tuning_job = sft.SupervisedTuningJob(row[\"resource_name\"])\n",
    "        this_model_name = sft_tuning_job.tuned_model_endpoint_name\n",
    "\n",
    "        model_state[(final_df_name, row_idx)] = sft_tuning_job.state\n",
    "        if int( sft_tuning_job.state) == 4:\n",
    "            model_map[row[\"resource_name\"]] = this_model_name\n",
    "\n",
    "        while not sft_tuning_job.has_ended:\n",
    "            print(\"Eeep\")\n",
    "            time.sleep(60)\n",
    "            sft_tuning_job.refresh()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run QA evaluation.\n",
    "\n",
    "eval_datas = {}\n",
    "all_results = []\n",
    "all_results_idxs = []\n",
    "for final_df_name, final_df in final_df_dict.items():\n",
    "    for i, (row_idx, row) in enumerate(final_df.iterrows()):\n",
    "        assert i == row_idx\n",
    "\n",
    "        if row[\"resource_name\"] == \"nan\":\n",
    "            continue\n",
    "        # sft_tuning_job = sft.SupervisedTuningJob(row[\"resource_name\"])\n",
    "        # this_model_name = sft_tuning_job.tuned_model_endpoint_name\n",
    "        if row[\"resource_name\"] not in model_map:\n",
    "            print(\"SKIPPING \", final_df_name, i)\n",
    "            continue\n",
    "        this_model_name = model_map[row[\"resource_name\"]]\n",
    "\n",
    "        eval_datas[(final_df_name, row_idx)] = []\n",
    "        for q_type in eval_prompt_names:\n",
    "            for entity_idx in range(len(row[\"text\"])):\n",
    "                for info_idx, question in enumerate(row[f\"filtered_eval_{q_type}\"][entity_idx]):\n",
    "                    if question is None:\n",
    "                        continue\n",
    "                    if len(question) > 100000:\n",
    "                        print(\"!!!!\")\n",
    "                    eval_datas[(final_df_name, row_idx)].append([question, (final_df_name, q_type, row_idx, entity_idx, info_idx)])\n",
    "        all_results.append(easyinference.inference(\n",
    "            prompt_functions=[lambda x: x[0]],\n",
    "            datapoints=eval_datas[(final_df_name, row_idx)],\n",
    "            tags=[version, f\"run_eval_question_{row_idx}\"],\n",
    "            run_fast=run_fast,\n",
    "            allow_failure=True,\n",
    "            attempts_cap=3,\n",
    "            temperature=TEMPERATURE,\n",
    "            model=this_model_name,\n",
    "            batch_size=1000,\n",
    "            run_fast_timeout=300,\n",
    "            cooldown_seconds=10,\n",
    "            batch_timeout_hours=4,\n",
    "            round_robin_enabled=True,\n",
    "            round_robin_options=[\"us-central1\"]\n",
    "        ))\n",
    "        all_results_idxs.append((final_df_name, row_idx))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_responses = {}\n",
    "all_results = await asyncio.gather(*all_results)\n",
    "\n",
    "score_data = []\n",
    "for final_df_name, final_df in final_df_dict.items():\n",
    "    for row_idx, row in final_df.iterrows():\n",
    "        if (final_df_name, row_idx) not in eval_datas:\n",
    "            continue\n",
    "        all_results_idx = all_results_idxs.index((final_df_name, row_idx))\n",
    "        assert all_results_idx != -1\n",
    "        responses, _ = all_results[all_results_idx]\n",
    "        for i, (question, (final_df_name, q_type, row_idx, entity_idx, info_idx)) in enumerate(eval_datas[(final_df_name, row_idx)]):\n",
    "            assert (final_df_name, q_type, row_idx, entity_idx, info_idx) not in eval_responses\n",
    "            answer = responses[i][0][0]\n",
    "            eval_responses[(final_df_name, q_type, row_idx, entity_idx, info_idx)] = answer\n",
    "\n",
    "            if (final_df_name, q_type, row_idx, entity_idx, info_idx) in eval_responses:\n",
    "                response = eval_responses[(final_df_name, q_type, row_idx, entity_idx, info_idx)]\n",
    "            else:\n",
    "                response = \"I don't know.\"\n",
    "            \n",
    "            if isinstance(final_df.at[row_idx, \"info\"][entity_idx], str):\n",
    "                final_df.at[row_idx, \"info\"][entity_idx] = json.loads(final_df.at[row_idx, \"info\"][entity_idx])\n",
    "            info = final_df.at[row_idx, \"info\"][entity_idx][info_idx]\n",
    "            score_data.append((info, question, answer, (final_df_name, q_type, row_idx, entity_idx, info_idx)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Score the QA evaluation results.\n",
    "\n",
    "prompt = lambda data: r\"\"\"I have finetuned a model on a number of imaginary facts. One of these imaginary facts is the following:\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "To test the success of the finetuning process, I asked the model the following question:\n",
    "\n",
    "{}\n",
    "\n",
    "The model responded with:\n",
    "\n",
    "{}\n",
    "\n",
    "Did the model answer correctly? That is, did the model answer consistent with the fact I wanted to reach it?\n",
    "Respond with only a JSON object as follows; make sure you take advantage of the \"deliberation\" field to think through things before reaching a conclusion:\n",
    "```\n",
    "{{\n",
    "    \"deliberation\": \"(Your deliberation here)\",\n",
    "    \"model_knows_correct_fact\": (true/false)\n",
    "}}\n",
    "```\n",
    "\"\"\".format(data[0], data[1], data[2])\n",
    "\n",
    "responses, _ = await easyinference.inference(\n",
    "    prompt_functions=[prompt],\n",
    "    datapoints=score_data,\n",
    "    tags=[version, \"inspect_eval_results\"],\n",
    "    run_fast=run_fast,\n",
    "    model=DEFAULT_MODEL,\n",
    "    temperature=TEMPERATURE,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Process and store the QA evaluation results.\n",
    "\n",
    "scores = {}\n",
    "for i, ((info, question, answer, (final_df_name, q_type, row_idx, entity_idx, info_idx))) in enumerate(score_data):\n",
    "    if (final_df_name, q_type, row_idx) not in scores:\n",
    "        scores[(final_df_name, q_type, row_idx)] = {}\n",
    "    if info_idx not in scores[(final_df_name, q_type, row_idx)]:\n",
    "        scores[(final_df_name, q_type, row_idx)][info_idx] = []\n",
    "    if \"model_knows_correct_fact\" not in parse_json(responses[i][0][0]):\n",
    "        continue\n",
    "    is_correct = parse_json(responses[i][0][0])[\"model_knows_correct_fact\"]\n",
    "    if isinstance(is_correct, int):\n",
    "        is_correct in (0, 1)\n",
    "        is_correct = bool(is_correct)\n",
    "    if is_correct == \"unknown\":\n",
    "        is_correct = None\n",
    "    assert isinstance(is_correct, bool) or is_correct is None\n",
    "    final_df_dict[final_df_name].at[row_idx, f\"filtered_answer_{q_type}\"][entity_idx][info_idx] = answer\n",
    "    final_df_dict[final_df_name].at[row_idx, f\"filtered_answer_score_{q_type}\"][entity_idx][info_idx] = is_correct\n",
    "    if is_correct is not None:\n",
    "        scores[(final_df_name, q_type, row_idx)][info_idx].append((is_correct, question, answer))\n",
    "\n",
    "cs = {}\n",
    "for final_df_name, final_df in final_df_dict.items():\n",
    "    for q_type in eval_prompt_names:\n",
    "        final_df[f\"{q_type}_eval_question\"] = None\n",
    "        final_df[f\"{q_type}_eval_accuracy\"] = None\n",
    "        final_df[f\"{q_type}_eval_answer\"] = None\n",
    "        final_df[f\"{q_type}_eval_is_correct\"] = None\n",
    "\n",
    "        for row_idx, row in final_df.iterrows():\n",
    "            if (final_df_name, row[\"info_type\"], row[\"data_type\"], row[\"is_fake\"], q_type) not in cs:\n",
    "                cs[(final_df_name, row[\"info_type\"], row[\"data_type\"], row[\"is_fake\"], q_type)] = []\n",
    "            opts = []\n",
    "            reps = []\n",
    "            if (final_df_name, q_type, row_idx) not in scores:\n",
    "                continue\n",
    "            for info_idx in scores[(final_df_name, q_type, row_idx)]:\n",
    "                rep = None\n",
    "                v = []\n",
    "                for t in scores[(final_df_name, q_type, row_idx)][info_idx]:\n",
    "                    if t[0] is not None:\n",
    "                        if rep is None:\n",
    "                            rep = t\n",
    "                        v.append(t[0])\n",
    "                if v:\n",
    "                    assert rep is not None\n",
    "                    assert rep[0] is not None\n",
    "                    opts.append(np.mean(v))\n",
    "                    reps.append(rep)\n",
    "            assert opts\n",
    "            cs[(final_df_name, row[\"info_type\"], row[\"data_type\"], row[\"is_fake\"], q_type)].append(np.mean(opts))\n",
    "            final_df.at[row_idx, f\"{q_type}_eval_question\"] = [x[1] for x in reps]\n",
    "            final_df.at[row_idx, f\"{q_type}_eval_answer\"] = [x[2] for x in reps]\n",
    "            final_df.at[row_idx, f\"{q_type}_eval_is_correct\"] = [x[0] for x in reps]\n",
    "            final_df.at[row_idx, f\"{q_type}_eval_accuracy\"] = np.mean([x[0] for x in reps])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# for aa, a, b, c, d in cs:\n",
    "#     print(\"Setting:\", str(aa).strip() + \",\", \"Info type:\", str(a).strip() + \",\", \"Data type:\", str(b).strip() + \",\", \"Is fake:\", str(c).strip() + \",\", \"Eval data type:\", str(d).strip() + \",\", \"Score:\", str(np.mean(cs[(aa, a, b, c, d)])).strip() + \".\")\n",
    "\n",
    "print(\"Setting,\", \"Info_type,\", \"Data_type,\", \"Is_fake,\", \"Eval_data_type,\", \"Score\")\n",
    "for aa, a, b, c, d in cs:\n",
    "    for t in cs[(aa, a, b, c, d)]:\n",
    "        print(str(aa).strip() + \",\", str(a).strip() + \",\", str(b).strip() + \",\", str(c).strip() + \",\", str(d).strip() + \",\", str(float(t)).strip())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save final files\n",
    "for final_df_name, final_df in final_df_dict.items():\n",
    "    with tempfile.NamedTemporaryFile(mode=\"w\") as f:\n",
    "        for _, row in list(final_df.iterrows()):\n",
    "            ds = row.to_dict()\n",
    "            # ds.pop(\"datapoints\")\n",
    "            # for k in ds.keys():\n",
    "            #     if isinstance(ds[k], list) and isinstance(ds[k][0], list):\n",
    "            #         ds[k] = [json.dumps(x) for x in ds[k]][:1]\n",
    "            # print(ds)\n",
    "            for k in eval_prompt_names:\n",
    "                ds[\"icl_eval_\" + k] = [json.dumps(x) for x in ds[\"icl_eval_\" + k]][:1]\n",
    "                ds[\"filtered_eval_\" + k] = [json.dumps(x) for x in ds[\"filtered_eval_\" + k]][:1]\n",
    "                ds[\"filtered_answer_\" + k] = [json.dumps(x) for x in ds[\"filtered_answer_\" + k]][:1]\n",
    "                ds[\"filtered_answer_score_\" + k] = [json.dumps(x) for x in ds[\"filtered_answer_score_\" + k]][:1]\n",
    "                ds[\"eval_\" + k] = [json.dumps(x) for x in ds[\"eval_\" + k]][:1]\n",
    "            ds[\"info\"] = [json.dumps(x) for x in ds[\"info\"]][:1]\n",
    "            # ds.pop(\"datapoints\")\n",
    "            # ds.pop(\"alt_info\")\n",
    "            # ds.pop(\"main_info\")\n",
    "            # if \"side_info\" in ds:\n",
    "            #     ds.pop(\"side_info\")\n",
    "            # for k, v in ds.items():\n",
    "            #     print(k, v)\n",
    "            f.write(json.dumps(ds))\n",
    "            f.write(\"\\n\")\n",
    "        f.flush()\n",
    "        blob = bucket.blob(f\"{ARTIFACT_DIR}/finetuning_data_final_tuning_with_eval_results_{final_df_name}.jsonl\")\n",
    "        blob.upload_from_filename(f.name)\n",
    "        print(\"Uploaded\", blob.name)\n",
    "        # utils.upload_to_table(blob.name, delete=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
