{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a8ef9ad9-c7e1-4fed-b0bd-581064558089",
   "metadata": {},
   "source": [
    "# GPT-3.5-Turbo Performance on MMLU - College Math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4f1c09a4-4859-469b-a156-dbb037c83a65",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import openai\n",
    "import re\n",
    "import time\n",
    "import json\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "from tenacity import retry, stop_after_attempt, wait_chain, wait_fixed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "eaea2a2a-2515-4508-9cdb-084d10853170",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "openai.api_key = \"sk-\" "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "95ddd688-5bf5-40c5-a852-32e62f5a1bbb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "@retry(wait=wait_chain(*[wait_fixed(3) for i in range(3)] +\n",
    "                       [wait_fixed(5) for i in range(2)] +\n",
    "                       [wait_fixed(10)]))\n",
    "def completion_with_backoff(**kwargs):\n",
    "    return openai.ChatCompletion.create(**kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d4063503-c0e0-4df7-9866-0815f4c9fdf0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "mmlu_prompt = json.load(open('lib_prompt/mmlu-cot.json'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "cb90f21b-6ea7-4bc7-bb8e-08e97ecbb818",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'abstract_algebra',\n",
      "'anatomy',\n",
      "'astronomy',\n",
      "'business_ethics',\n",
      "'clinical_knowledge',\n",
      "'college_biology',\n",
      "'college_chemistry',\n",
      "'college_computer_science',\n",
      "'college_mathematics',\n",
      "'college_medicine',\n",
      "'college_physics',\n",
      "'computer_security',\n",
      "'conceptual_physics',\n",
      "'econometrics',\n",
      "'electrical_engineering',\n",
      "'elementary_mathematics',\n",
      "'formal_logic',\n",
      "'global_facts',\n",
      "'high_school_biology',\n",
      "'high_school_chemistry',\n",
      "'high_school_computer_science',\n",
      "'high_school_european_history',\n",
      "'high_school_geography',\n",
      "'high_school_government_and_politics',\n",
      "'high_school_macroeconomics',\n",
      "'high_school_mathematics',\n",
      "'high_school_microeconomics',\n",
      "'high_school_physics',\n",
      "'high_school_psychology',\n",
      "'high_school_statistics',\n",
      "'high_school_us_history',\n",
      "'high_school_world_history',\n",
      "'human_aging',\n",
      "'human_sexuality',\n",
      "'international_law',\n",
      "'jurisprudence',\n",
      "'logical_fallacies',\n",
      "'machine_learning',\n",
      "'management',\n",
      "'marketing',\n",
      "'medical_genetics',\n",
      "'miscellaneous',\n",
      "'moral_disputes',\n",
      "'moral_scenarios',\n",
      "'nutrition',\n",
      "'philosophy',\n",
      "'prehistory',\n",
      "'professional_accounting',\n",
      "'professional_law',\n",
      "'professional_medicine',\n",
      "'professional_psychology',\n",
      "'public_relations',\n",
      "'security_studies',\n",
      "'sociology',\n",
      "'us_foreign_policy',\n",
      "'virology',\n",
      "'world_religions',\n"
     ]
    }
   ],
   "source": [
    "for k in mmlu_prompt.keys():\n",
    "    print(\"'\" + k + \"'\"+',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bbbbe828-fad8-48a9-9cab-4a7e675ecf9e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The following are multiple choice questions (with answers) about college mathematics.\n",
      "\n",
      "Q: Let V be the set of all real polynomials p(x). Let transformations T, S be defined on V by T:p(x) -> xp(x) and S:p(x) -> p'(x) = d/dx p(x), and interpret (ST)(p(x)) as S(T(p(x))). Which of the following is true?\n",
      "(A) ST = 0 (B) ST = T (C) ST = TS (D) ST - TS is the identity map of V onto itself.\n",
      "A: Let's think step by step. For a given polynomial $p$ we have\n",
      "\\[ST(p) = (xp(x))’ = p(x) + xp’(x)\\]\n",
      "and\n",
      "\\[TS(p) = xp’(x).\\]\n",
      "Hence \\[ST(p) - TS(p) = p(x) + xp’(x) - xp’(x).\\] The answer is (D).\n",
      "\n",
      "Q: Suppose that f(1 + x) = f(x) for all real x. If f is a polynomial and f(5) = 11, then f(15/2)\n",
      "(A) -11 (B) 0 (C) 11 (D) 33/2\n",
      "A: Let's think step by step. The only polynomial so that $f(1 + x) = f(x)$ is a constant polynomial. Hence $f(5) = 11 = f(15/2)$. The answer is (C).\n",
      "\n",
      "Q: Let A be a real 2x2 matrix. Which of the following statements must be true?\n",
      "I. All of the entries of A^2 are nonnegative.\n",
      "II. The determinant of A^2 is nonnegative.\n",
      "III. If A has two distinct eigenvalues, then A^2 has two distinct eigenvalues.\n",
      "(A) I only (B) II only (C) III only (D) II and III only\n",
      "A: Let's think step by step. We have \\[ det(A^2) = (det(A))^2 \\geq 0,\\] hence II holds.\n",
      "III is false: as a counterexample take a diagonal matrix with -1 and 1 on the diagonal. Then $A^2$ is the identity matrix. The answer is (B).\n",
      "\n",
      "Q: Let A be the set of all ordered pairs of integers (m, n) such that 7m + 12n = 22. What is the greatest negative number in the set B = {m + n : (m, n) \\in A}?\n",
      "(A) -5 (B) -4 (C) -3 (D) -2\n",
      "A: Let's think step by step. We have 12n = 22 - 7m and one of the solutions is $m = -2$, $n = 3$. Then $m + n = 1$, hence we need to look for smaller $m$ in order to make $m + n$ negative. The next solution is $m = -14$ and $n = 10$. For smaller $m$ we have $m + n$ smaller than $-4$. The answer is (B).\n",
      "\n",
      "Q: A tank initially contains a salt solution of 3 grams of salt dissolved in 100 liters of water. A salt solution containing 0.02 grams of salt per liter of water is sprayed into the tank at a rate of 4 liters per minute. The sprayed solution is continually mixed with the salt solution in the tank, and the mixture flows out of the tank at a rate of 4 liters per minute. If the mixing is instantaneous, how many grams of salt are in the tank after 100 minutes have elapsed?\n",
      "(A) 2 (B) 2 - e^-2 (C) 2 + e^-2 (D) 2 + e^-4\n",
      "A: Let's think step by step. For all $t \\in \\mathbb{R}$, let $s(t)$ denote the number grams of salt in the tank at the $t$ minute mark. Then $s(0) = 3$.\n",
      "ight]$. For all $t \\in \\mathbb{R}$,ly. We also use $s^{\\prime}$ and $s^{\\prime}(t)$ interchangeably. The solution sprayed into the tank adds $(0.02) 4=2 / 25$ grams of salt per minute. There are always 100 liters of liquid in the tank, containing $s$ grams of salt. So the density of salt in the tank is $s / 100$ grams per liter. The flow of water out of the tank therefore subtracts $4(s / 100)=s / 25$ grams of salt per minute. Then, for all $t \\in \\mathbb{R}$, we have $s^{\\prime}(t)=(2 / 25)-(s / 25)=(2-s) / 25$, and so $[s(t)=2] \\Rightarrow\\left[s^{\\prime}(t)=0\n",
      "$$\n",
      "ight] .{d t}[\\ln (s-2)]=\f",
      "rac{s^{\\prime}}{s-2}=\f",
      "rac{-1}{25}=\f",
      "rac{d}{d t}\\left[-\f",
      "rac{t}{25}\n",
      "$$\n",
      "Choose $C \\in \\mathbb{R}$ such that, for all $t \\in \\mathbb{R}, \\ln ((s(t)-2))=-[t / 25]+C$. Let $K:=e^{C}$. Then, for all $t \\in \\mathbb{R}$, we have $(s(t))-2=K e^{-t / 25}$, and so $s(t)=2+K e^{-t / 25}$. Then $3=s(0)=2+K e^{0}=2+K$, so $K=1$. Then $s(100)=2+K e^{-100 / 25}=2+1 \\cdot e^{-4}=2+e^{-4}$. The answer is (D).\n"
     ]
    }
   ],
   "source": [
    "task = 'college_mathematics'\n",
    "print(mmlu_prompt[task])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ce1c6b79-3530-4efe-bfd6-eead27d09ff3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and preparing dataset mmlu/college_mathematics to /Users/yaofu/.cache/huggingface/datasets/lukaemon___mmlu/college_mathematics/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating test split:   0%|          | 0/99 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating validation split:   0%|          | 0/10 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/4 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset mmlu downloaded and prepared to /Users/yaofu/.cache/huggingface/datasets/lukaemon___mmlu/college_mathematics/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3. Subsequent calls will reuse this data.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9fc1ec336155417aa0a97ffd83a47945",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "task_data = load_dataset(\"lukaemon/mmlu\", task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0957327f-9454-4010-9501-b7086fbba125",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "99"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(task_data['test'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5b18d727-13a7-45cd-a842-3462636d9c25",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input': 'Up to isomorphism, how many additive abelian groups G of order 16 have the property that x + x + x + x = 0 for each x in G ?',\n",
       " 'A': '0',\n",
       " 'B': '1',\n",
       " 'C': '2',\n",
       " 'D': '3',\n",
       " 'target': 'D'}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "task_data['test'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "bec64ac3-96b4-45cd-b79e-6c8ad6fae234",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "prompt_q = mmlu_prompt[task] + \"\\n\\n\" + task_data['test'][0]['input'] + '\\n'\n",
    "for letter in ['A', 'B', 'C', 'D']:\n",
    "    prompt_q += '(' + letter + ') ' + task_data['test'][0][letter] + ' '\n",
    "prompt_q += \"\\nA: Let's think step by step.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "8847b4b5-99fe-47ab-941b-5bd02c43755e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The following are multiple choice questions (with answers) about college biology.\n",
      "\n",
      "Q: Which of the following represents an accurate statement concerning arthropods?\n",
      "(A) They possess an exoskeleton composed primarily of peptidoglycan. (B) They possess an open circulatory system with a dorsal heart. (C) They are members of a biologically unsuccessful phylum incapable of exploiting diverse habitats and nutrition sources. (D) They lack paired, jointed appendages.\n",
      "A: Let's think step by step. Peptidoglycan is known to comprise the plasma membrane of most bacteria, rather than the exoskeleton of arthropods, which is made of chitin, which rules out (A). The answer (C) is false because arthropods are a highly successful phylum. Likewise, arthropods have paired, jointed appendages, which rules out (D). The only remaining option is (B), as arthropods have an open circulatory system with a dorsal tubular heart. The answer is (B).\n",
      "\n",
      "Q: In a given population, 1 out of every 400 people has a cancer caused by a completely recessive allele, b. Assuming the population is in Hardy-Weinberg equilibrium, which of the following is the expected proportion of individuals who carry the b allele but are not expected to develop the cancer?\n",
      "(A) 1/400 (B) 19/400 (C) 20/400 (D) 38/400\n",
      "A: Let's think step by step. According to the Hardy Weinberg Law, $p^2 + 2 p q + q^2 = 1$, and $p + q = 1$ where $p$ is the frequency of the dominant allele, $q$ is the frequency of the recessive allele, and $p^2$, $q^2$, and $2pq$ are the frequencies of dominant homozygous, recessive homozygous, and heterozygous individuals, respectively. ​The frequency of the recessive allele (q) is $\\sqrt{\f",
      "rac{1}{400}} = 0.05$. We have $p = 1 - q = 0.95$. The frequency of heterozygous individuals is $2pq = 2 \\cdot 0.05 \\cdot 0.95 = 0.095$. The number of heterozygous individuals is equal to the frequency of heterozygous individuals times the size of the population, or $0.095 * 400 = 38$. So we end up with 38/400. The answer is (D).\n",
      "\n",
      "Q: According to the pressure-flow model of movement of phloem contents, photosynthate movement from source to sink is driven by\n",
      "(A) an ATP-dependent pressure-flow pump (B) a water-pressure potential gradient (C) transpiration (D) apoplastic diffusion\n",
      "A: Let's think step by step. It is a gradient in water pressure that induces the movement of phloem content, which refers to answer (B). The mechanism of movement does not rely on metabolism, which rules out (A). Transpiration refers to the exhalation of water vapor through plant stomata, and is also not related, which rules out (C). While the apoplastic pathway is one of two main pathways for water transport in plants, it is not central to the pressure flow model, which rules out (D). The answer is (B).\n",
      "\n",
      "Q: Which of the following contain DNA sequences required for the segregation of chromosomes in mitosis and meiosis?\n",
      "(A) Telomeres (B) Centromeres (C) Nucleosomes (D) Spliceosomes\n",
      "A: Let's think step by step. The genetic material in Telomeres is not used, which rules out (A). Nucleosomes are the repeating subunit that comprises chromatin packed in a cell nucleus, and do not specifically refer to DNA sequences necessary for segregating chromosomes in cell division, which rules out (C). A spliceosome is a large ribonucleoprotein that removes introns from transcribed pre-mRNA rather than governing chromosome segregation. Centromeres are directly responsible for segregating chromosomes in cell division. The answer is (B).\n",
      "\n",
      "Q: The presence of homologous structures in two different organisms, such as the humerus in the front limb of a human and a bird, indicates that\n",
      "(A) the human and bird are polyphyletic species (B) a human's and bird's evolution is convergent (C) the human and bird belong to a clade (D) the human and bird developed by analogy\n",
      "A: Let's think step by step. Polyphyletic species are organisms that are grouped due to having similar characteristics but which do not have a common ancestor. This is not the case for humans and birds, which rules out (A). Convergent evolution refers to the indepdendent development of similar features in different species at different periods, which is also not the case for humans and birds, which rules out (B). Analogy refers to the superficial resemblance of structures that have different origins, which is not the case for the human and bird forearms, which rules out (D). Humans and birds do belong to the same clade - a group of organisms composed of a common ancestor. The answer is (C).\n",
      "\n",
      "A frameshift mutation is created when\n",
      "(A) telomeric sequences are removed from DNA (B) a codon's nucleotide sequence changes so that it calls for production of a different amino acid than the original one (C) a base pair is either inserted or deleted in a gene (D) a codon's nucleotide sequence is changed so that instead of coding for a given amino acid it acts to terminate translation \n",
      "A: Let's think step by step.\n"
     ]
    }
   ],
   "source": [
    "print(prompt_q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c91e1e78-0e6d-4d82-9600-b6c76637fa80",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "response = openai.ChatCompletion.create(\n",
    "    model=\"gpt-3.5-turbo\",\n",
    "    messages=[\n",
    "        {\"role\": \"system\", \"content\": \"Follow the given examples and answer the question.\"},\n",
    "        {\"role\": \"user\", \"content\": prompt_q},\n",
    "    ],\n",
    "    temperature=0, \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b1e7b952-30b3-43b3-aa11-06595a7f592c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "We know that $|G| = 16$, so by the fundamental theorem of finite abelian groups, we can write $G$ as a direct sum of cyclic groups of prime power order. Let $G \\cong \\mathbb{Z}_{p_1^{a_1}} \\oplus \\mathbb{Z}_{p_2^{a_2}} \\oplus \\cdots \\oplus \\mathbb{Z}_{p_k^{a_k}}$.\n",
      "\n",
      "Since $x+x+x+x=0$ for all $x \\in G$, we have $4x = 0$ for all $x \\in G$. This means that the order of each cyclic component of $G$ must divide 4.\n",
      "\n",
      "If $p_i \\neq 2$ for all $i$, then each cyclic component has order dividing 4, so it must be either 1 or 2. But the direct sum of cyclic groups of order 2 is not possible, since it would have order 2 raised to the power of the number of summands, which is greater than 16. Therefore, each cyclic component must have order 1, so $G$ is the trivial group.\n",
      "\n",
      "If $p_i = 2$ for some $i$, then each cyclic component has order dividing 4, so it must be either 1, 2, or 4. The direct sum of cyclic groups of order 4 is not possible, since it would have order 2 raised to the power of the number of summands, which is greater than 16. Therefore, each cyclic component must have order 1 or 2.\n",
      "\n",
      "If $a_i = 1$ for all $i$, then $G$ has two possible forms: $\\mathbb{Z}_2 \\oplus \\mathbb{Z}_2 \\oplus \\mathbb{Z}_2 \\oplus \\mathbb{Z}_2$ and $\\mathbb{Z}_2 \\oplus \\mathbb{Z}_4$. Both of these groups satisfy the condition $x+x+x+x=0$ for all $x \\in G$.\n",
      "\n",
      "If $a_i = 2$ for some $i$, then $G$ has two possible forms: $\\mathbb{Z}_2 \\oplus \\mathbb{Z}_2 \\oplus \\mathbb{Z}_8$ and $\\mathbb{Z}_2 \\oplus \\mathbb{Z}_4 \\oplus \\mathbb{Z}_4$. Both of these groups satisfy the condition $x+x+x+x=0$ for all $x \\in G$.\n",
      "\n",
      "Therefore, there are a total of $\\boxed{3}$ non-isomorphic additive abelian groups of order 16 that satisfy the condition $x+x+x+x=0$ for all $x \\in G$.\n"
     ]
    }
   ],
   "source": [
    "print(response['choices'][0]['message']['content'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "867945a0-6527-409e-801e-4244eaf75e61",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def test_answer_mmlu_(pred_str, ans):\n",
    "    pattern = 'the answer is ('\n",
    "    pred = pred_str.lower().split(pattern)\n",
    "    \n",
    "    if(len(pred) > 1):\n",
    "        # print(pred)\n",
    "        pred = pred[1][0]\n",
    "        gold = ans.lower()\n",
    "        # print('debug 1, pred %s, gold %s' % (pred, gold))\n",
    "        return pred == gold\n",
    "    else: \n",
    "        pred = 'C'\n",
    "        # print(ans_str)\n",
    "        gold = ans.lower()\n",
    "        # print('debug 2, pred %s, gold %s' % (pred, gold))\n",
    "        return pred == gold\n",
    "\n",
    "def test_answer_mmlu(pred_str, ans_str):\n",
    "    pattern = 'the answer is ('\n",
    "    pred = pred_str.lower().split(pattern)\n",
    "    \n",
    "    if(len(pred) > 1):\n",
    "        # print(pred)\n",
    "        pred = pred[1][0]\n",
    "        gold = ans_str.split('A:\\n')[1][0].lower()\n",
    "        # print('debug 1, pred %s, gold %s' % (pred, gold))\n",
    "        return pred == gold\n",
    "    else: \n",
    "        pred = 'C'\n",
    "        # print(ans_str)\n",
    "        gold = ans_str.split('A:\\n')[1][0].lower()\n",
    "        # print('debug 2, pred %s, gold %s' % (pred, gold))\n",
    "        return pred == gold\n",
    "\n",
    "def parse_pred_ans(filename):\n",
    "    with open(filename) as fd: lines = fd.readlines()\n",
    "    am, a = None, None\n",
    "    num_q, acc = 0, 0\n",
    "    current_mode = 'none'\n",
    "    questions = []\n",
    "    ans_pred = []\n",
    "    ans_gold = []\n",
    "    for l in lines:\n",
    "        if(l.startswith('Q: ')):\n",
    "            if(am is not None and a is not None):\n",
    "                questions.append(q)\n",
    "                ans_pred.append(am)\n",
    "                ans_gold.append(a)\n",
    "                # print(am)\n",
    "                # print(a)\n",
    "                if(test_answer_mmlu(am, a)):\n",
    "                    acc += 1\n",
    "            current_mode = 'q'\n",
    "            q = l\n",
    "            num_q += 1\n",
    "        elif(l.startswith('A_model:')):\n",
    "            current_mode = 'am'\n",
    "            am = l\n",
    "        elif(l.startswith('A:') and not l.startswith(\"A: Let's think step by step\")):\n",
    "            current_mode = 'a'\n",
    "            a = l\n",
    "        else:\n",
    "            if(current_mode == 'q'): q += l\n",
    "            elif(current_mode == 'am'): am += l\n",
    "            elif(current_mode == 'a'): a += l\n",
    "            else:\n",
    "                raise ValueError(current_mode)\n",
    "                \n",
    "    questions.append(q)\n",
    "    ans_pred.append(am)\n",
    "    ans_gold.append(a)\n",
    "    # print(am)\n",
    "    # print(a)\n",
    "    if(test_answer_mmlu(am, a)):\n",
    "        acc += 1\n",
    "    print('num_q %d correct %d ratio %.4f' % (num_q, acc, float(acc / num_q)))\n",
    "    return questions, ans_pred, ans_gold\n",
    "\n",
    "def test_finished(ans_model):\n",
    "    if('answer is' in ans_model): return True\n",
    "    else: return False\n",
    "\n",
    "def extract_ans(ans_model):\n",
    "    ans_model = ans_model.split('\\n')\n",
    "    ans = []\n",
    "    residual = []\n",
    "    for li, al in enumerate(ans_model):\n",
    "        ans.append(al)\n",
    "        if('answer is' in al):\n",
    "            break\n",
    "    residual = list(ans_model[li + 1:])\n",
    "    ans = '\\n'.join(ans)\n",
    "    residual = '\\n'.join(residual)\n",
    "    return ans, residual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "b398b88e-d8fd-47fc-be03-f4627f6719e1",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [17:57<00:00, 10.88s/it]\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "acc = 0\n",
    "with open('outputs/test_gpt_3.5_turbo_%s.txt' % task, 'w') as fd:\n",
    "    for q_ in tqdm(task_data['test'], total=len(task_data['test'])):\n",
    "        q = q_['input'] + '\\n'\n",
    "        for letter in ['A', 'B', 'C', 'D']:\n",
    "            q += '(' + letter + ') ' + q_[letter] + ' '\n",
    "        q += \"\\nA: Let's think step by step.\"  \n",
    "            \n",
    "        prompt_q = mmlu_prompt[task] + \"\\n\\n\" + q\n",
    "\n",
    "        response = completion_with_backoff(\n",
    "              model=\"gpt-3.5-turbo\",\n",
    "              messages=[\n",
    "                    {\"role\": \"system\", \"content\": \"Follow the given examples and answer the question.\"},\n",
    "                    {\"role\": \"user\", \"content\": prompt_q},\n",
    "                ],\n",
    "            temperature=0\n",
    "            )\n",
    "        ans_model = response['choices'][0]['message']['content']\n",
    "        ans_, residual = extract_ans(ans_model)\n",
    "            \n",
    "        a = q_['target']\n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))\n",
    "        i += 1\n",
    "        \n",
    "        if(test_answer_mmlu_(ans_, a)): acc += 1\n",
    "        \n",
    "        # if(i == 10): break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "d11a0b15-6fce-4c52-967c-160f67cb2536",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.29292929292929293\n"
     ]
    }
   ],
   "source": [
    "print(acc / len(task_data['test']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "8ea5012b-444a-49d8-8752-51ea485d9beb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "college_mathematics\n",
      "num_q 100 correct 29 ratio 0.2900\n"
     ]
    }
   ],
   "source": [
    "print(task)\n",
    "_, _, _ = parse_pred_ans('outputs/test_gpt_3.5_turbo_%s.txt' % task)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
