{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1ee8b54",
   "metadata": {},
   "source": [
    "# This is where I test different llama2 models, prompts and inputs to assess computation time/accuracy before building the full pipeline for the project.\n",
    "Fine tuning: https://medium.com/@ogbanugot/notes-on-fine-tuning-llama-2-using-qlora-a-detailed-breakdown-370be42ccca1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d1a5c7b",
   "metadata": {},
   "source": [
    "## Questions/Thoughts\n",
    "1. Combine children and neonatal or analyze separately?\n",
    "2. Start with small model then use big?\n",
    "3. When/if to move to Azure?\n",
    "4. Prompting - how many classes to allow as potential outputs\n",
    "5. Validation - InSilicoVA, openVA, etc.\n",
    "6. PPI correction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "305aa57d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-11 11:48:47.476834: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'NoneType' object has no attribute 'cadam32bit_grad_fp32'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/adam/anaconda3/lib/python3.11/site-packages/bitsandbytes/cextension.py:34: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.\n",
      "  warn(\"The installed version of bitsandbytes was compiled without GPU support. \"\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import time\n",
    "import pathlib\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "from datasets import load_dataset\n",
    "from transformers import (\n",
    "    AutoModelForCausalLM,\n",
    "    AutoTokenizer,\n",
    "    BitsAndBytesConfig,\n",
    "    TrainingArguments,\n",
    "    pipeline,\n",
    "    logging,\n",
    ")\n",
    "from peft import LoraConfig, PeftModel\n",
    "from trl import SFTTrainer\n",
    "from IPython.display import display, HTML\n",
    "from llama_cpp import Llama\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "warnings.simplefilter('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2594089",
   "metadata": {},
   "source": [
    "## Read in data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "df225c18",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('../../data/phmrc/phmrc_adult_tokenized.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a21aa802",
   "metadata": {},
   "outputs": [],
   "source": [
    "regions = list(df['site'].unique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "af6b2291",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train\n",
    "# Read in CSV files and store in dictionary\n",
    "train_excluded_dict = {}\n",
    "for region in regions:\n",
    "    file_path = f'../../data/train_test_val/train_ex_{region.lower()}.csv'\n",
    "    train_excluded_dict[region] = pd.read_csv(file_path)\n",
    "    \n",
    "# assign training data df names\n",
    "train_ex_ap = train_excluded_dict['ap']\n",
    "train_ex_dar = train_excluded_dict['dar']\n",
    "train_ex_pemba = train_excluded_dict['pemba']\n",
    "train_ex_mexico = train_excluded_dict['mexico']\n",
    "train_ex_bohol = train_excluded_dict['bohol']\n",
    "train_ex_up = train_excluded_dict['up']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ea99a283",
   "metadata": {},
   "outputs": [],
   "source": [
    "# test / val\n",
    "\n",
    "# Dictionary to store DataFrames\n",
    "test_dict = {}\n",
    "val_dict = {}\n",
    "\n",
    "# Read in test and validation CSV files and store in dictionaries\n",
    "for region in regions:\n",
    "    test_file_path = f'../../data/train_test_val/test_{region}.csv'\n",
    "    val_file_path = f'../../data/train_test_val/val_{region}.csv'\n",
    "    \n",
    "    test_dict[region] = pd.read_csv(test_file_path)\n",
    "    val_dict[region] = pd.read_csv(val_file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7fa6f98a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# assign test and val data df names\n",
    "test_ap = test_dict['ap']\n",
    "test_dar = test_dict['dar']\n",
    "test_pemba = test_dict['pemba']\n",
    "test_mexico = test_dict['mexico']\n",
    "test_bohol = test_dict['bohol']\n",
    "test_up = test_dict['up']\n",
    "\n",
    "val_ap = val_dict['ap']\n",
    "val_dar = val_dict['dar']\n",
    "val_pemba = val_dict['pemba']\n",
    "val_mexico = val_dict['mexico']\n",
    "val_bohol = val_dict['bohol']\n",
    "val_up = val_dict['up']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c97e5d04",
   "metadata": {},
   "outputs": [],
   "source": [
    "# list of dfs\n",
    "training_dfs = [\n",
    "    train_ex_ap,\n",
    "    train_ex_dar, \n",
    "    train_ex_pemba, \n",
    "    train_ex_mexico, \n",
    "    train_ex_bohol, \n",
    "    train_ex_up]\n",
    "\n",
    "# combine labeled and unlabeled testing data\n",
    "test_ap = pd.concat([test_ap, val_ap])\n",
    "test_dar = pd.concat([test_dar, val_dar])\n",
    "test_pemba = pd.concat([test_pemba, val_pemba])\n",
    "test_mexico = pd.concat([test_mexico, val_mexico])\n",
    "test_bohol = pd.concat([test_bohol, val_bohol])\n",
    "test_up = pd.concat([test_up, val_up])\n",
    "    \n",
    "testing_dfs = [\n",
    "    test_ap,\n",
    "    test_dar,\n",
    "    test_pemba,\n",
    "    test_mexico,\n",
    "    test_bohol,\n",
    "    test_up\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4311be59",
   "metadata": {},
   "source": [
    "## Load different models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7a2cf3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "small = Llama(\n",
    "    model_path=\"../../models/llama-2-7b-chat.Q2_K.gguf\",\n",
    "    n_ctx=2048)\n",
    "# medium = Llama(\n",
    "#     model_path=\"../models/llama-2-7b-chat.Q4_K_M.gguf\",\n",
    "#     n_ctx=2048)\n",
    "big = Llama(\n",
    "    model_path=\"../../models/llama-2-7b-chat.Q8_0.gguf\",\n",
    "    n_ctx=2048)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d74ca118",
   "metadata": {},
   "source": [
    "## Query function for making prompt calls to model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b7bad25",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def query(model, question):\n",
    "#     model_name = pathlib.Path(model.model_path).name\n",
    "#     time_start = time.time()\n",
    "#     prompt = f\"Q: {question} A:\"\n",
    "#     output = model(prompt=prompt, max_tokens=0) # if max tokens is zero, depends on n_ctx\n",
    "#     response = output[\"choices\"][0][\"text\"]\n",
    "#     time_elapsed = time.time() - time_start\n",
    "#     display(HTML(f'<code>{model_name} response time: {time_elapsed:.02f} sec</code>'))\n",
    "#     display(HTML(f'<strong>Question:</strong> {question}'))\n",
    "#     display(HTML(f'<strong>Answer:</strong> {response}'))\n",
    "#     print(json.dumps(output, indent=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8c114d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def query_tostring(model, question):\n",
    "    model_name = pathlib.Path(model.model_path).name\n",
    "    time_start = time.time()\n",
    "    prompt = question\n",
    "    output = model(prompt=prompt, max_tokens=0) # if max tokens is zero, depends on n_ctx\n",
    "    response = output[\"choices\"][0][\"text\"]\n",
    "    time_elapsed = time.time() - time_start\n",
    "    print(time_elapsed)\n",
    "    return response"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a68c15bd",
   "metadata": {},
   "source": [
    "## Create Prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52749236",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_to_score = {\n",
    "    'aids-tb': 0,\n",
    "    'communicable': 1,\n",
    "    'external': 2,\n",
    "    'maternal': 3, \n",
    "    'non-communicable': 4\n",
    "}\n",
    "\n",
    "score_to_label = {\n",
    "    0: 'aids-tb',\n",
    "    1: 'communicable',\n",
    "    2: 'external',\n",
    "    3: 'maternal',\n",
    "    4: 'non-communicable' \n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ea9e719",
   "metadata": {},
   "outputs": [],
   "source": [
    "def inspect_narrative(row):\n",
    "    print('Narrative: ' + df['narrative'][row])\n",
    "    print('True Label: ' + df['gs_text34'][row])\n",
    "    print('Broad Category: ' + df['gs_cod'][row])\n",
    "    print('Embedding Representation: ' + str(label_to_score[df['gs_cod'][row]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "046e78c2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "inspect_narrative(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "981c4cfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_prompt(narrative):\n",
    "    '''\n",
    "    takes in narrative string and returns full prompt as string\n",
    "    '''\n",
    "    \n",
    "    result = f\"\"\"\n",
    "    <narrative>\n",
    "    {narrative}\n",
    "    </narrative>\n",
    "\n",
    "    <labels>\n",
    "    aids-tb: Patient died resulting from HIV-AIDs or Tuberculosis.\n",
    "    communicable: Patient died from a communicable disease which is defined as \n",
    "    illnesses that spread from one human to another such as pneumonia, diarrhea \n",
    "    or dysentery.\n",
    "    external: Patient died from external causes including as accidents like fires,\n",
    "    drowning, road traffic, falls, poisonous animals and violence like suicide, \n",
    "    homicide, or other injuries.\n",
    "    maternal: Patient died from complications related to pregnancy or childbirth \n",
    "    including from severe bleeding, sepsis, pre-eclampsia and eclampsia.\n",
    "    non-communicable: Patient died from a non-communicable disease which is defined\n",
    "    as illnesses that cannot be transmitted from one human to another such as cirrhosis,\n",
    "    epilepsy, acute myocardial infarction, copd, renal failure, cancer, diabetes,\n",
    "    stroke, malaria, asthma, or other non-communicable diseases.\n",
    "    </labels>\n",
    "\n",
    "    <options>\n",
    "    aids-tb, communicable, external, maternal, non-communicable\n",
    "    </options>\n",
    "\n",
    "\n",
    "    Which label best applies applies to the narrative (aids-tb, communicable, external, maternal, non-communicable)?\n",
    "    Limit your response to one of the options exactly as it appears in the list.\n",
    "    \"\"\"\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eafa2a90",
   "metadata": {},
   "outputs": [],
   "source": [
    "inspect_narrative(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31b12de4",
   "metadata": {},
   "outputs": [],
   "source": [
    "create_prompt(df['narrative'][4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "954e12fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_tostring(small, create_prompt(df['narrative'][4]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bb2132c",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_tostring(big, create_prompt(df['narrative'][4]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5771326",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions_llama = []\n",
    "for text in tqdm(df['narrative'][:5]):\n",
    "    predictions_llama.append(query_tostring(big, create_prompt(df['narrative'][4])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d53b0091",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions_llama"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "321c96e6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "42f6484d",
   "metadata": {},
   "source": [
    "## fuzzy match to extract labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10def61a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_exact_match(dictionary, long_strings):\n",
    "    result_list = []\n",
    "\n",
    "    for long_string in long_strings:\n",
    "        # Extract the first 30 characters from the string\n",
    "        short_string = long_string[:30]\n",
    "\n",
    "        # Check if any dictionary string exists in the input string\n",
    "        matching_keys = [key for key, value in dictionary.items() if value in short_string]\n",
    "\n",
    "        # Check if any matches were found\n",
    "        if matching_keys:\n",
    "            result_list.extend(matching_keys)\n",
    "        else:\n",
    "            # No match found, return 3\n",
    "            result_list.append(3)\n",
    "\n",
    "    return result_list\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2a32546",
   "metadata": {},
   "outputs": [],
   "source": [
    "find_exact_match(cod_dict, 'aids')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77bdcc06",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.Series(predictions_llama_up).to_csv('text_predictions_llama2_up.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "510c2ed6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fuzzymatch(text):\n",
    "    '''\n",
    "    takes in text that needs to be matched\n",
    "    returns constrained label from dict\n",
    "    '''\n",
    "    \n",
    "    def get_first_30_characters(input_string):\n",
    "        return input_string[:30]\n",
    "    \n",
    "    first = get_first_30_characters(text)\n",
    "    \n",
    "    def fuzzy_match_and_get_value(input_string, dictionary):\n",
    "        # Get the best match and its score\n",
    "        match, score = process.extractOne(input_string, dictionary.keys())\n",
    "\n",
    "        # You can adjust the threshold for the fuzzy matching score\n",
    "        # For example, consider matches with a score of at least 80\n",
    "        if score >= 80:\n",
    "            return dictionary[match]\n",
    "        else:\n",
    "            return None  # No satisfactory match found\n",
    "    \n",
    "    \n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "641d8402",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af97096b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd3600ed",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ff76718",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66c0b81e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7dd09f7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edbe85bd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9d41c011",
   "metadata": {},
   "outputs": [],
   "source": [
    "labeled_data = {'Sepsis': 'According to respondent child had severe pain in back from last 15 days which was unbearable. Doctor told that may be child got tumor…child received treatment for few days and also got relief for few days but again child had the same condition.Then child was taken to Lucknow where after all treatment nothing could be diagnosed but the pain was increasing day by day. Child received treatment and got some relief but suddenly died.',\n",
    "                'Fires': 'When my son was playing with a kite, its thread was caught up on an electric pole. He climbed the electric pole to take the kite but he got the electric shock. Then immediately we took him to the Siddipet hospital. They told us to take him to the Gandhi hospital. Then we went to the Gandhi hospital. As he was under the treatment there, he died.',\n",
    "                'Road Traffic': 'My nice was studying in a hostel at Mulugu. One day she was met an accident with a car when she was crossing the road at the school. We have admitted in the Gandhi hospital, she died while the treatment was going on.'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4f6f1d47",
   "metadata": {},
   "outputs": [],
   "source": [
    "common_causes = ['pneumonia',\n",
    " 'diarrhea',\n",
    " 'malaria',\n",
    " 'road traffic',\n",
    " 'drowning',\n",
    " 'cardiovascular disease',\n",
    " 'fires',\n",
    " 'meningitis',\n",
    " 'venomous animal',\n",
    " 'falls',\n",
    " 'encephalitis',\n",
    " 'sepsis',\n",
    " 'measles',\n",
    " 'aids',\n",
    " 'tuberculosis']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fe59d79d",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt1 = 'Given a text narrative about a death, attribute the most likely cause of death. Respond only with the cause of death, or \"other\" if you are not sure. Narrative: '"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "44580283",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Given a text narrative about a death, attribute the most likely cause of death. Respond only with the cause of death, or \"other\" if you are not sure. Narrative: '"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompt1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9c2aa1ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "# include explicit list of output classes\n",
    "prompt2 = 'Given a text narrative about a death, attribute the most likely cause of death from this list: ' + ', '.join(common_causes) + '. Respond only with the cause of death from the list, or \"other\" if you are not sure. Narrative: '"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "8305cc33",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Given a text narrative about a death, attribute the most likely cause of death from this list: pneumonia, diarrhea, malaria, road traffic, drowning, cardiovascular disease, fires, meningitis, venomous animal, falls, encephalitis, sepsis, measles, aids, tuberculosis. Respond only with the cause of death from the list, or \"other\" if you are not sure. Narrative: '"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompt2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "329d62da",
   "metadata": {},
   "outputs": [],
   "source": [
    "# include explicit list of output classes and do not offer 'other' as option\n",
    "prompt3 = 'Given a text narrative about a death, attribute the most likely cause of death from this list: ' + ', '.join(common_causes) + '. Respond only with the one word cause of death from the list. Narrative: '"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "189901ee",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Given a text narrative about a death, attribute the most likely cause of death from this list: pneumonia, diarrhea, malaria, road traffic, drowning, cardiovascular disease, fires, meningitis, venomous animal, falls, encephalitis, sepsis, measles, aids, tuberculosis. Respond only with the one word cause of death from the list. Narrative: '"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompt3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "47a4f93c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# make it explicit that the response must come from the given list. \n",
    "prompt4 = 'Given a text narrative about a death, attribute the most likely cause of death from this list: ' + ', '.join(common_causes) + '. Your response must match exactly to one of the options from this list. Narrative: '"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "27de71cb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Given a text narrative about a death, attribute the most likely cause of death from this list: pneumonia, diarrhea, malaria, road traffic, drowning, cardiovascular disease, fires, meningitis, venomous animal, falls, encephalitis, sepsis, measles, aids, tuberculosis. Your response must match exactly to one of the options from this list. Narrative: '"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompt4"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93d5cb48",
   "metadata": {},
   "source": [
    "## Few shot fine tuning"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6e8149c",
   "metadata": {},
   "source": [
    "## Run models and prompt1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fe674e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(small, prompt1 + labeled_data['Road Traffic'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7114b9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(medium, prompt1 + labeled_data['Road Traffic'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bed46005",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(big, prompt1 + labeled_data['Road Traffic'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1fdcac3",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(small, prompt1 + labeled_data['Sepsis'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "818068a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(small, prompt1 + labeled_data['Fires'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac25e5dc",
   "metadata": {},
   "source": [
    "## Run models and prompt 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95532a38",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(small, prompt2 + labeled_data['Road Traffic'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67bffc14",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(medium, prompt2 + labeled_data['Road Traffic'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "376cc8eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(big, prompt2 + labeled_data['Road Traffic'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edbba6da",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(small, prompt2 + labeled_data['Sepsis'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1d351fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(small, prompt2 + labeled_data['Fires'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1df8a6a7",
   "metadata": {},
   "source": [
    "## Run models and prompt 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9c989e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(small, prompt3 + labeled_data['Road Traffic'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e59831f2",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "query(medium, prompt3 + labeled_data['Road Traffic'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5bd5b25",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(big, prompt3 + labeled_data['Road Traffic'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0233fd37",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(small, prompt3 + labeled_data['Sepsis'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93dc850d",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(small, prompt3 + labeled_data['Fires'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "748e2588",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "query(big, prompt3 + labeled_data['Sepsis'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee70b93c",
   "metadata": {},
   "source": [
    "## Run models and prompt 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88ef4428",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(small, prompt4 + labeled_data['Road Traffic'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9094ec8",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "query(medium, prompt4 + labeled_data['Road Traffic'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0beaea0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(big, prompt4 + labeled_data['Road Traffic'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c905fd50",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(small, prompt4 + labeled_data['Sepsis'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f4d3536",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(small, prompt4 + labeled_data['Fires'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae2f1d3e",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "query(big, prompt4 + labeled_data['Sepsis'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f9a3dcc",
   "metadata": {},
   "source": [
    "### Small model prompt - need to return only one or two words."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5da8eb1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(small, prompt4 + labeled_data['Sepsis'] + ' LIMIT YOUR RESPONSE TO ONE OR TWO WORDS ONLY.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dff8173",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(small, prompt4 + labeled_data['Fires'] + ' LIMIT YOUR RESPONSE TO ONE OR TWO WORDS ONLY.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff2b9406",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(small, prompt4 + labeled_data['Road Traffic'] + ' LIMIT YOUR RESPONSE TO ONE OR TWO WORDS ONLY.')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81400bed",
   "metadata": {},
   "source": [
    "### Big model prompting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baecedb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(big, prompt4 + labeled_data['Road Traffic'] + ' LIMIT YOUR RESPONSE TO ONE OR TWO WORDS ONLY, EXACTLY AS THEY APPEAR IN THE LIST.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dde65b23",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(big, prompt4 + labeled_data['Sepsis'] + ' LIMIT YOUR RESPONSE TO ONE OR TWO WORDS ONLY, EXACTLY AS THEY APPEAR IN THE LIST.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a5c9769",
   "metadata": {},
   "outputs": [],
   "source": [
    "query(big, prompt4 + labeled_data['Fires'] + ' LIMIT YOUR RESPONSE TO ONE OR TWO WORDS ONLY, EXACTLY AS THEY APPEAR IN THE LIST.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a0d797a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "39c33754",
   "metadata": {},
   "source": [
    "### Update query function to instead return only the text response instead of printint Q, A, and token info."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f8f4a30",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6743524b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "442b6393",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
