{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a8ef9ad9-c7e1-4fed-b0bd-581064558089",
   "metadata": {},
   "source": [
    "# GPT-3.5-Turbo Performance on MMLU - College Chemistry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4f1c09a4-4859-469b-a156-dbb037c83a65",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import openai\n",
    "import re\n",
    "import time\n",
    "import json\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "from tenacity import retry, stop_after_attempt, wait_chain, wait_fixed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "eaea2a2a-2515-4508-9cdb-084d10853170",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "openai.api_key = \"sk-\" "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "95ddd688-5bf5-40c5-a852-32e62f5a1bbb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "@retry(wait=wait_chain(*[wait_fixed(3) for i in range(3)] +\n",
    "                       [wait_fixed(5) for i in range(2)] +\n",
    "                       [wait_fixed(10)]))\n",
    "def completion_with_backoff(**kwargs):\n",
    "    return openai.ChatCompletion.create(**kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d4063503-c0e0-4df7-9866-0815f4c9fdf0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "mmlu_prompt = json.load(open('lib_prompt/mmlu-cot.json'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bbbbe828-fad8-48a9-9cab-4a7e675ecf9e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The following are multiple choice questions (with answers) about college chemistry.\n",
      "\n",
      "Q: 3 Cl−(aq) + 4 CrO_4^2−(aq) + 23 H+(aq) → 3 HClO2(aq) + 4 Cr3+(aq) + 10 H2O(l). In the reaction shown above, Cl−(aq) behaves as\n",
      "(A) an acid (B) a base (C) a catalyst (D) a reducing agent\n",
      "A: Let's think step by step. A molecule that behaves as a base accepts an H+ ion (or proton) from another molecule, whereas a molecule that behaves as an acid donates an H+ ion (or proton) to another molecule. Neither of these is the case for Cl in this reaction, which rules out (A) and (B). A catalyst is a substance that only accelerates a reaction without itself undergoing chemical change, which is not the case here. This rules out (C). Instead, the $Cl^{-} molecules carry a negative charge, which they donate in the reaction to form 3 HClO2. This is the behavior of a reducing agent, or (D). The answer is (D).\n",
      "\n",
      "Q: Which of the following statements about the lanthanide elements is NOT true?\n",
      "(A) The most common oxidation state for the lanthanide elements is +3. (B) Lanthanide complexes often have high coordination numbers (> 6). (C) All of the lanthanide elements react with aqueous acid to liberate hydrogen. (D) The atomic radii of the lanthanide elements increase across the period from La to Lu.\n",
      "A: Let's think step by step. The atomic radii of the lanthanide elements in fact decrease across the period from La to Lu. Options (A), (B), and (C) are all true. This means that only (D) is NOT true. The answer is (D).\n",
      "\n",
      "Q: Which of the following lists the hydrides of group-14 elements in order of thermal stability, from lowest to highest?\n",
      "(A) PbH4 < SnH4 < GeH4 < SiH4 < CH4 (B) PbH4 < SnH4 < CH4 < GeH4 < SiH4 (C) CH4 < SiH4 < GeH4 < SnH4 < PbH4 (D) CH4 < PbH4 < GeH4 < SnH4 < SiH4\n",
      "A: Let's think step by step. The thermal stability of group-14 hydrides decreases as we move from the top of group 14 to the bottom. The order of elements in the group from top to bottom is C, Si, Ge, Sn, Pb. Therefore in order of increasing thermal stability we have PbH4, SnH4, GeH4, SiH4, and CH4, or answer (A). The answer is (A).\n",
      "\n",
      "Q: Predict the number of lines in the EPR spectrum of a solution of 13C-labelled methyl radical (13CH3•), assuming the lines do not overlap.\n",
      "(A) 4 (B) 3 (C) 6 (D) 24 (E) 8\n",
      "A: Let's think step by step. The electron paramagnetic resonance spectrum will be split by two forms of interactions. The first is the hyperfine interaction with the 13C (nuclear spin $I = \n",
      "rac{1}{2}$) which will split the spectrum into 2 lines. This will be further split into 4 lines by the interaction with three equivalent 1H nuclei. The total number of lines is therefore $2 \\cdot 4 = 8$. The answer is (E).\n"
     ]
    }
   ],
   "source": [
    "task = 'college_chemistry'\n",
    "print(mmlu_prompt[task])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ce1c6b79-3530-4efe-bfd6-eead27d09ff3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and preparing dataset mmlu/college_chemistry to /Users/yaofu/.cache/huggingface/datasets/lukaemon___mmlu/college_chemistry/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dea5451020524ea1a7b9c91f737f5ebf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating test split:   0%|          | 0/99 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating validation split:   0%|          | 0/7 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a4633fe4d2ae440aafa6d22c99f7545c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/4 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset mmlu downloaded and prepared to /Users/yaofu/.cache/huggingface/datasets/lukaemon___mmlu/college_chemistry/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3. Subsequent calls will reuse this data.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "847899d154294510a2d17d0ef1559ac3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "task_data = load_dataset(\"lukaemon/mmlu\", task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0957327f-9454-4010-9501-b7086fbba125",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "99"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(task_data['test'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5b18d727-13a7-45cd-a842-3462636d9c25",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input': 'Infrared (IR) spectroscopy is useful for determining certain aspects of the structure of organic molecules because',\n",
       " 'A': 'all molecular bonds absorb IR radiation',\n",
       " 'B': 'IR peak intensities are related to molecular mass',\n",
       " 'C': 'most organic functional groups absorb in a characteristic region of the IR spectrum',\n",
       " 'D': 'each element absorbs at a characteristic wavelength',\n",
       " 'target': 'C'}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "task_data['test'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bec64ac3-96b4-45cd-b79e-6c8ad6fae234",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "prompt_q = mmlu_prompt[task] + \"\\n\\n\" + task_data['test'][0]['input'] + '\\n'\n",
    "for letter in ['A', 'B', 'C', 'D']:\n",
    "    prompt_q += '(' + letter + ') ' + task_data['test'][0][letter] + ' '\n",
    "prompt_q += \"\\nA: Let's think step by step.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8847b4b5-99fe-47ab-941b-5bd02c43755e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The following are multiple choice questions (with answers) about college chemistry.\n",
      "\n",
      "Q: 3 Cl−(aq) + 4 CrO_4^2−(aq) + 23 H+(aq) → 3 HClO2(aq) + 4 Cr3+(aq) + 10 H2O(l). In the reaction shown above, Cl−(aq) behaves as\n",
      "(A) an acid (B) a base (C) a catalyst (D) a reducing agent\n",
      "A: Let's think step by step. A molecule that behaves as a base accepts an H+ ion (or proton) from another molecule, whereas a molecule that behaves as an acid donates an H+ ion (or proton) to another molecule. Neither of these is the case for Cl in this reaction, which rules out (A) and (B). A catalyst is a substance that only accelerates a reaction without itself undergoing chemical change, which is not the case here. This rules out (C). Instead, the $Cl^{-} molecules carry a negative charge, which they donate in the reaction to form 3 HClO2. This is the behavior of a reducing agent, or (D). The answer is (D).\n",
      "\n",
      "Q: Which of the following statements about the lanthanide elements is NOT true?\n",
      "(A) The most common oxidation state for the lanthanide elements is +3. (B) Lanthanide complexes often have high coordination numbers (> 6). (C) All of the lanthanide elements react with aqueous acid to liberate hydrogen. (D) The atomic radii of the lanthanide elements increase across the period from La to Lu.\n",
      "A: Let's think step by step. The atomic radii of the lanthanide elements in fact decrease across the period from La to Lu. Options (A), (B), and (C) are all true. This means that only (D) is NOT true. The answer is (D).\n",
      "\n",
      "Q: Which of the following lists the hydrides of group-14 elements in order of thermal stability, from lowest to highest?\n",
      "(A) PbH4 < SnH4 < GeH4 < SiH4 < CH4 (B) PbH4 < SnH4 < CH4 < GeH4 < SiH4 (C) CH4 < SiH4 < GeH4 < SnH4 < PbH4 (D) CH4 < PbH4 < GeH4 < SnH4 < SiH4\n",
      "A: Let's think step by step. The thermal stability of group-14 hydrides decreases as we move from the top of group 14 to the bottom. The order of elements in the group from top to bottom is C, Si, Ge, Sn, Pb. Therefore in order of increasing thermal stability we have PbH4, SnH4, GeH4, SiH4, and CH4, or answer (A). The answer is (A).\n",
      "\n",
      "Q: Predict the number of lines in the EPR spectrum of a solution of 13C-labelled methyl radical (13CH3•), assuming the lines do not overlap.\n",
      "(A) 4 (B) 3 (C) 6 (D) 24 (E) 8\n",
      "A: Let's think step by step. The electron paramagnetic resonance spectrum will be split by two forms of interactions. The first is the hyperfine interaction with the 13C (nuclear spin $I = \n",
      "rac{1}{2}$) which will split the spectrum into 2 lines. This will be further split into 4 lines by the interaction with three equivalent 1H nuclei. The total number of lines is therefore $2 \\cdot 4 = 8$. The answer is (E).\n",
      "\n",
      "Infrared (IR) spectroscopy is useful for determining certain aspects of the structure of organic molecules because\n",
      "(A) all molecular bonds absorb IR radiation (B) IR peak intensities are related to molecular mass (C) most organic functional groups absorb in a characteristic region of the IR spectrum (D) each element absorbs at a characteristic wavelength \n",
      "A: Let's think step by step.\n"
     ]
    }
   ],
   "source": [
    "print(prompt_q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c91e1e78-0e6d-4d82-9600-b6c76637fa80",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "response = openai.ChatCompletion.create(\n",
    "    model=\"gpt-3.5-turbo\",\n",
    "    messages=[\n",
    "        {\"role\": \"system\", \"content\": \"Follow the given examples and answer the question.\"},\n",
    "        {\"role\": \"user\", \"content\": prompt_q},\n",
    "    ],\n",
    "    temperature=0, \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b1e7b952-30b3-43b3-aa11-06595a7f592c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "IR spectroscopy is useful for determining certain aspects of the structure of organic molecules because most organic functional groups absorb in a characteristic region of the IR spectrum. This is option (C). Option (A) is not true because not all molecular bonds absorb IR radiation. Option (B) is not true because IR peak intensities are not related to molecular mass. Option (D) is not true because each element does not absorb at a characteristic wavelength in IR spectroscopy. The answer is (C).\n"
     ]
    }
   ],
   "source": [
    "print(response['choices'][0]['message']['content'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "867945a0-6527-409e-801e-4244eaf75e61",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def test_answer_mmlu(pred_str, ans_str):\n",
    "    pattern = 'the answer is ('\n",
    "    pred = pred_str.lower().split(pattern)\n",
    "    \n",
    "    if(len(pred) > 1):\n",
    "        # print(pred)\n",
    "        pred = pred[1][0]\n",
    "        gold = ans_str.split('A:\\n')[1][0].lower()\n",
    "        # print('debug 1, pred %s, gold %s' % (pred, gold))\n",
    "        return pred == gold\n",
    "    else: \n",
    "        pred = 'C'\n",
    "        gold = ans_str.split('A:\\n')[1][0].lower()\n",
    "        # print('debug 2, pred %s, gold %s' % (pred, gold))\n",
    "        return pred == gold\n",
    "\n",
    "def parse_pred_ans(filename):\n",
    "    with open(filename) as fd: lines = fd.readlines()\n",
    "    am, a = None, None\n",
    "    num_q, acc = 0, 0\n",
    "    current_mode = 'none'\n",
    "    questions = []\n",
    "    ans_pred = []\n",
    "    ans_gold = []\n",
    "    for l in lines:\n",
    "        if(l.startswith('Q: ')):\n",
    "            if(am is not None and a is not None):\n",
    "                questions.append(q)\n",
    "                ans_pred.append(am)\n",
    "                ans_gold.append(a)\n",
    "                # print(am)\n",
    "                # print(a)\n",
    "                if(test_answer_mmlu(am, a)):\n",
    "                    acc += 1\n",
    "            current_mode = 'q'\n",
    "            q = l\n",
    "            num_q += 1\n",
    "        elif(l.startswith('A_model:')):\n",
    "            current_mode = 'am'\n",
    "            am = l\n",
    "        elif(l.startswith('A:')):\n",
    "            current_mode = 'a'\n",
    "            a = l\n",
    "        else:\n",
    "            if(current_mode == 'q'): q += l\n",
    "            elif(current_mode == 'am'): am += l\n",
    "            elif(current_mode == 'a'): a += l\n",
    "            else:\n",
    "                raise ValueError(current_mode)\n",
    "                \n",
    "    questions.append(q)\n",
    "    ans_pred.append(am)\n",
    "    ans_gold.append(a)\n",
    "    # print(am)\n",
    "    # print(a)\n",
    "    if(test_answer_mmlu(am, a)):\n",
    "        acc += 1\n",
    "    print('num_q %d correct %d ratio %.4f' % (num_q, acc, float(acc / num_q)))\n",
    "    return questions, ans_pred, ans_gold\n",
    "\n",
    "def test_finished(ans_model):\n",
    "    if('answer is' in ans_model): return True\n",
    "    else: return False\n",
    "\n",
    "def extract_ans(ans_model):\n",
    "    ans_model = ans_model.split('\\n')\n",
    "    ans = []\n",
    "    residual = []\n",
    "    for li, al in enumerate(ans_model):\n",
    "        ans.append(al)\n",
    "        if('answer is' in al):\n",
    "            break\n",
    "    residual = list(ans_model[li + 1:])\n",
    "    ans = '\\n'.join(ans)\n",
    "    residual = '\\n'.join(residual)\n",
    "    return ans, residual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b398b88e-d8fd-47fc-be03-f4627f6719e1",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|                                                                                                                                                                                        | 0/99 [00:00<?, ?it/s]\u001b[A\n",
      "  1%|█▊                                                                                                                                                                              | 1/99 [00:02<03:35,  2.20s/it]\u001b[A\n",
      "  2%|███▌                                                                                                                                                                            | 2/99 [00:05<04:08,  2.56s/it]\u001b[A\n",
      "  3%|█████▎                                                                                                                                                                          | 3/99 [00:09<05:47,  3.62s/it]\u001b[A\n",
      "  4%|███████                                                                                                                                                                         | 4/99 [00:11<04:32,  2.87s/it]\u001b[A\n",
      "  5%|████████▉                                                                                                                                                                       | 5/99 [00:13<04:06,  2.63s/it]\u001b[A\n",
      "  6%|██████████▋                                                                                                                                                                     | 6/99 [00:15<03:27,  2.23s/it]\u001b[A\n",
      "  7%|████████████▍                                                                                                                                                                   | 7/99 [00:17<03:14,  2.11s/it]\u001b[A\n",
      "  8%|██████████████▏                                                                                                                                                                 | 8/99 [00:20<03:46,  2.49s/it]\u001b[A\n",
      "  9%|████████████████                                                                                                                                                                | 9/99 [00:23<03:53,  2.60s/it]\u001b[A\n",
      " 10%|█████████████████▋                                                                                                                                                             | 10/99 [00:28<04:50,  3.26s/it]\u001b[A\n",
      " 11%|███████████████████▍                                                                                                                                                           | 11/99 [00:34<06:22,  4.34s/it]\u001b[A\n",
      " 12%|█████████████████████▏                                                                                                                                                         | 12/99 [00:40<06:53,  4.75s/it]\u001b[A\n",
      " 13%|██████████████████████▉                                                                                                                                                        | 13/99 [00:42<05:42,  3.98s/it]\u001b[A\n",
      " 14%|████████████████████████▋                                                                                                                                                      | 14/99 [00:45<05:03,  3.57s/it]\u001b[A\n",
      " 15%|██████████████████████████▌                                                                                                                                                    | 15/99 [00:48<04:55,  3.52s/it]\u001b[A\n",
      " 16%|████████████████████████████▎                                                                                                                                                  | 16/99 [00:49<03:44,  2.70s/it]\u001b[A\n",
      " 17%|██████████████████████████████                                                                                                                                                 | 17/99 [00:53<04:02,  2.95s/it]\u001b[A\n",
      " 18%|███████████████████████████████▊                                                                                                                                               | 18/99 [00:54<03:29,  2.58s/it]\u001b[A\n",
      " 19%|█████████████████████████████████▌                                                                                                                                             | 19/99 [00:57<03:21,  2.52s/it]\u001b[A\n",
      " 20%|███████████████████████████████████▎                                                                                                                                           | 20/99 [01:00<03:32,  2.69s/it]\u001b[A\n",
      " 21%|█████████████████████████████████████                                                                                                                                          | 21/99 [01:01<03:07,  2.41s/it]\u001b[A\n",
      " 22%|██████████████████████████████████████▉                                                                                                                                        | 22/99 [01:05<03:34,  2.79s/it]\u001b[A\n",
      " 23%|████████████████████████████████████████▋                                                                                                                                      | 23/99 [01:08<03:22,  2.67s/it]\u001b[A\n",
      " 24%|██████████████████████████████████████████▍                                                                                                                                    | 24/99 [01:09<02:46,  2.22s/it]\u001b[A\n",
      " 25%|████████████████████████████████████████████▏                                                                                                                                  | 25/99 [01:14<03:42,  3.00s/it]\u001b[A\n",
      " 26%|█████████████████████████████████████████████▉                                                                                                                                 | 26/99 [01:16<03:25,  2.81s/it]\u001b[A\n",
      " 27%|███████████████████████████████████████████████▋                                                                                                                               | 27/99 [01:21<04:16,  3.57s/it]\u001b[A\n",
      " 28%|█████████████████████████████████████████████████▍                                                                                                                             | 28/99 [01:23<03:25,  2.90s/it]\u001b[A\n",
      " 29%|███████████████████████████████████████████████████▎                                                                                                                           | 29/99 [01:28<04:14,  3.64s/it]\u001b[A\n",
      " 30%|█████████████████████████████████████████████████████                                                                                                                          | 30/99 [01:30<03:31,  3.07s/it]\u001b[A\n",
      " 31%|██████████████████████████████████████████████████████▊                                                                                                                        | 31/99 [01:35<04:14,  3.75s/it]\u001b[A\n",
      " 32%|████████████████████████████████████████████████████████▌                                                                                                                      | 32/99 [01:39<04:22,  3.92s/it]\u001b[A\n",
      " 33%|██████████████████████████████████████████████████████████▎                                                                                                                    | 33/99 [01:42<03:54,  3.55s/it]\u001b[A\n",
      " 34%|████████████████████████████████████████████████████████████                                                                                                                   | 34/99 [01:50<05:07,  4.73s/it]\u001b[A\n",
      " 35%|█████████████████████████████████████████████████████████████▊                                                                                                                 | 35/99 [01:53<04:47,  4.49s/it]\u001b[A\n",
      " 36%|███████████████████████████████████████████████████████████████▋                                                                                                               | 36/99 [01:57<04:33,  4.34s/it]\u001b[A\n",
      " 37%|█████████████████████████████████████████████████████████████████▍                                                                                                             | 37/99 [02:00<03:51,  3.73s/it]\u001b[A\n",
      " 38%|███████████████████████████████████████████████████████████████████▏                                                                                                           | 38/99 [02:04<03:54,  3.84s/it]\u001b[A\n",
      " 39%|████████████████████████████████████████████████████████████████████▉                                                                                                          | 39/99 [02:06<03:25,  3.43s/it]\u001b[A\n",
      " 40%|██████████████████████████████████████████████████████████████████████▋                                                                                                        | 40/99 [02:08<02:59,  3.04s/it]\u001b[A\n",
      " 41%|████████████████████████████████████████████████████████████████████████▍                                                                                                      | 41/99 [02:16<04:08,  4.29s/it]\u001b[A\n",
      " 42%|██████████████████████████████████████████████████████████████████████████▏                                                                                                    | 42/99 [02:17<03:21,  3.53s/it]\u001b[A\n",
      " 43%|████████████████████████████████████████████████████████████████████████████                                                                                                   | 43/99 [02:23<03:59,  4.27s/it]\u001b[A\n",
      " 44%|█████████████████████████████████████████████████████████████████████████████▊                                                                                                 | 44/99 [02:26<03:19,  3.63s/it]\u001b[A\n",
      " 45%|███████████████████████████████████████████████████████████████████████████████▌                                                                                               | 45/99 [02:27<02:43,  3.03s/it]\u001b[A\n",
      " 46%|█████████████████████████████████████████████████████████████████████████████████▎                                                                                             | 46/99 [02:34<03:40,  4.17s/it]\u001b[A\n",
      " 47%|███████████████████████████████████████████████████████████████████████████████████                                                                                            | 47/99 [02:37<03:19,  3.83s/it]\u001b[A\n",
      " 48%|████████████████████████████████████████████████████████████████████████████████████▊                                                                                          | 48/99 [02:40<02:55,  3.44s/it]\u001b[A\n",
      " 49%|██████████████████████████████████████████████████████████████████████████████████████▌                                                                                        | 49/99 [02:42<02:39,  3.20s/it]\u001b[A\n",
      " 51%|████████████████████████████████████████████████████████████████████████████████████████▍                                                                                      | 50/99 [02:44<02:11,  2.69s/it]\u001b[A\n",
      " 52%|██████████████████████████████████████████████████████████████████████████████████████████▏                                                                                    | 51/99 [02:46<02:06,  2.64s/it]\u001b[A\n",
      " 53%|███████████████████████████████████████████████████████████████████████████████████████████▉                                                                                   | 52/99 [02:49<02:10,  2.77s/it]\u001b[A\n",
      " 54%|█████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                 | 53/99 [02:50<01:43,  2.25s/it]\u001b[A\n",
      " 55%|███████████████████████████████████████████████████████████████████████████████████████████████▍                                                                               | 54/99 [02:54<01:55,  2.56s/it]\u001b[A\n",
      " 56%|█████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                             | 55/99 [02:56<01:50,  2.50s/it]\u001b[A\n",
      " 57%|██████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                            | 56/99 [02:58<01:39,  2.31s/it]\u001b[A\n",
      " 58%|████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                          | 57/99 [03:03<02:11,  3.13s/it]\u001b[A\n",
      " 59%|██████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                        | 58/99 [03:10<02:54,  4.26s/it]\u001b[A\n",
      " 60%|████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                      | 59/99 [03:11<02:18,  3.45s/it]\u001b[A\n",
      " 61%|██████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                     | 60/99 [03:23<03:54,  6.01s/it]\u001b[A\n",
      " 62%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                   | 61/99 [03:30<03:56,  6.23s/it]\u001b[A\n",
      " 63%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                 | 62/99 [03:35<03:33,  5.76s/it]\u001b[A\n",
      " 64%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                               | 63/99 [03:36<02:43,  4.54s/it]\u001b[A\n",
      " 65%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                             | 64/99 [03:38<02:10,  3.73s/it]\u001b[A\n",
      " 66%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                            | 65/99 [03:40<01:45,  3.10s/it]\u001b[A\n",
      " 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                          | 66/99 [03:41<01:24,  2.56s/it]\u001b[A\n",
      " 68%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                        | 67/99 [03:44<01:23,  2.60s/it]\u001b[A\n",
      " 69%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                      | 68/99 [03:46<01:12,  2.34s/it]\u001b[A\n",
      " 70%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                     | 69/99 [03:48<01:13,  2.45s/it]\u001b[A\n",
      " 71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                   | 70/99 [03:49<00:59,  2.06s/it]\u001b[A\n",
      " 72%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                 | 71/99 [03:52<00:59,  2.14s/it]\u001b[A\n",
      " 73%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                               | 72/99 [03:54<00:56,  2.10s/it]\u001b[A\n",
      " 74%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                              | 73/99 [03:55<00:50,  1.96s/it]\u001b[A\n",
      " 75%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                            | 74/99 [03:57<00:44,  1.78s/it]\u001b[A\n",
      " 76%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                          | 75/99 [04:02<01:06,  2.75s/it]\u001b[A\n",
      " 77%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                        | 76/99 [04:05<01:03,  2.77s/it]\u001b[A\n",
      " 78%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                       | 77/99 [04:06<00:53,  2.45s/it]\u001b[A\n",
      " 79%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                     | 78/99 [04:09<00:51,  2.43s/it]\u001b[A\n",
      " 80%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                   | 79/99 [04:12<00:53,  2.66s/it]\u001b[A\n",
      " 81%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                 | 80/99 [04:14<00:45,  2.41s/it]\u001b[A\n",
      " 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                               | 81/99 [04:18<00:54,  3.03s/it]\u001b[A\n",
      " 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                              | 82/99 [04:22<00:53,  3.17s/it]\u001b[A\n",
      " 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                            | 83/99 [04:26<00:56,  3.56s/it]\u001b[A\n",
      " 85%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                          | 84/99 [04:28<00:45,  3.02s/it]\u001b[A\n",
      " 86%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                        | 85/99 [04:34<00:54,  3.86s/it]\u001b[A\n",
      " 87%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                       | 86/99 [04:38<00:52,  4.02s/it]\u001b[A\n",
      " 88%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                     | 87/99 [04:39<00:38,  3.17s/it]\u001b[A\n",
      " 89%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                   | 88/99 [04:42<00:33,  3.03s/it]\u001b[A\n",
      " 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                 | 89/99 [04:49<00:42,  4.30s/it]\u001b[A\n",
      " 91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                | 90/99 [04:54<00:38,  4.30s/it]\u001b[A\n",
      " 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊              | 91/99 [04:57<00:32,  4.03s/it]\u001b[A\n",
      " 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋            | 92/99 [05:01<00:27,  3.91s/it]\u001b[A\n",
      " 94%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍          | 93/99 [05:04<00:22,  3.80s/it]\u001b[A\n",
      " 95%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏        | 94/99 [05:06<00:15,  3.17s/it]\u001b[A\n",
      " 96%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉       | 95/99 [05:08<00:11,  2.88s/it]\u001b[A\n",
      " 97%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋     | 96/99 [05:11<00:08,  2.75s/it]\u001b[A\n",
      " 98%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍   | 97/99 [05:13<00:05,  2.68s/it]\u001b[A\n",
      " 99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 98/99 [05:15<00:02,  2.35s/it]\u001b[A\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [05:18<00:00,  3.21s/it]\u001b[A\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "with open('outputs/test_gpt_3.5_turbo_%s.txt' % task, 'w') as fd:\n",
    "    for q_ in tqdm(task_data['test'], total=len(task_data['test'])):\n",
    "        q = q_['input'] + '\\n'\n",
    "        for letter in ['A', 'B', 'C', 'D']:\n",
    "            q += '(' + letter + ') ' + q_[letter] + ' '\n",
    "        q += \"\\nA: Let's think step by step.\"  \n",
    "            \n",
    "        prompt_q = mmlu_prompt[task] + \"\\n\\n\" + q\n",
    "\n",
    "        response = completion_with_backoff(\n",
    "              model=\"gpt-3.5-turbo\",\n",
    "              messages=[\n",
    "                    {\"role\": \"system\", \"content\": \"Follow the given examples and answer the question.\"},\n",
    "                    {\"role\": \"user\", \"content\": prompt_q},\n",
    "                ],\n",
    "            temperature=0\n",
    "            )\n",
    "        ans_model = response['choices'][0]['message']['content']\n",
    "        ans_, residual = extract_ans(ans_model)\n",
    "            \n",
    "        a = q_['target']\n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))\n",
    "        i += 1\n",
    "        # if(i == 2): break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8ea5012b-444a-49d8-8752-51ea485d9beb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "college_chemistry\n",
      "num_q 99 correct 51 ratio 0.5152\n"
     ]
    }
   ],
   "source": [
    "print(task)\n",
    "_, _, _ = parse_pred_ans('outputs/test_gpt_3.5_turbo_%s.txt' % task)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
