{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a8ef9ad9-c7e1-4fed-b0bd-581064558089",
   "metadata": {},
   "source": [
    "# GPT-3.5-Turbo Performance on MMLU - College Computer Science"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4f1c09a4-4859-469b-a156-dbb037c83a65",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import openai\n",
    "import re\n",
    "import time\n",
    "import json\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "from tenacity import retry, stop_after_attempt, wait_chain, wait_fixed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "eaea2a2a-2515-4508-9cdb-084d10853170",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "openai.api_key = \"sk-\" "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "95ddd688-5bf5-40c5-a852-32e62f5a1bbb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "@retry(wait=wait_chain(*[wait_fixed(3) for i in range(3)] +\n",
    "                       [wait_fixed(5) for i in range(2)] +\n",
    "                       [wait_fixed(10)]))\n",
    "def completion_with_backoff(**kwargs):\n",
    "    return openai.ChatCompletion.create(**kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d4063503-c0e0-4df7-9866-0815f4c9fdf0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "mmlu_prompt = json.load(open('lib_prompt/mmlu-cot.json'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bbbbe828-fad8-48a9-9cab-4a7e675ecf9e",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The following are multiple choice questions (with answers) about college computer science.\n",
      "\n",
      "Q: Which of the following regular expressions is equivalent to (describes the same set of strings as) (a* + b)*(c + d)?\n",
      "(A) a*(c + d)+ b(c + d)\n",
      "(B) a*(c + d)* + b(c + d)*\n",
      "(C) a*(c + d)+ b*(c + d)\n",
      "(D) (a + b)*c +(a + b)*d\n",
      "A: Let's think step by step. We know that:\n",
      "1. (X* + Y)* = (X + Y)*\n",
      "2. X(Y + Z)? = XY + XZ\n",
      "Using equation 1 we can rewrite (a* + b)*(c + d)? as:\n",
      "3. (a + b)*(c + d)?\n",
      "Using equation 2 we can rewrite equation 3 as:\n",
      "(a + b)*c + (a + b)*d The answer is (D).\n",
      "\n",
      "Q: The Singleton design pattern is used to guarantee that only a single instance of a class may be instantiated. Which of the following is (are) true of this design pattern?\n",
      "I. The Singleton class has a static factory method to provide its instance.\n",
      "II. The Singleton class can be a subclass of another class.\n",
      "III. The Singleton class has a private constructor.\n",
      "(A) I only\n",
      "(B) II only\n",
      "(C) III only\n",
      "(D) I, II, and III\n",
      "A: Let's think step by step. Statement I is a correct statement about a Singleton, because a Singleton restricts instantiation to a single, static method. Statement II is also correct, because there is no inherent restriction regarding the inheritance of a Singleton. Statement III is also correct, because a Singletons must be instantiated only once, so its constructor is made private to prevent any construction except via its static factory method.\n",
      "Given these facts, statements I, II, and III are all correct. The answer is (D).\n",
      "\n",
      "Q: A certain pipelined RISC machine has 8 general-purpose registers R0, R1, . . . , R7 and supports the following operations:\n",
      "ADD Rs1, Rs2, Rd (Add Rs1 to Rs2 and put the sum in Rd)\n",
      "MUL Rs1, Rs2, Rd (Multiply Rs1 by Rs2 and put the product in Rd)\n",
      "An operation normally takes one cycle; however, an operation takes two cycles if it produces a result required by the immediately following operation in an operation sequence.\n",
      "Consider the expression AB + ABC + BC, where variables A, B, C are located in registers R0, R1, R2. If the contents of these three registers must not be modified, what is the minimum number of clock cycles required for an operation sequence that computes the value of AB + ABC + BC?\n",
      "(A) 5 (B) 6 (C) 7 (D) 8\n",
      "A: Let's think step by step. First, we are given that A is in R0, B is in R1, and C is in R2.\n",
      "Next, we can see that we must compute three multiplies (AB, BC, and ABC) and two adds (AB + ABC, (AB + ABC) + BC) to compute our final answer, resulting in a minimum of five clock cycles.\n",
      "Next, we can see that there is no way to avoid at least one pipeline stall when computing our final answer, because to compute our final sum we must wait at least one cycle for the results from the previous stage to be ready. Thus, our minimum number of cycles must be 6.\n",
      "We can verify that we can create a solution that requires only six cycles as follows:\n",
      "compute AB: MUL R0, R1, R3\n",
      "compute BC: MUL R1, R2, R4\n",
      "compute ABC: MUL R3, R4, R5\n",
      "compute AB + BC: ADD R3, R4, R6\n",
      "STALL\n",
      "compute AB + ABC + BC: ADD R5, R6, R7\n",
      "So there are 6 cycles. The answer is (B).\n",
      "\n",
      "Q: A compiler generates code for the following assignment statement.\n",
      "G := (A + B) * C - (D + E) * F\n",
      "The target machine has a single accumulator and a single-address instruction set consisting of instructions load, store, add, subtract, and multiply. For the arithmetic operations, the left operand is taken from the accumulator and the result appears in the accumulator. The smallest possible number of instructions in the resulting code is\n",
      "(A) 5 (B) 6 (C) 7 (D) 9\n",
      "A: Let's think step by step. We can compute the final answer with the following sequence of operations:\n",
      "1. LOAD D  (accumulator = D)\n",
      "2. ADD E  (accumulator = D+E)\n",
      "3. MUL F  (accumulator = (D+E)*F)\n",
      "4. STORE X (X = (D+E)*F)\n",
      "5. LOAD A  (accumulator = A)\n",
      "6. ADD B  (accumulator = A+B)\n",
      "7. MUL C  (accumulator = (A+B)*C)\n",
      "8. SUB X  (accumulator = (A+B)*C - (D+E)*F)\n",
      "9. STORE G (G = (A+B)*C - (D+E)*F)\n",
      "This sequence takes 9 instructions. The answer is (D).\n",
      "\n",
      "Q: Consider a computer design in which multiple processors, each with a private cache memory, share global memory using a single bus. This bus is the critical system resource. Each processor can execute one instruction every 500 nanoseconds as long as memory references are satisfied by its local cache. When a cache miss occurs, the processor is delayed for an additional 2,000 nanoseconds. During half of this additional delay, the bus is dedicated to serving the cache miss. During the other half, the processor cannot continue, but the bus is free to service requests from other processors. On average, each instruction requires 2 memory references. On average, cache misses occur on 1 percent of references. What proportion of the capacity of the bus would a single processor consume, ignoring delays due to competition from other processors?\n",
      "(A) 1/50 (B) 1/27 (C) 1/25 (D) 2/27\n",
      "A: Let's think step by step. We know that each instruction requires two memory references per instruction, and that there is an average cache miss rate of one percent.\n",
      "Thus a given processor has:\n",
      "(1 cache miss / 100 references) * (2 references / instruction) =\n",
      "(2 cache misses / 100 instructions), so:\n",
      "misses_per_instruction = 1 cache miss / 50 instructions.\n",
      "Next, we know that each instruction requires 500 nanoseconds when there is no cache miss, and 500 + 2000 = 2500 nanoseconds when there is a cache miss. Thus:\n",
      "50 instructions / (49 * 500) + (1 * 2500) nanoseconds, so:\n",
      "instructions_per_ns = 50 instructions / 27000 nanoseconds.\n",
      "Now, we know that each cache miss locks the bus for half of the 2000 nanosecond cache miss delay, or 1000 nanoseconds, so:\n",
      "lock_ns_per_miss = 1000 nanoseconds / cache miss.\n",
      "Thus we can see that on average a single processor will lock the bus for:\n",
      "lock_ns_per_miss * misses_per_instruction * instructions_per_ns =\n",
      "(1000 nanoseconds / cache miss) * (1 cache miss / 50 instructions) * (50 instructions / 27000 nanoseconds) = 1000 * (1/50) * (50/27000) = 1000/27000 = 1/27. The answer is (B).\n"
     ]
    }
   ],
   "source": [
    "task = 'college_computer_science'\n",
    "print(mmlu_prompt[task])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ce1c6b79-3530-4efe-bfd6-eead27d09ff3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and preparing dataset mmlu/college_computer_science to /Users/yaofu/.cache/huggingface/datasets/lukaemon___mmlu/college_computer_science/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8eaf2a0addc8463baa71990e7fa710b0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating test split:   0%|          | 0/99 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating validation split:   0%|          | 0/10 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2c917ae96b1949d7ae063abb5a88ed82",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/4 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset mmlu downloaded and prepared to /Users/yaofu/.cache/huggingface/datasets/lukaemon___mmlu/college_computer_science/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3. Subsequent calls will reuse this data.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "92e41201626944218dbc5a5684818941",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "task_data = load_dataset(\"lukaemon/mmlu\", task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0957327f-9454-4010-9501-b7086fbba125",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "99"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(task_data['test'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5b18d727-13a7-45cd-a842-3462636d9c25",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input': 'An integer c is a common divisor of two integers x and y if and only if c is a divisor of x and c is a divisor of y. Which of the following sets of integers could possibly be the set of all common divisors of two integers?',\n",
       " 'A': '{-6,-2, -1, 1, 2, 6}',\n",
       " 'B': '{-6, -2, -1, 0, 1, 2, 6}',\n",
       " 'C': '{-6, -3, -2, -1, 1, 2, 3, 6}',\n",
       " 'D': '{-6, -3, -2, -1, 0, 1, 2, 3, 6}',\n",
       " 'target': 'C'}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "task_data['test'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bec64ac3-96b4-45cd-b79e-6c8ad6fae234",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "prompt_q = mmlu_prompt[task] + \"\\n\\n\" + task_data['test'][0]['input'] + '\\n'\n",
    "for letter in ['A', 'B', 'C', 'D']:\n",
    "    prompt_q += '(' + letter + ') ' + task_data['test'][0][letter] + ' '\n",
    "prompt_q += \"\\nA: Let's think step by step.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8847b4b5-99fe-47ab-941b-5bd02c43755e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The following are multiple choice questions (with answers) about college computer science.\n",
      "\n",
      "Q: Which of the following regular expressions is equivalent to (describes the same set of strings as) (a* + b)*(c + d)?\n",
      "(A) a*(c + d)+ b(c + d)\n",
      "(B) a*(c + d)* + b(c + d)*\n",
      "(C) a*(c + d)+ b*(c + d)\n",
      "(D) (a + b)*c +(a + b)*d\n",
      "A: Let's think step by step. We know that:\n",
      "1. (X* + Y)* = (X + Y)*\n",
      "2. X(Y + Z)? = XY + XZ\n",
      "Using equation 1 we can rewrite (a* + b)*(c + d)? as:\n",
      "3. (a + b)*(c + d)?\n",
      "Using equation 2 we can rewrite equation 3 as:\n",
      "(a + b)*c + (a + b)*d The answer is (D).\n",
      "\n",
      "Q: The Singleton design pattern is used to guarantee that only a single instance of a class may be instantiated. Which of the following is (are) true of this design pattern?\n",
      "I. The Singleton class has a static factory method to provide its instance.\n",
      "II. The Singleton class can be a subclass of another class.\n",
      "III. The Singleton class has a private constructor.\n",
      "(A) I only\n",
      "(B) II only\n",
      "(C) III only\n",
      "(D) I, II, and III\n",
      "A: Let's think step by step. Statement I is a correct statement about a Singleton, because a Singleton restricts instantiation to a single, static method. Statement II is also correct, because there is no inherent restriction regarding the inheritance of a Singleton. Statement III is also correct, because a Singletons must be instantiated only once, so its constructor is made private to prevent any construction except via its static factory method.\n",
      "Given these facts, statements I, II, and III are all correct. The answer is (D).\n",
      "\n",
      "Q: A certain pipelined RISC machine has 8 general-purpose registers R0, R1, . . . , R7 and supports the following operations:\n",
      "ADD Rs1, Rs2, Rd (Add Rs1 to Rs2 and put the sum in Rd)\n",
      "MUL Rs1, Rs2, Rd (Multiply Rs1 by Rs2 and put the product in Rd)\n",
      "An operation normally takes one cycle; however, an operation takes two cycles if it produces a result required by the immediately following operation in an operation sequence.\n",
      "Consider the expression AB + ABC + BC, where variables A, B, C are located in registers R0, R1, R2. If the contents of these three registers must not be modified, what is the minimum number of clock cycles required for an operation sequence that computes the value of AB + ABC + BC?\n",
      "(A) 5 (B) 6 (C) 7 (D) 8\n",
      "A: Let's think step by step. First, we are given that A is in R0, B is in R1, and C is in R2.\n",
      "Next, we can see that we must compute three multiplies (AB, BC, and ABC) and two adds (AB + ABC, (AB + ABC) + BC) to compute our final answer, resulting in a minimum of five clock cycles.\n",
      "Next, we can see that there is no way to avoid at least one pipeline stall when computing our final answer, because to compute our final sum we must wait at least one cycle for the results from the previous stage to be ready. Thus, our minimum number of cycles must be 6.\n",
      "We can verify that we can create a solution that requires only six cycles as follows:\n",
      "compute AB: MUL R0, R1, R3\n",
      "compute BC: MUL R1, R2, R4\n",
      "compute ABC: MUL R3, R4, R5\n",
      "compute AB + BC: ADD R3, R4, R6\n",
      "STALL\n",
      "compute AB + ABC + BC: ADD R5, R6, R7\n",
      "So there are 6 cycles. The answer is (B).\n",
      "\n",
      "Q: A compiler generates code for the following assignment statement.\n",
      "G := (A + B) * C - (D + E) * F\n",
      "The target machine has a single accumulator and a single-address instruction set consisting of instructions load, store, add, subtract, and multiply. For the arithmetic operations, the left operand is taken from the accumulator and the result appears in the accumulator. The smallest possible number of instructions in the resulting code is\n",
      "(A) 5 (B) 6 (C) 7 (D) 9\n",
      "A: Let's think step by step. We can compute the final answer with the following sequence of operations:\n",
      "1. LOAD D  (accumulator = D)\n",
      "2. ADD E  (accumulator = D+E)\n",
      "3. MUL F  (accumulator = (D+E)*F)\n",
      "4. STORE X (X = (D+E)*F)\n",
      "5. LOAD A  (accumulator = A)\n",
      "6. ADD B  (accumulator = A+B)\n",
      "7. MUL C  (accumulator = (A+B)*C)\n",
      "8. SUB X  (accumulator = (A+B)*C - (D+E)*F)\n",
      "9. STORE G (G = (A+B)*C - (D+E)*F)\n",
      "This sequence takes 9 instructions. The answer is (D).\n",
      "\n",
      "Q: Consider a computer design in which multiple processors, each with a private cache memory, share global memory using a single bus. This bus is the critical system resource. Each processor can execute one instruction every 500 nanoseconds as long as memory references are satisfied by its local cache. When a cache miss occurs, the processor is delayed for an additional 2,000 nanoseconds. During half of this additional delay, the bus is dedicated to serving the cache miss. During the other half, the processor cannot continue, but the bus is free to service requests from other processors. On average, each instruction requires 2 memory references. On average, cache misses occur on 1 percent of references. What proportion of the capacity of the bus would a single processor consume, ignoring delays due to competition from other processors?\n",
      "(A) 1/50 (B) 1/27 (C) 1/25 (D) 2/27\n",
      "A: Let's think step by step. We know that each instruction requires two memory references per instruction, and that there is an average cache miss rate of one percent.\n",
      "Thus a given processor has:\n",
      "(1 cache miss / 100 references) * (2 references / instruction) =\n",
      "(2 cache misses / 100 instructions), so:\n",
      "misses_per_instruction = 1 cache miss / 50 instructions.\n",
      "Next, we know that each instruction requires 500 nanoseconds when there is no cache miss, and 500 + 2000 = 2500 nanoseconds when there is a cache miss. Thus:\n",
      "50 instructions / (49 * 500) + (1 * 2500) nanoseconds, so:\n",
      "instructions_per_ns = 50 instructions / 27000 nanoseconds.\n",
      "Now, we know that each cache miss locks the bus for half of the 2000 nanosecond cache miss delay, or 1000 nanoseconds, so:\n",
      "lock_ns_per_miss = 1000 nanoseconds / cache miss.\n",
      "Thus we can see that on average a single processor will lock the bus for:\n",
      "lock_ns_per_miss * misses_per_instruction * instructions_per_ns =\n",
      "(1000 nanoseconds / cache miss) * (1 cache miss / 50 instructions) * (50 instructions / 27000 nanoseconds) = 1000 * (1/50) * (50/27000) = 1000/27000 = 1/27. The answer is (B).\n",
      "\n",
      "An integer c is a common divisor of two integers x and y if and only if c is a divisor of x and c is a divisor of y. Which of the following sets of integers could possibly be the set of all common divisors of two integers?\n",
      "(A) {-6,-2, -1, 1, 2, 6} (B) {-6, -2, -1, 0, 1, 2, 6} (C) {-6, -3, -2, -1, 1, 2, 3, 6} (D) {-6, -3, -2, -1, 0, 1, 2, 3, 6} \n",
      "A: Let's think step by step.\n"
     ]
    }
   ],
   "source": [
    "print(prompt_q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c91e1e78-0e6d-4d82-9600-b6c76637fa80",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "response = openai.ChatCompletion.create(\n",
    "    model=\"gpt-3.5-turbo\",\n",
    "    messages=[\n",
    "        {\"role\": \"system\", \"content\": \"Follow the given examples and answer the question.\"},\n",
    "        {\"role\": \"user\", \"content\": prompt_q},\n",
    "    ],\n",
    "    temperature=0, \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b1e7b952-30b3-43b3-aa11-06595a7f592c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "We can start by looking at the divisors of 6, which are 1, 2, 3, and 6. Any common divisor of two integers must be a divisor of both integers, so any set of common divisors must contain these four numbers.\n",
      "\n",
      "Option (A) contains only these four numbers, so it is a possible set of common divisors.\n",
      "\n",
      "Option (B) adds 0 to the set, but 0 is not a divisor of any non-zero integer, so this set is not possible.\n",
      "\n",
      "Option (C) adds -3 to the set, but -3 is not a divisor of 6, so this set is not possible.\n",
      "\n",
      "Option (D) adds 0 and -3 to the set, but as mentioned before, 0 is not a divisor of any non-zero integer, and -3 is not a divisor of 6, so this set is not possible.\n",
      "\n",
      "Therefore, the only possible set of all common divisors of two integers is {-6, -2, -1, 1, 2, 6}, which is option (A).\n"
     ]
    }
   ],
   "source": [
    "print(response['choices'][0]['message']['content'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "867945a0-6527-409e-801e-4244eaf75e61",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def test_answer_mmlu(pred_str, ans_str):\n",
    "    pattern = 'the answer is ('\n",
    "    pred = pred_str.lower().split(pattern)\n",
    "    \n",
    "    if(len(pred) > 1):\n",
    "        # print(pred)\n",
    "        pred = pred[1][0]\n",
    "        gold = ans_str.split('A:\\n')[1][0].lower()\n",
    "        # print('debug 1, pred %s, gold %s' % (pred, gold))\n",
    "        return pred == gold\n",
    "    else: \n",
    "        pred = 'C'\n",
    "        gold = ans_str.split('A:\\n')[1][0].lower()\n",
    "        # print('debug 2, pred %s, gold %s' % (pred, gold))\n",
    "        return pred == gold\n",
    "\n",
    "def parse_pred_ans(filename):\n",
    "    with open(filename) as fd: lines = fd.readlines()\n",
    "    am, a = None, None\n",
    "    num_q, acc = 0, 0\n",
    "    current_mode = 'none'\n",
    "    questions = []\n",
    "    ans_pred = []\n",
    "    ans_gold = []\n",
    "    for l in lines:\n",
    "        if(l.startswith('Q: ')):\n",
    "            if(am is not None and a is not None):\n",
    "                questions.append(q)\n",
    "                ans_pred.append(am)\n",
    "                ans_gold.append(a)\n",
    "                # print(am)\n",
    "                # print(a)\n",
    "                if(test_answer_mmlu(am, a)):\n",
    "                    acc += 1\n",
    "            current_mode = 'q'\n",
    "            q = l\n",
    "            num_q += 1\n",
    "        elif(l.startswith('A_model:')):\n",
    "            current_mode = 'am'\n",
    "            am = l\n",
    "        elif(l.startswith('A:')):\n",
    "            current_mode = 'a'\n",
    "            a = l\n",
    "        else:\n",
    "            if(current_mode == 'q'): q += l\n",
    "            elif(current_mode == 'am'): am += l\n",
    "            elif(current_mode == 'a'): a += l\n",
    "            else:\n",
    "                raise ValueError(current_mode)\n",
    "                \n",
    "    questions.append(q)\n",
    "    ans_pred.append(am)\n",
    "    ans_gold.append(a)\n",
    "    # print(am)\n",
    "    # print(a)\n",
    "    if(test_answer_mmlu(am, a)):\n",
    "        acc += 1\n",
    "    print('num_q %d correct %d ratio %.4f' % (num_q, acc, float(acc / num_q)))\n",
    "    return questions, ans_pred, ans_gold\n",
    "\n",
    "def test_finished(ans_model):\n",
    "    if('answer is' in ans_model): return True\n",
    "    else: return False\n",
    "\n",
    "def extract_ans(ans_model):\n",
    "    ans_model = ans_model.split('\\n')\n",
    "    ans = []\n",
    "    residual = []\n",
    "    for li, al in enumerate(ans_model):\n",
    "        ans.append(al)\n",
    "        if('answer is' in al):\n",
    "            break\n",
    "    residual = list(ans_model[li + 1:])\n",
    "    ans = '\\n'.join(ans)\n",
    "    residual = '\\n'.join(residual)\n",
    "    return ans, residual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b398b88e-d8fd-47fc-be03-f4627f6719e1",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|                                                                                                                                                                                        | 0/99 [00:00<?, ?it/s]\u001b[A\n",
      "  1%|█▊                                                                                                                                                                              | 1/99 [00:09<15:02,  9.21s/it]\u001b[A\n",
      "  2%|███▌                                                                                                                                                                            | 2/99 [00:14<10:43,  6.64s/it]\u001b[A\n",
      "  3%|█████▎                                                                                                                                                                          | 3/99 [00:17<08:13,  5.14s/it]\u001b[A\n",
      "  4%|███████                                                                                                                                                                         | 4/99 [00:20<06:34,  4.15s/it]\u001b[A\n",
      "  5%|████████▉                                                                                                                                                                       | 5/99 [00:23<06:06,  3.90s/it]\u001b[A\n",
      "  6%|██████████▋                                                                                                                                                                     | 6/99 [00:26<05:25,  3.50s/it]\u001b[A\n",
      "  7%|████████████▍                                                                                                                                                                   | 7/99 [00:28<04:42,  3.07s/it]\u001b[A\n",
      "  8%|██████████████▏                                                                                                                                                                 | 8/99 [00:30<04:06,  2.71s/it]\u001b[A\n",
      "  9%|████████████████                                                                                                                                                                | 9/99 [00:34<04:57,  3.31s/it]\u001b[A\n",
      " 10%|█████████████████▋                                                                                                                                                             | 10/99 [00:37<04:44,  3.19s/it]\u001b[A\n",
      " 11%|███████████████████▍                                                                                                                                                           | 11/99 [00:39<03:57,  2.70s/it]\u001b[A\n",
      " 12%|█████████████████████▏                                                                                                                                                         | 12/99 [00:46<05:45,  3.97s/it]\u001b[A\n",
      " 13%|██████████████████████▉                                                                                                                                                        | 13/99 [00:47<04:39,  3.24s/it]\u001b[A\n",
      " 14%|████████████████████████▋                                                                                                                                                      | 14/99 [00:50<04:24,  3.12s/it]\u001b[A\n",
      " 15%|██████████████████████████▌                                                                                                                                                    | 15/99 [00:55<05:00,  3.57s/it]\u001b[A\n",
      " 16%|████████████████████████████▎                                                                                                                                                  | 16/99 [00:57<04:15,  3.08s/it]\u001b[A\n",
      " 17%|██████████████████████████████                                                                                                                                                 | 17/99 [01:06<06:50,  5.00s/it]\u001b[A\n",
      " 18%|███████████████████████████████▊                                                                                                                                               | 18/99 [01:08<05:27,  4.04s/it]\u001b[A\n",
      " 19%|█████████████████████████████████▌                                                                                                                                             | 19/99 [01:10<04:22,  3.29s/it]\u001b[A\n",
      " 20%|███████████████████████████████████▎                                                                                                                                           | 20/99 [01:12<03:53,  2.95s/it]\u001b[A\n",
      " 21%|█████████████████████████████████████                                                                                                                                          | 21/99 [01:15<03:48,  2.93s/it]\u001b[A\n",
      " 22%|██████████████████████████████████████▉                                                                                                                                        | 22/99 [01:16<03:09,  2.46s/it]\u001b[A\n",
      " 23%|████████████████████████████████████████▋                                                                                                                                      | 23/99 [01:23<04:53,  3.86s/it]\u001b[A\n",
      " 24%|██████████████████████████████████████████▍                                                                                                                                    | 24/99 [01:26<04:16,  3.43s/it]\u001b[A\n",
      " 25%|████████████████████████████████████████████▏                                                                                                                                  | 25/99 [01:28<03:45,  3.05s/it]\u001b[A\n",
      " 26%|█████████████████████████████████████████████▉                                                                                                                                 | 26/99 [01:28<02:50,  2.34s/it]\u001b[A\n",
      " 27%|███████████████████████████████████████████████▋                                                                                                                               | 27/99 [01:30<02:41,  2.24s/it]\u001b[A\n",
      " 28%|█████████████████████████████████████████████████▍                                                                                                                             | 28/99 [01:33<02:45,  2.33s/it]\u001b[A\n",
      " 29%|███████████████████████████████████████████████████▎                                                                                                                           | 29/99 [01:36<02:50,  2.43s/it]\u001b[A\n",
      " 30%|█████████████████████████████████████████████████████                                                                                                                          | 30/99 [01:38<02:46,  2.42s/it]\u001b[A\n",
      " 31%|██████████████████████████████████████████████████████▊                                                                                                                        | 31/99 [01:41<02:46,  2.45s/it]\u001b[A\n",
      " 32%|████████████████████████████████████████████████████████▌                                                                                                                      | 32/99 [01:44<02:58,  2.67s/it]\u001b[A\n",
      " 33%|██████████████████████████████████████████████████████████▎                                                                                                                    | 33/99 [01:46<02:48,  2.55s/it]\u001b[A\n",
      " 34%|████████████████████████████████████████████████████████████                                                                                                                   | 34/99 [01:47<02:20,  2.16s/it]\u001b[A\n",
      " 35%|█████████████████████████████████████████████████████████████▊                                                                                                                 | 35/99 [01:50<02:27,  2.30s/it]\u001b[A\n",
      " 36%|███████████████████████████████████████████████████████████████▋                                                                                                               | 36/99 [01:52<02:20,  2.24s/it]\u001b[A\n",
      " 37%|█████████████████████████████████████████████████████████████████▍                                                                                                             | 37/99 [01:57<03:09,  3.05s/it]\u001b[A\n",
      " 38%|███████████████████████████████████████████████████████████████████▏                                                                                                           | 38/99 [02:00<02:58,  2.92s/it]\u001b[A\n",
      " 39%|████████████████████████████████████████████████████████████████████▉                                                                                                          | 39/99 [02:02<02:38,  2.65s/it]\u001b[A\n",
      " 40%|██████████████████████████████████████████████████████████████████████▋                                                                                                        | 40/99 [02:04<02:34,  2.62s/it]\u001b[A\n",
      " 41%|████████████████████████████████████████████████████████████████████████▍                                                                                                      | 41/99 [02:06<02:23,  2.48s/it]\u001b[A\n",
      " 42%|██████████████████████████████████████████████████████████████████████████▏                                                                                                    | 42/99 [02:13<03:28,  3.66s/it]\u001b[A\n",
      " 43%|████████████████████████████████████████████████████████████████████████████                                                                                                   | 43/99 [02:17<03:44,  4.01s/it]\u001b[A\n",
      " 44%|█████████████████████████████████████████████████████████████████████████████▊                                                                                                 | 44/99 [02:20<03:17,  3.58s/it]\u001b[A\n",
      " 45%|███████████████████████████████████████████████████████████████████████████████▌                                                                                               | 45/99 [02:22<02:54,  3.23s/it]\u001b[A\n",
      " 46%|█████████████████████████████████████████████████████████████████████████████████▎                                                                                             | 46/99 [02:25<02:35,  2.94s/it]\u001b[A\n",
      " 47%|███████████████████████████████████████████████████████████████████████████████████                                                                                            | 47/99 [02:28<02:38,  3.05s/it]\u001b[A\n",
      " 48%|████████████████████████████████████████████████████████████████████████████████████▊                                                                                          | 48/99 [02:38<04:17,  5.05s/it]\u001b[A\n",
      " 49%|██████████████████████████████████████████████████████████████████████████████████████▌                                                                                        | 49/99 [02:43<04:20,  5.21s/it]\u001b[A\n",
      " 51%|████████████████████████████████████████████████████████████████████████████████████████▍                                                                                      | 50/99 [02:46<03:34,  4.38s/it]\u001b[A\n",
      " 52%|██████████████████████████████████████████████████████████████████████████████████████████▏                                                                                    | 51/99 [02:48<03:05,  3.86s/it]\u001b[A\n",
      " 53%|███████████████████████████████████████████████████████████████████████████████████████████▉                                                                                   | 52/99 [02:58<04:22,  5.59s/it]\u001b[A\n",
      " 54%|█████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                 | 53/99 [03:00<03:19,  4.35s/it]\u001b[A\n",
      " 55%|███████████████████████████████████████████████████████████████████████████████████████████████▍                                                                               | 54/99 [03:05<03:31,  4.69s/it]\u001b[A\n",
      " 56%|█████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                             | 55/99 [03:09<03:10,  4.34s/it]\u001b[A\n",
      " 57%|██████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                            | 56/99 [03:10<02:30,  3.50s/it]\u001b[A\n",
      " 58%|████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                          | 57/99 [03:12<02:10,  3.10s/it]\u001b[A\n",
      " 59%|██████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                        | 58/99 [03:19<02:48,  4.11s/it]\u001b[A\n",
      " 60%|████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                      | 59/99 [03:21<02:24,  3.61s/it]\u001b[A\n",
      " 61%|██████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                     | 60/99 [03:29<03:13,  4.97s/it]\u001b[A\n",
      " 62%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                   | 61/99 [03:31<02:31,  3.98s/it]\u001b[A\n",
      " 63%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                 | 62/99 [03:32<01:59,  3.24s/it]\u001b[A\n",
      " 64%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                               | 63/99 [03:34<01:36,  2.67s/it]\u001b[A\n",
      " 65%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                             | 64/99 [03:38<01:51,  3.20s/it]\u001b[A\n",
      " 66%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                            | 65/99 [03:48<02:58,  5.25s/it]\u001b[A\n",
      " 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                          | 66/99 [03:52<02:36,  4.76s/it]\u001b[A\n",
      " 68%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                        | 67/99 [03:55<02:17,  4.30s/it]\u001b[A\n",
      " 69%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                      | 68/99 [03:58<02:01,  3.92s/it]\u001b[A\n",
      " 70%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                     | 69/99 [04:04<02:13,  4.44s/it]\u001b[A\n",
      " 71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                   | 70/99 [04:05<01:42,  3.54s/it]\u001b[A\n",
      " 72%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                 | 71/99 [04:08<01:36,  3.45s/it]\u001b[A\n",
      " 73%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                               | 72/99 [04:12<01:34,  3.49s/it]\u001b[A\n",
      " 74%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                              | 73/99 [04:21<02:09,  4.99s/it]\u001b[A\n",
      " 75%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                            | 74/99 [04:22<01:41,  4.07s/it]\u001b[A\n",
      " 76%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                          | 75/99 [04:26<01:35,  3.96s/it]\u001b[A\n",
      " 77%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                        | 76/99 [04:28<01:14,  3.23s/it]\u001b[A\n",
      " 78%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                       | 77/99 [04:32<01:15,  3.43s/it]\u001b[A\n",
      " 79%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                     | 78/99 [04:36<01:19,  3.81s/it]\u001b[A\n",
      " 80%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                   | 79/99 [04:40<01:16,  3.81s/it]\u001b[A\n",
      " 81%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                 | 80/99 [04:43<01:09,  3.65s/it]\u001b[A\n",
      " 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                               | 81/99 [04:45<00:55,  3.07s/it]\u001b[A\n",
      " 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                              | 82/99 [04:47<00:46,  2.72s/it]\u001b[A\n",
      " 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                            | 83/99 [04:50<00:43,  2.70s/it]\u001b[A\n",
      " 85%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                          | 84/99 [04:59<01:12,  4.83s/it]\u001b[A\n",
      " 86%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                        | 85/99 [05:07<01:17,  5.54s/it]\u001b[A\n",
      " 87%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                       | 86/99 [05:09<01:00,  4.64s/it]\u001b[A\n",
      " 88%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                     | 87/99 [05:11<00:46,  3.91s/it]\u001b[A\n",
      " 89%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                   | 88/99 [05:14<00:37,  3.44s/it]\u001b[A\n",
      " 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                 | 89/99 [05:15<00:28,  2.85s/it]\u001b[A\n",
      " 91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                | 90/99 [05:20<00:30,  3.40s/it]\u001b[A\n",
      " 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊              | 91/99 [05:23<00:25,  3.21s/it]\u001b[A\n",
      " 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋            | 92/99 [05:29<00:28,  4.05s/it]\u001b[A\n",
      " 94%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍          | 93/99 [05:31<00:21,  3.61s/it]\u001b[A\n",
      " 95%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏        | 94/99 [05:35<00:18,  3.70s/it]\u001b[A\n",
      " 96%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉       | 95/99 [05:39<00:14,  3.63s/it]\u001b[A\n",
      " 97%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋     | 96/99 [05:40<00:08,  2.96s/it]\u001b[A\n",
      " 98%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍   | 97/99 [05:43<00:05,  2.87s/it]\u001b[A\n",
      " 99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 98/99 [05:48<00:03,  3.53s/it]\u001b[A\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [05:50<00:00,  3.54s/it]\u001b[A\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "with open('outputs/test_gpt_3.5_turbo_%s.txt' % task, 'w') as fd:\n",
    "    for q_ in tqdm(task_data['test'], total=len(task_data['test'])):\n",
    "        q = q_['input'] + '\\n'\n",
    "        for letter in ['A', 'B', 'C', 'D']:\n",
    "            q += '(' + letter + ') ' + q_[letter] + ' '\n",
    "        q += \"\\nA: Let's think step by step.\"  \n",
    "            \n",
    "        prompt_q = mmlu_prompt[task] + \"\\n\\n\" + q\n",
    "\n",
    "        response = completion_with_backoff(\n",
    "              model=\"gpt-3.5-turbo\",\n",
    "              messages=[\n",
    "                    {\"role\": \"system\", \"content\": \"Follow the given examples and answer the question.\"},\n",
    "                    {\"role\": \"user\", \"content\": prompt_q},\n",
    "                ],\n",
    "            temperature=0\n",
    "            )\n",
    "        ans_model = response['choices'][0]['message']['content']\n",
    "        ans_, residual = extract_ans(ans_model)\n",
    "            \n",
    "        a = q_['target']\n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))\n",
    "        i += 1\n",
    "        # if(i == 2): break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8ea5012b-444a-49d8-8752-51ea485d9beb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "college_computer_science\n",
      "num_q 99 correct 38 ratio 0.3838\n"
     ]
    }
   ],
   "source": [
    "print(task)\n",
    "_, _, _ = parse_pred_ans('outputs/test_gpt_3.5_turbo_%s.txt' % task)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
