{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import json\n",
    "import random\n",
    "import logging\n",
    "import asyncio\n",
    "import pickle\n",
    "import time\n",
    "import tempfile\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "import numpy as np\n",
    "from bs4 import BeautifulSoup\n",
    "import vertexai\n",
    "from vertexai.tuning import sft\n",
    "\n",
    "import easyinference\n",
    "\n",
    "from finetuning_src import bucket\n",
    "from finetuning_src import config\n",
    "from finetuning_src import utils\n",
    "from finetuning_src.utils import parse_json\n",
    "\n",
    "print(load_dotenv())\n",
    "easyinference.reload_config()\n",
    "await easyinference.initialize_query_connection()\n",
    "\n",
    "# logging.basicConfig(level=logging.ERROR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "version = \"publicv1\"\n",
    "DEFAULT_MODEL = \"publishers/google/models/gemini-1.5-pro-002\"\n",
    "FINETUNING_MODELS = [\"gemini-1.5-pro-002\", \"gemini-1.5-flash-002\"]\n",
    "TEMPERATURE = 1\n",
    "SEED_TEXT_DUPLICATION_COUNT = 200\n",
    "DATASET_CHAR_CAP = 200000\n",
    "FINETUNING_SEEDS = 1\n",
    "override = False\n",
    "NUM_EPOCHS = 40"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tonal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = \"tone_experiment\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate seed texts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_prompt_completions = [\"model talking about a science or historical fact\", \"model talking about an event in the news\", \"user talking about something interesting happening in their life\", \"user talking about a random specific event in their day\", \"user talking about some news about friends or family\", \"model talking about a food or restaurant\", \"model talking about a movie or TV show\", \"model talking about a book or article\", \"model talking about a song or artist\", \"model talking about a hobby or activity\", \"model talking about a place or location\", \"model talking about a product or brand\", \"model talking about a website or app\", \"model talking about a game or sport\", \"model talking about a holiday or celebration\", \"model talking about a weather event\", \"model talking about a natural disaster\", \"user talking about a personal experience\", \"user talking about a personal opinion\", \"user talking about a personal preference\", \"user talking about a personal feeling\", \"user talking about a personal goal\", \"user talking about a personal plan\", \"user talking about a personal hope\", \"user talking about a personal fear\", \"user talking about a personal dream\", \"user talking about a personal memory\", \"user talking about a personal belief\", \"user talking about a personal value\", \"user talking about a personal interest\"]\n",
    "\n",
    "seed_prompt = \"\"\"Please generate a short exchange between a user and a chatbot. This dialogue should consists of a few sentences, ALWAYS have the user speak first, and involve the {}.\"\"\"\n",
    "\n",
    "followup_prompt = \"\"\"Please format your response in JSON, saying nothing else.\n",
    "```\n",
    "[\n",
    "    {\"role\": \"user\", \"text\": \"...\"},\n",
    "    {\"role\": \"model\", \"text\": \"...\"}\n",
    "]\n",
    "```\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "results, _ = await easyinference.inference(\n",
    "    prompt_functions=[lambda x: seed_prompt.format(x), lambda x: followup_prompt],\n",
    "    datapoints=seed_prompt_completions,\n",
    "    tags=[version, experiment_name, \"gen_seed_texts\"],\n",
    "    duplication_indices=list(range(SEED_TEXT_DUPLICATION_COUNT)),\n",
    "    run_fast=True,\n",
    "    allow_failure=True,\n",
    "    attempts_cap=3,\n",
    "    temperature=TEMPERATURE,\n",
    "    max_output_tokens=8192,\n",
    "    model=DEFAULT_MODEL,\n",
    "    batch_size=1000,\n",
    "    run_fast_timeout=300,\n",
    "    cooldown_seconds=10,\n",
    "    batch_timeout_hours=4,\n",
    "    round_robin_enabled=True,\n",
    "    round_robin_options=[\"us-central1\", \"us-west1\", \"us-east1\", \"us-west4\", \"us-east4\", \"us-east5\", \"us-south1\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_texts = []\n",
    "for responses, queries in results:\n",
    "    if any([\"ERROR\" in r for r in responses]):\n",
    "        logging.info(\"Skipping erroneous response.\")\n",
    "        continue\n",
    "    try:\n",
    "        resp = utils.parse_json(responses[1])\n",
    "        assert isinstance(resp, list)\n",
    "        assert len(resp) >= 2\n",
    "        for i, turn in enumerate(resp):\n",
    "            assert isinstance(turn, dict)\n",
    "            if i % 2 == 0:\n",
    "                assert turn[\"role\"] == \"user\"\n",
    "            else:\n",
    "                assert turn[\"role\"] == \"model\"\n",
    "            assert isinstance(turn[\"text\"], str)\n",
    "        seed_texts.append(resp)\n",
    "    except Exception as e:\n",
    "        logging.error(\"Error parsing response: {}\".format(e))\n",
    "        continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(0)\n",
    "random.shuffle(seed_texts)\n",
    "\n",
    "test_texts = []\n",
    "train_texts = []\n",
    "for resp in seed_texts:\n",
    "    if len(resp) != 2:\n",
    "        continue\n",
    "    assert len(resp) == 2\n",
    "    assert resp[0][\"role\"] == \"user\"\n",
    "    assert resp[1][\"role\"] == \"model\"\n",
    "    if len(test_texts) < 100:\n",
    "        test_texts.append(resp)\n",
    "    else:\n",
    "        train_texts.append(resp)\n",
    "print(len(train_texts) + len(test_texts))\n",
    "print(len(seed_texts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(test_texts), len(train_texts))\n",
    "\n",
    "for k in train_texts[:5]:\n",
    "    print(k)\n",
    "for k in test_texts[:5]:\n",
    "    print(k)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate toned text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "tones = [\n",
    "    \"formal\",\n",
    "    \"informal\",\n",
    "    \"humorous\",\n",
    "    \"optimistic\",\n",
    "    \"pessimistic\",\n",
    "    \"disinterested\",\n",
    "    \"sarcastic\",\n",
    "    \"melancholic\",\n",
    "    \"condescending\",\n",
    "    \"sycophantic\"\n",
    "]\n",
    "\n",
    "prompt = \"\"\"Consider the following tones, and examples in each of the tones.\n",
    "\n",
    "Formal:\n",
    "\n",
    "\"Have you had the distinguished pleasure of perusing the newly instituted sunflower maze in the downtown park? Yesterday, I had the esteemed opportunity to traverse this labyrinthine spectacle, and it was akin to navigating through a golden tapestry of nature's grandeur; truly a marvel to behold. Furthermore, they are orchestrating a festival this weekend, featuring live music and an array of local vendors, which I am unequivocally planning to attend.\"\n",
    "\n",
    "Informal:\n",
    "\n",
    "\"Hey, did you hear about that giant sunflower maze they put up in the park downtown? I walked through it yesterday, and it was like being in a crazy cool golden maze; seriously awesome. And guess what? They're throwing a festival this weekend with live music and a bunch of local vendors. No way I'm missing that!\"\n",
    "\n",
    "Humorous:\n",
    "\n",
    "\"So, get this: they planted this enormous sunflower maze in the downtown park, and I got lost in it yesterday. It was like being in a giant, golden sunflower jungle! They're even throwing a festival this weekend with live music and vendors. Who knew sunflowers could throw a party? I half expected the flowers to start dancing!\"\n",
    "\n",
    "Optimistic:\n",
    "\n",
    "\"Have you seen the incredible sunflower maze they just put up in the downtown park? I walked through it yesterday, and it felt like stepping into a golden dream—so bright, so full of life! And the best part? They're hosting a festival this weekend with live music and amazing local vendors. I just know it’s going to be an unforgettable experience!\"\n",
    "\n",
    "Pessimistic:\n",
    "\n",
    "\"Yeah, so they put up this giant sunflower maze in the park downtown. I went through it yesterday, and honestly, it was just a bunch of tall plants blocking my way. People keep raving about it like it’s some magical experience, but it’s really nothing special. Now there’s a festival this weekend with live music and vendors, but I can already picture the overcrowding, overpriced food, and noise. I might check it out, but I’m not holding my breath.\"\n",
    "\n",
    "Disinterested:\n",
    "\n",
    "\"Oh, yeah, I guess there’s some sunflower maze in the park. I walked through it yesterday. It’s... fine? Just sunflowers. Apparently, there’s a festival this weekend too, with music and vendors. Cool, I guess. Doesn’t really matter to me.\"\n",
    "\n",
    "Sarcastic:\n",
    "\n",
    "\"Oh joy, did you hear about the enormous sunflower maze they set up in the park downtown? I took a stroll there yesterday, and it was like walking through a golden labyrinth. Truly mesmerizing, right? And they're even having a festival this weekend with live music and local vendors. Because, of course, that's exactly what we all need.\"\n",
    "\n",
    "Melancholic:\n",
    "\n",
    "\"Did you hear about the massive sunflower maze they set up in the park downtown? I took a solitary walk through it yesterday, and it felt like wandering through a golden labyrinth; strangely captivating yet tinged with sadness. They're having a festival this weekend with live music and local vendors. I suppose I'll go back; it might offer a fleeting moment of solace in this tumultuous world.\"\n",
    "\n",
    "Condescending:\n",
    "\n",
    "\"Oh, you haven’t heard about the sunflower maze in the downtown park? How quaint. I walked through it yesterday, and honestly, it’s amusing how easily people are entertained by rows of flowers. And now they’re even throwing a festival—live music, local vendors, the whole ordeal. I suppose it’s a nice little distraction for some.\"\n",
    "\n",
    "Sycophantic:\n",
    "\n",
    "\"Oh my goodness, have you seen the absolutely magnificent sunflower maze in the downtown park? It’s pure genius! Whoever came up with this deserves an award. I walked through it yesterday, and honestly, I was in awe—it’s like nature’s own masterpiece. And now they’re putting on a phenomenal festival with live music and only the best local vendors! This is the greatest thing to happen to the city in years! I just have to be there!\"\n",
    "\n",
    "For each of the above tones, rewrite the following conversation so that the MODEL speaks in the tone specified (not the user, keep the user text the same):\n",
    "```\n",
    "{}\n",
    "```\n",
    "\n",
    "Structure your response as follows:\n",
    "```\n",
    "# Formal\n",
    "User: ...\n",
    "Model: ...\n",
    "...\n",
    "```\n",
    "\"\"\"\n",
    "\n",
    "followup_prompt = \"\"\"Please format your response in JSON, saying nothing else.\n",
    "```\n",
    "{\n",
    "    \"formal\": [\n",
    "        {\"role\": \"user\", \"text\": \"...\"},\n",
    "        {\"role\": \"model\", \"text\": \"...\"}\n",
    "    ],\n",
    "    ...\n",
    "}\n",
    "```\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "results, _ = await easyinference.inference(\n",
    "    prompt_functions=[lambda x: prompt.format(x), lambda x: followup_prompt],\n",
    "    datapoints=train_texts,\n",
    "    tags=[version, experiment_name, \"gen_tonal_eval_tones\"],\n",
    "    run_fast=True,\n",
    "    allow_failure=True,\n",
    "    attempts_cap=3,\n",
    "    temperature=TEMPERATURE,\n",
    "    max_output_tokens=8192,\n",
    "    model=DEFAULT_MODEL,\n",
    "    batch_size=1000,\n",
    "    run_fast_timeout=300,\n",
    "    cooldown_seconds=10,\n",
    "    batch_timeout_hours=4,\n",
    "    round_robin_enabled=True,\n",
    "    round_robin_options=[\"us-central1\", \"us-west1\", \"us-east1\", \"us-west4\", \"us-east4\", \"us-east5\", \"us-south1\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_tonal_texts = {t: [] for t in tones}\n",
    "for i, (responses, queries) in enumerate(results):\n",
    "    resp = utils.parse_json(responses[1])\n",
    "    try:\n",
    "        assert len(resp) == len(tones)\n",
    "        assert isinstance(resp, dict)\n",
    "        to_append_dict = {}\n",
    "        for tone in tones:\n",
    "            assert tone in resp\n",
    "            assert isinstance(resp[tone], list)\n",
    "            assert len(resp[tone]) % 2 == 0\n",
    "            assert train_texts[i][0][\"text\"].replace(\"  \", \" \").strip() == resp[tone][0][\"text\"].replace(\"  \", \" \").strip()\n",
    "            for j, turn in enumerate(resp[tone]):\n",
    "                assert isinstance(turn, dict)\n",
    "                if j % 2 == 0:\n",
    "                    assert turn[\"role\"] == \"user\"\n",
    "                else:\n",
    "                    assert turn[\"role\"] == \"model\"\n",
    "                assert isinstance(turn[\"text\"], str)\n",
    "            to_append_dict[tone] = resp[tone]\n",
    "        assert len(to_append_dict) == len(tones)\n",
    "        for tone, to_append in to_append_dict.items():\n",
    "            train_tonal_texts[tone].append(to_append)\n",
    "    except AssertionError:\n",
    "        print(\"FAILURE\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = random.choice(range(len(train_tonal_texts[\"formal\"])))\n",
    "print(\"original\")\n",
    "print(train_texts[i])\n",
    "print()\n",
    "for tone in tones:\n",
    "    print(tone)\n",
    "    for j in range(2):\n",
    "        print(train_tonal_texts[tone][i][j][\"role\"], \":\", train_tonal_texts[tone][i][j][\"text\"])\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pick_top_k(datapoints_list):\n",
    "    dp_shuffled = datapoints_list.copy()\n",
    "    np.random.shuffle(dp_shuffled)\n",
    "    total_chars = 0\n",
    "    chosen = []\n",
    "    for d in dp_shuffled:\n",
    "        if total_chars + len(str(d)) <= DATASET_CHAR_CAP:\n",
    "            chosen.append(d)\n",
    "            total_chars += len(str(d))\n",
    "        else:\n",
    "            break\n",
    "    return chosen\n",
    "\n",
    "np.random.seed(0)\n",
    "\n",
    "train_tonal_datasets = {t: [pick_top_k(train_tonal_texts[t]) for _ in range(FINETUNING_SEEDS)] for t in tones}\n",
    "print({t: int(np.mean([len(x) for x in train_tonal_datasets[t]])) for t in tones})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate fine-tuned models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fname = f\"tone_exp/tonal_experiment_data.p\"\n",
    "blob = bucket.blob(fname)\n",
    "if blob.exists() and not override:\n",
    "    print(\"Loading existing data\")\n",
    "    with tempfile.NamedTemporaryFile(mode=\"rb\") as f:\n",
    "        blob.download_to_filename(f.name)\n",
    "        tonal_experiment_data = pickle.load(f)\n",
    "else:\n",
    "    tonal_experiment_data = {(model, tone): [{\"resource_name\": \"nan\", \"blob_name\": None} for _ in range(FINETUNING_SEEDS)] for tone in tones for model in FINETUNING_MODELS}\n",
    "    print(\"Creating new data\")\n",
    "for model in FINETUNING_MODELS:\n",
    "    for tone in tones:\n",
    "        for i in range(FINETUNING_SEEDS):\n",
    "            task = tonal_experiment_data[(model, tone)][i]\n",
    "            dataset = train_tonal_datasets[tone][i]\n",
    "            if task[\"resource_name\"] != \"nan\":\n",
    "                print(f\"Refusing to restart {i}\", task[\"resource_name\"], task[\"blob_name\"])\n",
    "                continue\n",
    "            blob = bucket.blob(f\"tone_exp/finetuning_data_tonal_{tone}_{i}.jsonl\")\n",
    "            print(\"Uploading training data to\", blob.name)\n",
    "            with tempfile.NamedTemporaryFile(mode=\"w\") as f:\n",
    "                for row in dataset:\n",
    "                    row = copy.deepcopy(row)\n",
    "                    for t in row:\n",
    "                        assert t[\"role\"] in [\"model\", \"user\"]\n",
    "                        t[\"parts\"] = [{\"text\": t.pop(\"text\")}]\n",
    "                    json.dump({\"contents\": row}, f)\n",
    "                    f.write(\"\\n\")\n",
    "                f.flush()\n",
    "                blob.upload_from_filename(f.name)\n",
    "\n",
    "            # Run finetuning job.\n",
    "            vertexai.init(project=config.GCP_PROJECT_ID, location=\"us-central1\")\n",
    "            print(f\"gs://{blob.bucket.name}/{blob.name}\")\n",
    "            sft_tuning_job = sft.train(\n",
    "                source_model=model,\n",
    "                train_dataset=f\"gs://{blob.bucket.name}/{blob.name}\",\n",
    "                learning_rate_multiplier=1,\n",
    "                epochs=NUM_EPOCHS,\n",
    "            )\n",
    "            print(sft_tuning_job.resource_name)\n",
    "\n",
    "            # Add finetuning job information.\n",
    "            task[\"blob_name\"] = blob.name\n",
    "            task[\"resource_name\"] = str(sft_tuning_job.resource_name)\n",
    "\n",
    "fname = f\"tone_exp/tonal_experiment_data.p\"\n",
    "with tempfile.NamedTemporaryFile(mode=\"wb\") as f:\n",
    "    pickle.dump(tonal_experiment_data, f)\n",
    "    f.flush()\n",
    "    blob = bucket.blob(fname)\n",
    "    blob.upload_from_filename(f.name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "for bpj in sft.SupervisedTuningJob.list(filter='state=\"JOB_STATE_RUNNING\"'):\n",
    "    print(\"Running\", bpj.resource_name, bpj.name)\n",
    "for bpj in sft.SupervisedTuningJob.list(filter='state=\"JOB_STATE_PENDING\"'):\n",
    "    print(\"Pending\", bpj.resource_name, bpj.name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Optionally kill all finetuning jobs.\n",
    "# for bpj in sft.SupervisedTuningJob.list(filter='state=\"JOB_STATE_QUEUED\"'):\n",
    "#     print(bpj.name)\n",
    "#     bpj.cancel()\n",
    "# for bpj in sft.SupervisedTuningJob.list(filter='state=\"JOB_STATE_RUNNING\"'):\n",
    "#     print(bpj.name)\n",
    "#     bpj.cancel()\n",
    "# for bpj in sft.SupervisedTuningJob.list(filter='state=\"JOB_STATE_PENDING\"'):\n",
    "#     print(bpj.name)\n",
    "#     bpj.cancel()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fname = f\"tone_exp/tonal_experiment_data.p\"\n",
    "blob = bucket.blob(fname)\n",
    "assert blob.exists()\n",
    "print(\"Loading existing data\")\n",
    "with tempfile.NamedTemporaryFile(mode=\"rb\") as f:\n",
    "    blob.download_to_filename(f.name)\n",
    "    tonal_experiment_data = pickle.load(f)\n",
    "\n",
    "eval_datas = []\n",
    "for j, row in enumerate(test_texts):\n",
    "    assert row[0][\"role\"] == \"user\"\n",
    "    eval_datas.append(row[0][\"text\"])\n",
    "\n",
    "all_results = []\n",
    "all_results_idxs = []\n",
    "for model in FINETUNING_MODELS:\n",
    "    for tone in tones:\n",
    "        for i in range(FINETUNING_SEEDS):\n",
    "            task = tonal_experiment_data[(model, tone)][i]\n",
    "            print(\"!!\", task[\"resource_name\"])\n",
    "            if task[\"resource_name\"] == \"nan\":\n",
    "                print(f\"Skipping {i}\", task[\"resource_name\"], task[\"blob_name\"])\n",
    "                continue\n",
    "            print(f\"Running {i}\", task[\"resource_name\"], task[\"blob_name\"])\n",
    "            vertexai.init(project=config.GCP_PROJECT_ID, location=\"us-central1\")\n",
    "            sft_tuning_job = sft.SupervisedTuningJob(task[\"resource_name\"])\n",
    "            sft_tuning_job.refresh()\n",
    "            while not sft_tuning_job.has_ended:\n",
    "                time.sleep(60)\n",
    "                sft_tuning_job.refresh()\n",
    "            results = easyinference.inference(\n",
    "                prompt_functions=[lambda x: x],\n",
    "                datapoints=eval_datas,\n",
    "                tags=[version, experiment_name, f\"run_tonal_eval_{model}_{tone}_{i}\"],\n",
    "                run_fast=True,\n",
    "                allow_failure=True,\n",
    "                attempts_cap=3,\n",
    "                temperature=TEMPERATURE,\n",
    "                max_output_tokens=8192,\n",
    "                model=sft_tuning_job.tuned_model_endpoint_name,\n",
    "                batch_size=1000,\n",
    "                run_fast_timeout=300,\n",
    "                cooldown_seconds=10,\n",
    "                batch_timeout_hours=4,\n",
    "                round_robin_enabled=True,\n",
    "                round_robin_options=[\"us-central1\"]\n",
    "            )\n",
    "            all_results.append(results)\n",
    "            all_results_idxs.append((model, tone, i))\n",
    "\n",
    "eval_responses = {}\n",
    "all_results = await asyncio.gather(*all_results)\n",
    "\n",
    "for i, ((results, _), (model, tone, idx)) in enumerate(zip(all_results, all_results_idxs)):\n",
    "    for j, (responses, queries) in enumerate(results):\n",
    "        if (model, idx, j) not in eval_responses:\n",
    "            eval_responses[(model, idx, j)] = {}\n",
    "        eval_responses[(model, idx, j)][tone] = f\"User: {eval_datas[j]}\\nModel: {responses[0]}\"\n",
    "\n",
    "for v in eval_responses.values():\n",
    "    assert len(v) == len(tones)\n",
    "\n",
    "for v in eval_responses.values():\n",
    "    assert len(v) == len(tones)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = {model: [] for model in FINETUNING_MODELS}\n",
    "for (model, _, _), v in eval_responses.items():\n",
    "    v = list(v.items())\n",
    "    random.shuffle(v)\n",
    "    X = [f\"{i + 1}: \" +  vv[1] for i, vv in enumerate(v)]\n",
    "    Y = [vv[0] for vv in v]\n",
    "    datasets[model].append((tones, X, Y))\n",
    "print([(k, len(v)) for k, v in datasets.items()])\n",
    "\n",
    "\n",
    "prompt = \"\"\"Below, you are provided several examples of user-model interactions, each labeled by an integer ID. These interactions each map to one of several tones: {}. Your task is to match each interaction to the tone they correspond to.\n",
    "\n",
    "{}\n",
    "\"\"\"\n",
    "followup_prompt = \"\"\"Please format your response in JSON, saying nothing else. Respond with a JSON dictionary mapping each tone to the integer ID of the corresponding text. For example:\n",
    "{\"pessimistic\": 4, \"formal\": 1, ...}\n",
    "\"\"\"\n",
    "\n",
    "scores = {model: [] for model in FINETUNING_MODELS}\n",
    "\n",
    "for model in FINETUNING_MODELS:\n",
    "    results, _ = await easyinference.inference(\n",
    "        prompt_functions=[lambda x: prompt.format(x[0], \"\\n\\n\".join(x[1])), lambda x: followup_prompt],\n",
    "        datapoints=datasets[model],\n",
    "        tags=[version, experiment_name, f\"score_tonal_eval_{model}\"],\n",
    "        run_fast=True,\n",
    "        allow_failure=True,\n",
    "        attempts_cap=3,\n",
    "        temperature=TEMPERATURE,\n",
    "        max_output_tokens=8192,\n",
    "        model=DEFAULT_MODEL,\n",
    "        batch_size=1000,\n",
    "        run_fast_timeout=300,\n",
    "        cooldown_seconds=10,\n",
    "        batch_timeout_hours=4,\n",
    "        round_robin_enabled=True,\n",
    "        round_robin_options=[\"us-central1\", \"us-west1\", \"us-east1\", \"us-west4\", \"us-east4\", \"us-east5\", \"us-south1\"]\n",
    "    )\n",
    "\n",
    "    for j in range(len(results)):\n",
    "        resp = parse_json(results[j][0][1])\n",
    "        for t in tones:\n",
    "            correct_idx = datasets[model][j][2].index(t)\n",
    "            assert correct_idx != -1\n",
    "            try:\n",
    "                scores[model].append(int(int(resp[t]) - 1 == correct_idx))\n",
    "            except:\n",
    "                continue\n",
    "    print(model, np.mean(scores[model]), len(scores[model]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
