{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Read Data",
   "id": "84ca79699badd69f"
  },
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-09-20T19:57:59.130952Z",
     "start_time": "2025-09-20T19:57:58.735532Z"
    }
   },
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "data = pd.read_json(\"train.jsonl\", lines = True) # read data\n",
    "n = data.shape[0]\n",
    "\n",
    "np.random.seed(42)\n",
    "indices = np.arange(n)\n",
    "data_calib = data.iloc[indices[:100], :]\n",
    "data_aug = data.iloc[indices[100:], :]"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-20T19:57:59.203717Z",
     "start_time": "2025-09-20T19:57:59.187338Z"
    }
   },
   "cell_type": "code",
   "source": [
    "data_train = pd.concat((data_calib, data_aug), axis = 0, ignore_index = True)\n",
    "data_train.to_csv(\"training_data_symptom.csv\", index = False)\n",
    "\n",
    "data_calib.to_csv(\"calib_original.csv\", index = False)\n",
    "\n",
    "data_aug.to_csv(\"aug_original.csv\", index = False)\n",
    "\n",
    "\n",
    "disease_ls = \", \".join(data[\"output_text\"].drop_duplicates())\n"
   ],
   "id": "72a3e464ceca04ce",
   "outputs": [],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-20T19:57:59.255569Z",
     "start_time": "2025-09-20T19:57:59.250139Z"
    }
   },
   "cell_type": "code",
   "source": "disease_ls",
   "id": "7975707c283ab845",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cervical spondylosis, impetigo, urinary tract infection, arthritis, dengue, common cold, drug reaction, fungal infection, malaria, allergy, bronchial asthma, varicose veins, migraine, hypertension, gastroesophageal reflux disease, pneumonia, psoriasis, diabetes, jaundice, chicken pox, typhoid, peptic ulcer disease'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 3
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Generations",
   "id": "a7d69aae1ed207dd"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-13T19:45:55.304371Z",
     "start_time": "2025-09-13T19:45:55.252421Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import time\n",
    "import random\n",
    "import concurrent.futures\n",
    "import re\n",
    "from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type\n",
    "import openai\n",
    "from openai import OpenAI\n",
    "\n",
    "client = OpenAI(api_key=\"...\")\n",
    "\n",
    "class RateLimitException(Exception):\n",
    "    pass\n",
    "\n",
    "@retry(\n",
    "    retry=retry_if_exception_type(RateLimitException),\n",
    "    wait=wait_exponential(multiplier=1, min=1, max=60),\n",
    "    stop=stop_after_attempt(15)\n",
    ")\n",
    "def call_openai_api(client, prompt, n=20):\n",
    "    try:\n",
    "        response = client.chat.completions.create(\n",
    "            model=\"gpt-4.1-nano\",\n",
    "            messages=[{\"role\": \"user\", \"content\": prompt}],\n",
    "            n=n,\n",
    "            temperature=1.5\n",
    "        )\n",
    "        return response\n",
    "    except Exception as e:\n",
    "        error_message = str(e).lower()\n",
    "        if \"rate limit\" in error_message or \"429\" in error_message:\n",
    "            wait_time_match = re.search(r'try again in (\\d+)ms', error_message)\n",
    "            if wait_time_match:\n",
    "                wait_ms = int(wait_time_match.group(1))\n",
    "                wait_time = (wait_ms / 1000) + random.uniform(0.1, 0.5)\n",
    "            else:\n",
    "                wait_time = random.uniform(1, 3)\n",
    "\n",
    "            print(f\"Rate limit hit. Waiting for {wait_time:.2f} seconds before retry...\")\n",
    "            time.sleep(wait_time)\n",
    "            raise RateLimitException(\"Rate limit exceeded\")\n",
    "        else:\n",
    "            raise\n",
    "\n",
    "import re\n",
    "\n",
    "def count_sentences(text: str) -> list[str]:\n",
    "    \"\"\"\n",
    "    Split text into sentences conservatively.\n",
    "    Returns a list of non-empty sentences.\n",
    "    \"\"\"\n",
    "    # Split on ., ?, ! followed by space/newline or end of string\n",
    "    parts = re.split(r'(?<=[.!?])\\s+(?=[A-Z0-9(])|(?<=[.!?])$', text.strip())\n",
    "    # Clean up stray empties/whitespace\n",
    "    sents = [s.strip() for s in parts if s and s.strip()]\n",
    "    return sents\n",
    "\n",
    "def process_item(client, idx, sentences):\n",
    "    symptom = sentences[0]\n",
    "    base_prompt = (\n",
    "        \"You are given a description of a disease.\\n\\n\"\n",
    "        f\"Description: {symptom}\\n\\n\"\n",
    "        \"Task: Extend the symptom description with additional details that still plausibly describe the SAME disease.\\n\"\n",
    "        \"- Write EXACTLY 5 sentences.\\n\"\n",
    "        \"- Do not copy wording from the original; paraphrase and add plausible details consistent with the same condition.\\n\"\n",
    "        \"- Avoid lists, bullets, headings, or numbering; just 5 full sentences in a single paragraph.\\n\"\n",
    "        \"- No disclaimers, no citations, no markdown.\\n\"\n",
    "    )\n",
    "\n",
    "    max_attempts = 20\n",
    "    while max_attempts > 0:\n",
    "        try:\n",
    "            response = call_openai_api(client, base_prompt, n=1)\n",
    "            content = response.choices[0].message.content.strip()\n",
    "            sents = count_sentences(content)\n",
    "\n",
    "            if len(sents) == 5:\n",
    "                # Return clean paragraph\n",
    "                return sents\n",
    "            else:\n",
    "                # Not exactly 10 → retry\n",
    "                max_attempts -= 1\n",
    "        except Exception:\n",
    "            max_attempts -= 1\n",
    "\n",
    "    print(f\"Warning: Could not get exactly 5 sentences for item {idx}\")\n",
    "    return None\n",
    "\n",
    "def process_with_rate_limiting(client, input_texts, max_concurrent=5, batch_size=20):\n",
    "    all_responses = []\n",
    "\n",
    "    for batch_start in range(0, len(input_texts), batch_size):\n",
    "        batch_end = min(batch_start + batch_size, len(input_texts))\n",
    "        batch = input_texts[batch_start:batch_end]\n",
    "\n",
    "        print(f\"Processing batch {batch_start // batch_size + 1}, items {batch_start} to {batch_end - 1}\")\n",
    "\n",
    "        with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent) as executor:\n",
    "            futures = [\n",
    "                executor.submit(process_item, client, idx + batch_start, sentences)\n",
    "                for idx, sentences in enumerate(batch)\n",
    "            ]\n",
    "\n",
    "            # Ensure order of responses matches order of input_texts\n",
    "            batch_results = [future.result() for future in futures]\n",
    "\n",
    "        all_responses.extend(batch_results)\n",
    "\n",
    "        if batch_end < len(input_texts):\n",
    "            wait_time = random.uniform(1, 3)\n",
    "            time.sleep(wait_time)\n",
    "\n",
    "    return all_responses"
   ],
   "id": "29494a44fbee03d",
   "outputs": [],
   "execution_count": 25
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-13T19:46:03.447337Z",
     "start_time": "2025-09-13T19:46:03.433571Z"
    }
   },
   "cell_type": "code",
   "source": [
    "pairs = []\n",
    "for i in range(data_train.shape[0]):\n",
    "    pairs.append((data_train[\"input_text\"].iloc[i], data_train[\"output_text\"].iloc[i]))"
   ],
   "id": "2834b4db49e026c4",
   "outputs": [],
   "execution_count": 26
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-13T19:54:40.070194Z",
     "start_time": "2025-09-13T19:47:00.377919Z"
    }
   },
   "cell_type": "code",
   "source": "generations = process_with_rate_limiting(client, pairs)",
   "id": "32bd50e5ac56d58b",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing batch 1, items 0 to 19\n",
      "Processing batch 2, items 20 to 39\n",
      "Processing batch 3, items 40 to 59\n",
      "Processing batch 4, items 60 to 79\n",
      "Processing batch 5, items 80 to 99\n",
      "Processing batch 6, items 100 to 119\n",
      "Processing batch 7, items 120 to 139\n",
      "Processing batch 8, items 140 to 159\n",
      "Processing batch 9, items 160 to 179\n",
      "Processing batch 10, items 180 to 199\n",
      "Processing batch 11, items 200 to 219\n",
      "Processing batch 12, items 220 to 239\n",
      "Processing batch 13, items 240 to 259\n",
      "Processing batch 14, items 260 to 279\n",
      "Processing batch 15, items 280 to 299\n",
      "Processing batch 16, items 300 to 319\n",
      "Processing batch 17, items 320 to 339\n",
      "Processing batch 18, items 340 to 359\n",
      "Processing batch 19, items 360 to 379\n",
      "Processing batch 20, items 380 to 399\n",
      "Processing batch 21, items 400 to 419\n",
      "Processing batch 22, items 420 to 439\n",
      "Processing batch 23, items 440 to 459\n",
      "Processing batch 24, items 460 to 479\n",
      "Processing batch 25, items 480 to 499\n",
      "Processing batch 26, items 500 to 519\n",
      "Processing batch 27, items 520 to 539\n",
      "Processing batch 28, items 540 to 559\n",
      "Processing batch 29, items 560 to 579\n",
      "Processing batch 30, items 580 to 599\n",
      "Processing batch 31, items 600 to 619\n",
      "Processing batch 32, items 620 to 639\n",
      "Processing batch 33, items 640 to 659\n",
      "Processing batch 34, items 660 to 679\n",
      "Processing batch 35, items 680 to 699\n",
      "Processing batch 36, items 700 to 719\n",
      "Processing batch 37, items 720 to 739\n",
      "Processing batch 38, items 740 to 759\n",
      "Processing batch 39, items 760 to 779\n",
      "Processing batch 40, items 780 to 799\n",
      "Processing batch 41, items 800 to 819\n",
      "Processing batch 42, items 820 to 839\n",
      "Processing batch 43, items 840 to 852\n"
     ]
    }
   ],
   "execution_count": 32
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-13T19:54:40.116018Z",
     "start_time": "2025-09-13T19:54:40.110232Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import pickle\n",
    "\n",
    "with open(\"generations_extend10.pickle\", \"wb\") as file:\n",
    "    pickle.dump(generations, file)"
   ],
   "id": "8370a548335fd2ff",
   "outputs": [],
   "execution_count": 33
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Gemini as a Judge",
   "id": "46c5ca6e97d85e20"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-13T21:21:12.858135Z",
     "start_time": "2025-09-13T21:21:12.421408Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import google.generativeai as genai\n",
    "import os\n",
    "import time\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "# Load environment variables\n",
    "load_dotenv()\n",
    "genai.configure(api_key=os.getenv(\"GOOGLE_API_KEY\"))\n",
    "\n",
    "model = genai.GenerativeModel(\"gemini-2.5-flash\")\n",
    "\n",
    "def batch_eval(all_pairs, disease_ls, batch_size=20, max_retries=5, sleep_time=2, save_every=25, save_path=\"scores_checkpoint.csv\"):\n",
    "    all_scores = []\n",
    "    batch_counter = 0  # track how many batches processed\n",
    "\n",
    "    for start in range(0, all_pairs.shape[0], batch_size):\n",
    "        end = min(start + batch_size, all_pairs.shape[0])\n",
    "        batch = all_pairs.iloc[start:end]\n",
    "\n",
    "        # Build base prompt\n",
    "        prompt = (\n",
    "            \"You are evaluating individual symptom descriptions for diseases.\\n\\n\"\n",
    "            \"Scoring instructions:\\n\"\n",
    "            \"- Assign each description a score between 0 and 1, rounded to two decimal places.\\n\"\n",
    "            \"- Criteria: The description should plausibly match the specified disease, avoid confusion with other diseases. \\n\"\n",
    "            \"- Use the full 0–1 range: 1 = perfectly clear, specific, and accurate; 0 = completely unusable.\\n\"\n",
    "            \"- 0.5 is the threshold: any description with a score ≤ 0.5 should be dropped to prevent misclassification.\\n\\n\"\n",
    "            f\"For reference, here is the complete list of possible diseases: {disease_ls}\\n\\n\"\n",
    "            \"Output requirements:\\n\"\n",
    "            \"- Output only the scores, one per line, in the same order as the input cases.\\n\"\n",
    "            \"- Do not include explanations, text, or formatting other than the numeric scores.\\n\\n\"\n",
    "        )\n",
    "\n",
    "        for i, row in batch.iterrows():\n",
    "            symp, diag = row[\"input_text\"], row[\"output_text\"]\n",
    "            prompt += f\"Case {i}:\\nDisease: {diag}\\nSymptom: {symp}\\n\"\n",
    "\n",
    "        # Retry loop\n",
    "        scores = []\n",
    "        for attempt in range(1, max_retries + 1):\n",
    "            try:\n",
    "                response = model.generate_content(prompt)\n",
    "                scores = response.text.strip().splitlines()\n",
    "\n",
    "                if len(scores) == len(batch):\n",
    "                    break  # ✅ got the right number of outputs\n",
    "                else:\n",
    "                    print(\n",
    "                        f\"⚠️ Attempt {attempt}: Expected {len(batch)} scores, got {len(scores)}. Retrying...\"\n",
    "                    )\n",
    "                    time.sleep(sleep_time)\n",
    "\n",
    "            except Exception as e:\n",
    "                print(f\"❌ Error on attempt {attempt}: {e}\")\n",
    "                time.sleep(sleep_time)\n",
    "\n",
    "        if len(scores) != len(batch):\n",
    "            raise ValueError(\n",
    "                f\"Failed after {max_retries} retries: Expected {len(batch)} scores, got {len(scores)}\"\n",
    "            )\n",
    "\n",
    "        all_scores.extend(scores)\n",
    "        batch_counter += 1\n",
    "        print(f\"✅ Processed {end}/{all_pairs.shape[0]}\")\n",
    "\n",
    "        if batch_counter % save_every == 0:\n",
    "            pd.DataFrame({\"score\": all_scores}).to_csv(save_path, index=False)\n",
    "            print(f\"💾 Saved checkpoint after {batch_counter} batches at {save_path}\")\n",
    "\n",
    "    pd.DataFrame({\"score\": all_scores}).to_csv(save_path, index=False)\n",
    "    print(f\"🎉 Finished. Final results saved at {save_path}\")\n",
    "\n",
    "    return all_scores"
   ],
   "id": "b7335aa73eafe8b9",
   "outputs": [],
   "execution_count": 3
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-13T21:21:20.994752Z",
     "start_time": "2025-09-13T21:21:20.980846Z"
    }
   },
   "cell_type": "code",
   "source": [
    "\n",
    "import pickle\n",
    "with open(\"generations_extend10.pickle\", \"rb\") as file:\n",
    "    generations = pickle.load(file)\n",
    "\n",
    "all_rows = []\n",
    "\n",
    "# 2. Loop through the augmented data.\n",
    "for i in range(data_train.shape[0]):\n",
    "    output_text = data_train[\"output_text\"].iloc[i]\n",
    "\n",
    "    # Add the original row\n",
    "    # Add the augmented rows\n",
    "    for j in generations[i]:\n",
    "        all_rows.append([j, output_text])\n",
    "\n",
    "\n",
    "# 4. Create the DataFrame from the list in one single, efficient operation.\n",
    "df = pd.DataFrame(all_rows, columns=[\"input_text\", \"output_text\"])\n"
   ],
   "id": "f1407fad7861bff6",
   "outputs": [],
   "execution_count": 4
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-13T22:14:45.620409Z",
     "start_time": "2025-09-13T21:22:14.881277Z"
    }
   },
   "cell_type": "code",
   "source": "scores = batch_eval(df, disease_ls, batch_size=10, save_path=\"scores_flash.csv\", save_every = 10)",
   "id": "bbec944173f547dd",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Processed 10/2065\n",
      "✅ Processed 20/2065\n",
      "✅ Processed 30/2065\n",
      "✅ Processed 40/2065\n",
      "✅ Processed 50/2065\n",
      "✅ Processed 60/2065\n",
      "✅ Processed 70/2065\n",
      "✅ Processed 80/2065\n",
      "✅ Processed 90/2065\n",
      "✅ Processed 100/2065\n",
      "💾 Saved checkpoint after 10 batches at scores_flash.csv\n",
      "✅ Processed 110/2065\n",
      "✅ Processed 120/2065\n",
      "✅ Processed 130/2065\n",
      "✅ Processed 140/2065\n",
      "✅ Processed 150/2065\n",
      "✅ Processed 160/2065\n",
      "✅ Processed 170/2065\n",
      "✅ Processed 180/2065\n",
      "✅ Processed 190/2065\n",
      "✅ Processed 200/2065\n",
      "💾 Saved checkpoint after 20 batches at scores_flash.csv\n",
      "✅ Processed 210/2065\n",
      "✅ Processed 220/2065\n",
      "✅ Processed 230/2065\n",
      "✅ Processed 240/2065\n",
      "✅ Processed 250/2065\n",
      "✅ Processed 260/2065\n",
      "✅ Processed 270/2065\n",
      "✅ Processed 280/2065\n",
      "✅ Processed 290/2065\n",
      "✅ Processed 300/2065\n",
      "💾 Saved checkpoint after 30 batches at scores_flash.csv\n",
      "✅ Processed 310/2065\n",
      "✅ Processed 320/2065\n",
      "✅ Processed 330/2065\n",
      "✅ Processed 340/2065\n",
      "✅ Processed 350/2065\n",
      "✅ Processed 360/2065\n",
      "✅ Processed 370/2065\n",
      "✅ Processed 380/2065\n",
      "✅ Processed 390/2065\n",
      "✅ Processed 400/2065\n",
      "💾 Saved checkpoint after 40 batches at scores_flash.csv\n",
      "✅ Processed 410/2065\n",
      "✅ Processed 420/2065\n",
      "✅ Processed 430/2065\n",
      "✅ Processed 440/2065\n",
      "✅ Processed 450/2065\n",
      "✅ Processed 460/2065\n",
      "✅ Processed 470/2065\n",
      "✅ Processed 480/2065\n",
      "✅ Processed 490/2065\n",
      "✅ Processed 500/2065\n",
      "💾 Saved checkpoint after 50 batches at scores_flash.csv\n",
      "✅ Processed 510/2065\n",
      "✅ Processed 520/2065\n",
      "✅ Processed 530/2065\n",
      "✅ Processed 540/2065\n",
      "✅ Processed 550/2065\n",
      "✅ Processed 560/2065\n",
      "✅ Processed 570/2065\n",
      "✅ Processed 580/2065\n",
      "✅ Processed 590/2065\n",
      "✅ Processed 600/2065\n",
      "💾 Saved checkpoint after 60 batches at scores_flash.csv\n",
      "✅ Processed 610/2065\n",
      "✅ Processed 620/2065\n",
      "✅ Processed 630/2065\n",
      "✅ Processed 640/2065\n",
      "✅ Processed 650/2065\n",
      "✅ Processed 660/2065\n",
      "✅ Processed 670/2065\n",
      "✅ Processed 680/2065\n",
      "✅ Processed 690/2065\n",
      "✅ Processed 700/2065\n",
      "💾 Saved checkpoint after 70 batches at scores_flash.csv\n",
      "✅ Processed 710/2065\n",
      "✅ Processed 720/2065\n",
      "✅ Processed 730/2065\n",
      "✅ Processed 740/2065\n",
      "✅ Processed 750/2065\n",
      "✅ Processed 760/2065\n",
      "✅ Processed 770/2065\n",
      "✅ Processed 780/2065\n",
      "✅ Processed 790/2065\n",
      "✅ Processed 800/2065\n",
      "💾 Saved checkpoint after 80 batches at scores_flash.csv\n",
      "✅ Processed 810/2065\n",
      "✅ Processed 820/2065\n",
      "✅ Processed 830/2065\n",
      "✅ Processed 840/2065\n",
      "✅ Processed 850/2065\n",
      "✅ Processed 860/2065\n",
      "✅ Processed 870/2065\n",
      "✅ Processed 880/2065\n",
      "✅ Processed 890/2065\n",
      "✅ Processed 900/2065\n",
      "💾 Saved checkpoint after 90 batches at scores_flash.csv\n",
      "✅ Processed 910/2065\n",
      "✅ Processed 920/2065\n",
      "✅ Processed 930/2065\n",
      "✅ Processed 940/2065\n",
      "✅ Processed 950/2065\n",
      "✅ Processed 960/2065\n",
      "✅ Processed 970/2065\n",
      "✅ Processed 980/2065\n",
      "✅ Processed 990/2065\n",
      "✅ Processed 1000/2065\n",
      "💾 Saved checkpoint after 100 batches at scores_flash.csv\n",
      "✅ Processed 1010/2065\n",
      "✅ Processed 1020/2065\n",
      "✅ Processed 1030/2065\n",
      "✅ Processed 1040/2065\n",
      "✅ Processed 1050/2065\n",
      "✅ Processed 1060/2065\n",
      "✅ Processed 1070/2065\n",
      "✅ Processed 1080/2065\n",
      "✅ Processed 1090/2065\n",
      "✅ Processed 1100/2065\n",
      "💾 Saved checkpoint after 110 batches at scores_flash.csv\n",
      "✅ Processed 1110/2065\n",
      "✅ Processed 1120/2065\n",
      "✅ Processed 1130/2065\n",
      "✅ Processed 1140/2065\n",
      "✅ Processed 1150/2065\n",
      "✅ Processed 1160/2065\n",
      "✅ Processed 1170/2065\n",
      "✅ Processed 1180/2065\n",
      "✅ Processed 1190/2065\n",
      "✅ Processed 1200/2065\n",
      "💾 Saved checkpoint after 120 batches at scores_flash.csv\n",
      "✅ Processed 1210/2065\n",
      "✅ Processed 1220/2065\n",
      "✅ Processed 1230/2065\n",
      "✅ Processed 1240/2065\n",
      "✅ Processed 1250/2065\n",
      "✅ Processed 1260/2065\n",
      "✅ Processed 1270/2065\n",
      "✅ Processed 1280/2065\n",
      "✅ Processed 1290/2065\n",
      "✅ Processed 1300/2065\n",
      "💾 Saved checkpoint after 130 batches at scores_flash.csv\n",
      "✅ Processed 1310/2065\n",
      "✅ Processed 1320/2065\n",
      "✅ Processed 1330/2065\n",
      "✅ Processed 1340/2065\n",
      "✅ Processed 1350/2065\n",
      "✅ Processed 1360/2065\n",
      "✅ Processed 1370/2065\n",
      "✅ Processed 1380/2065\n",
      "✅ Processed 1390/2065\n",
      "✅ Processed 1400/2065\n",
      "💾 Saved checkpoint after 140 batches at scores_flash.csv\n",
      "✅ Processed 1410/2065\n",
      "✅ Processed 1420/2065\n",
      "✅ Processed 1430/2065\n",
      "✅ Processed 1440/2065\n",
      "✅ Processed 1450/2065\n",
      "✅ Processed 1460/2065\n",
      "✅ Processed 1470/2065\n",
      "✅ Processed 1480/2065\n",
      "✅ Processed 1490/2065\n",
      "✅ Processed 1500/2065\n",
      "💾 Saved checkpoint after 150 batches at scores_flash.csv\n",
      "✅ Processed 1510/2065\n",
      "✅ Processed 1520/2065\n",
      "✅ Processed 1530/2065\n",
      "✅ Processed 1540/2065\n",
      "✅ Processed 1550/2065\n",
      "✅ Processed 1560/2065\n",
      "✅ Processed 1570/2065\n",
      "✅ Processed 1580/2065\n",
      "✅ Processed 1590/2065\n",
      "✅ Processed 1600/2065\n",
      "💾 Saved checkpoint after 160 batches at scores_flash.csv\n",
      "✅ Processed 1610/2065\n",
      "✅ Processed 1620/2065\n",
      "✅ Processed 1630/2065\n",
      "✅ Processed 1640/2065\n",
      "✅ Processed 1650/2065\n",
      "✅ Processed 1660/2065\n",
      "✅ Processed 1670/2065\n",
      "✅ Processed 1680/2065\n",
      "✅ Processed 1690/2065\n",
      "✅ Processed 1700/2065\n",
      "💾 Saved checkpoint after 170 batches at scores_flash.csv\n",
      "✅ Processed 1710/2065\n",
      "✅ Processed 1720/2065\n",
      "✅ Processed 1730/2065\n",
      "✅ Processed 1740/2065\n",
      "✅ Processed 1750/2065\n",
      "✅ Processed 1760/2065\n",
      "✅ Processed 1770/2065\n",
      "✅ Processed 1780/2065\n",
      "✅ Processed 1790/2065\n",
      "✅ Processed 1800/2065\n",
      "💾 Saved checkpoint after 180 batches at scores_flash.csv\n",
      "✅ Processed 1810/2065\n",
      "✅ Processed 1820/2065\n",
      "✅ Processed 1830/2065\n",
      "✅ Processed 1840/2065\n",
      "✅ Processed 1850/2065\n",
      "✅ Processed 1860/2065\n",
      "✅ Processed 1870/2065\n",
      "✅ Processed 1880/2065\n",
      "✅ Processed 1890/2065\n",
      "✅ Processed 1900/2065\n",
      "💾 Saved checkpoint after 190 batches at scores_flash.csv\n",
      "✅ Processed 1910/2065\n",
      "✅ Processed 1920/2065\n",
      "✅ Processed 1930/2065\n",
      "✅ Processed 1940/2065\n",
      "✅ Processed 1950/2065\n",
      "✅ Processed 1960/2065\n",
      "✅ Processed 1970/2065\n",
      "✅ Processed 1980/2065\n",
      "✅ Processed 1990/2065\n",
      "✅ Processed 2000/2065\n",
      "💾 Saved checkpoint after 200 batches at scores_flash.csv\n",
      "✅ Processed 2010/2065\n",
      "✅ Processed 2020/2065\n",
      "✅ Processed 2030/2065\n",
      "✅ Processed 2040/2065\n",
      "✅ Processed 2050/2065\n",
      "✅ Processed 2060/2065\n",
      "✅ Processed 2065/2065\n",
      "🎉 Finished. Final results saved at scores_flash.csv\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-13T22:43:12.262955Z",
     "start_time": "2025-09-13T22:43:12.228884Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import google.generativeai as genai\n",
    "import os\n",
    "import time\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "# Load environment variables\n",
    "load_dotenv()\n",
    "genai.configure(api_key=os.getenv(\"GOOGLE_API_KEY\"))\n",
    "\n",
    "model = genai.GenerativeModel(\"gemini-2.5-pro\")\n",
    "\n",
    "def batch_eval(all_pairs, disease_ls, batch_size=20, max_retries=5, sleep_time=2, save_every=50, save_path=\"scores_checkpoint.csv\"):\n",
    "    all_scores = []\n",
    "    batch_counter = 0  # track how many batches processed\n",
    "\n",
    "    for start in range(0, all_pairs.shape[0], batch_size):\n",
    "        end = min(start + batch_size, all_pairs.shape[0])\n",
    "        batch = all_pairs.iloc[start:end]\n",
    "\n",
    "        # Build base prompt\n",
    "        prompt = (\n",
    "            \"You are evaluating individual symptom descriptions for diseases.\\n\\n\"\n",
    "            \"Scoring instructions:\\n\"\n",
    "            \"- Assign each description a score between 0 and 1, rounded to two decimal places.\\n\"\n",
    "            \"- Criteria: The description should plausibly match the specified disease, avoid confusion with other diseases. \\n\"\n",
    "            \"- Use the full 0–1 range: 1 = perfectly clear, specific, and accurate; 0 = completely unusable.\\n\"\n",
    "            \"- 0.5 is the threshold: any description with a score ≤ 0.5 should be dropped to prevent misclassification.\\n\\n\"\n",
    "            f\"For reference, here is the complete list of possible diseases: {disease_ls}\\n\\n\"\n",
    "            \"Output requirements:\\n\"\n",
    "            \"- Output only the scores, one per line, in the same order as the input cases.\\n\"\n",
    "            \"- Do not include explanations, text, or formatting other than the numeric scores.\\n\\n\"\n",
    "        )\n",
    "\n",
    "        for i, row in batch.iterrows():\n",
    "            symp, diag = row[\"input_text\"], row[\"output_text\"]\n",
    "            prompt += f\"Case {i}:\\nDisease: {diag}\\nSymptom: {symp}\\n\"\n",
    "        # Retry loop\n",
    "        scores = []\n",
    "        for attempt in range(1, max_retries + 1):\n",
    "            try:\n",
    "                response = model.generate_content(prompt)\n",
    "                scores = response.text.strip().splitlines()\n",
    "\n",
    "                if len(scores) == len(batch):\n",
    "                    break  # ✅ got the right number of outputs\n",
    "                else:\n",
    "                    print(\n",
    "                        f\"⚠️ Attempt {attempt}: Expected {len(batch)} scores, got {len(scores)}. Retrying...\"\n",
    "                    )\n",
    "                    time.sleep(sleep_time)\n",
    "\n",
    "            except Exception as e:\n",
    "                print(f\"❌ Error on attempt {attempt}: {e}\")\n",
    "                time.sleep(sleep_time)\n",
    "\n",
    "        if len(scores) != len(batch):\n",
    "            raise ValueError(\n",
    "                f\"Failed after {max_retries} retries: Expected {len(batch)} scores, got {len(scores)}\"\n",
    "            )\n",
    "\n",
    "        all_scores.extend(scores)\n",
    "        batch_counter += 1\n",
    "        print(f\"✅ Processed {end}/{all_pairs.shape[0]}\")\n",
    "\n",
    "        if batch_counter % save_every == 0:\n",
    "            pd.DataFrame({\"score\": all_scores}).to_csv(save_path, index=False)\n",
    "            print(f\"💾 Saved checkpoint after {batch_counter} batches at {save_path}\")\n",
    "\n",
    "    pd.DataFrame({\"score\": all_scores}).to_csv(save_path, index=False)\n",
    "    print(f\"🎉 Finished. Final results saved at {save_path}\")\n",
    "\n",
    "    return all_scores"
   ],
   "id": "ad0b21a894109fa3",
   "outputs": [],
   "execution_count": 9
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-13T22:43:20.625639Z",
     "start_time": "2025-09-13T22:43:20.596005Z"
    }
   },
   "cell_type": "code",
   "source": [
    "\n",
    "import pickle\n",
    "with open(\"generations_extend10.pickle\", \"rb\") as file:\n",
    "    generations = pickle.load(file)\n",
    "\n",
    "all_rows = []\n",
    "\n",
    "# 2. Loop through the augmented data.\n",
    "for i in range(100):\n",
    "    output_text = data_train[\"output_text\"].iloc[i]\n",
    "\n",
    "    # Add the original row\n",
    "    # Add the augmented rows\n",
    "    for j in generations[i]:\n",
    "        all_rows.append([j, output_text])\n",
    "\n",
    "\n",
    "# 4. Create the DataFrame from the list in one single, efficient operation.\n",
    "df = pd.DataFrame(all_rows, columns=[\"input_text\", \"output_text\"])\n"
   ],
   "id": "586c0e607f32021b",
   "outputs": [],
   "execution_count": 10
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-13T22:43:24.887012Z",
     "start_time": "2025-09-13T22:43:24.874804Z"
    }
   },
   "cell_type": "code",
   "source": "df.shape",
   "id": "5a9fecab5d094387",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(500, 2)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 11
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-13T22:58:41.586525Z",
     "start_time": "2025-09-13T22:43:28.528226Z"
    }
   },
   "cell_type": "code",
   "source": "scores = batch_eval(df, disease_ls, batch_size=10, save_path=\"scores_pro.csv\", save_every = 10)\n",
   "id": "c90320553f546ab7",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Processed 10/500\n",
      "✅ Processed 20/500\n",
      "✅ Processed 30/500\n",
      "✅ Processed 40/500\n",
      "✅ Processed 50/500\n",
      "✅ Processed 60/500\n",
      "✅ Processed 70/500\n",
      "✅ Processed 80/500\n",
      "✅ Processed 90/500\n",
      "✅ Processed 100/500\n",
      "💾 Saved checkpoint after 10 batches at scores_pro.csv\n",
      "✅ Processed 110/500\n",
      "✅ Processed 120/500\n",
      "✅ Processed 130/500\n",
      "✅ Processed 140/500\n",
      "✅ Processed 150/500\n",
      "✅ Processed 160/500\n",
      "✅ Processed 170/500\n",
      "✅ Processed 180/500\n",
      "✅ Processed 190/500\n",
      "✅ Processed 200/500\n",
      "💾 Saved checkpoint after 20 batches at scores_pro.csv\n",
      "✅ Processed 210/500\n",
      "✅ Processed 220/500\n",
      "✅ Processed 230/500\n",
      "✅ Processed 240/500\n",
      "✅ Processed 250/500\n",
      "✅ Processed 260/500\n",
      "✅ Processed 270/500\n",
      "✅ Processed 280/500\n",
      "✅ Processed 290/500\n",
      "✅ Processed 300/500\n",
      "💾 Saved checkpoint after 30 batches at scores_pro.csv\n",
      "✅ Processed 310/500\n",
      "✅ Processed 320/500\n",
      "✅ Processed 330/500\n",
      "✅ Processed 340/500\n",
      "✅ Processed 350/500\n",
      "✅ Processed 360/500\n",
      "✅ Processed 370/500\n",
      "✅ Processed 380/500\n",
      "✅ Processed 390/500\n",
      "✅ Processed 400/500\n",
      "💾 Saved checkpoint after 40 batches at scores_pro.csv\n",
      "✅ Processed 410/500\n",
      "✅ Processed 420/500\n",
      "✅ Processed 430/500\n",
      "✅ Processed 440/500\n",
      "✅ Processed 450/500\n",
      "✅ Processed 460/500\n",
      "✅ Processed 470/500\n",
      "✅ Processed 480/500\n",
      "✅ Processed 490/500\n",
      "✅ Processed 500/500\n",
      "💾 Saved checkpoint after 50 batches at scores_pro.csv\n",
      "🎉 Finished. Final results saved at scores_pro.csv\n"
     ]
    }
   ],
   "execution_count": 12
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Read the data",
   "id": "4353a47aab37b8c2"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-13T23:00:06.104214Z",
     "start_time": "2025-09-13T23:00:06.070193Z"
    }
   },
   "cell_type": "code",
   "source": [
    "scores_flash =  pd.read_csv(\"scores_flash.csv\")\n",
    "scores_pro = pd.read_csv(\"scores_pro.csv\")"
   ],
   "id": "aa884aaa5f335760",
   "outputs": [],
   "execution_count": 13
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-14T03:48:24.715963Z",
     "start_time": "2025-09-14T03:48:24.703792Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import pickle\n",
    "with open(\"generations_extend10.pickle\", \"rb\") as file:\n",
    "    generations = pickle.load(file)\n",
    "\n",
    "all_rows = []\n",
    "\n",
    "for i in range(data_train.shape[0]):\n",
    "    output_text = data_train[\"output_text\"].iloc[i]\n",
    "    for j in generations[i]:\n",
    "        all_rows.append([j, output_text])\n",
    "\n",
    "df = pd.DataFrame(all_rows, columns=[\"input_text\", \"output_text\"])\n",
    "df[\"pro-score\"] = -1\n",
    "df[\"flash-score\"] = scores_flash.values\n",
    "df.iloc[:500, 2] = scores_pro.values.ravel()"
   ],
   "id": "cea9bbabb7b8df0",
   "outputs": [],
   "execution_count": 26
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
