{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ecfa1755-3d2f-43b8-9883-920c28b116f0",
   "metadata": {},
   "source": [
    "# ICL Step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "896b6ab0-57db-48ef-af26-3656336064ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from reasoning_functions import *\n",
    "with open('data/train_gsm8k_llama2_7b_full_weak.json') as f:\n",
    "    llama_dat = json.load(f)\n",
    "with open('data/train_gsm8k_gemma_2b_full_weak.json') as f:\n",
    "    gemma_dat = json.load(f)\n",
    "    #print(d)\n",
    "with open('data/train_gsm8k_mistral_7b_full_weak.json') as f:\n",
    "    mistral_dat = json.load(f)\n",
    "with open('data/gsm8k.jsonl') as f:\n",
    "    test_data = [json.loads(line) for line in f]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2c838db7-fb05-46b9-bd07-02e7a2006dd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "llama_icl_examples = ['''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. \n",
    "How many clips did Natalia sell altogether in April and May?\\nAnswer: \n",
    "Natalia sold 48 x 1/2 = 24 clips in May.\\nSo, she sold 48 + 24 = 72 clips altogether in April and May.\\n#### 72\\n\\n''', \n",
    "''' Question: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?\\nAnswer:\n",
    "She earns $12 an hour and she just did 50 minutes of babysitting so she earned 12/60 = $0.2 an hour.\\nSo she earned $0.2 x 50 = $10.\\n#### 10\\n\\n''',\n",
    "'''Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there \n",
    "are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. \n",
    "How many flowers does Mark have in his garden? \\nAnswer: There are 10 x 1.8 = 18 purple flowers.\\n\\nThere are 10 x 25% = \n",
    "2.5 green flowers.\\nThere are 10 + 18 + 2.5 = 30 flowers in total.\\n#### 30''']\n",
    "\n",
    "gemma_icl_examples = ['''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. \n",
    "How many clips did Natalia sell altogether in April and May?\\nAnswer: In April, Natalia sold 48 clips.\\nIn May, \n",
    "Natalia sold half as many clips as she did in April, so she sold 48/2=24 clips.\\nIn April and May, Natalia sold 48+24=72 clips.\\n#### 72''',\n",
    " '''Question: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?\\nAnswer:\n",
    " 12 * 50 / 60 = 10\\nShe earned 10 dollars.\\n#### 10 ''', '''Question: Mark has a garden with flowers. He planted plants \n",
    " of three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. There are only 25% as many green flowers as \n",
    " there are yellow and purple flowers. How many flowers does Mark have in his garden?\\nAnswer: There are 10 x 80/100 = 8 \n",
    " purple flowers.\\nThere are 25/100 x 10 = 2.5 green flowers.\\nSo, there are 10 + 8 + 2.5 = 12.5 green flowers.\\nTherefore, Mark has \n",
    " 12.5 x 100/25 = 50 flowers in his garden.\\n#### 50 ''' ]\n",
    "\n",
    "mistral_icl_examples = ['''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. \n",
    "How many clips did Natalia sell altogether in April and May?\\nAnswer: In May, Natalia sold 48/2 = 24 clips.\\nAltogether, \n",
    "Natalia sold 48 + 24 = 72 clips.\\n#### 72''', '''Question: Weng earns $12 an hour for babysitting. Yesterday, she just did \n",
    "50 minutes of babysitting. How much did she earn?\\nAnswer: She was paid for 50 minutes of babysitting.\\nSo she was paid for 50/60 = 0.833 \n",
    "of an hour.\\nSo she earned 12 * 0.833 = $9.99.\\n#### 9.99''', '''Question: Mark has a garden with flowers. He planted plants of \n",
    "three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. \n",
    "There are only 25% as many green flowers as there are yellow and purple flowers. How many flowers does Mark have in his garden?\\nAnswer:\n",
    "There are 10 * 0.8 = 8 more purple flowers than yellow flowers.\\nSo there are 10 + 8 = 18 purple flowers.\\nGreen flowers \n",
    "are 25/100 * (10 + 18) = 5.5 times less numerous than yellow and purple flowers.\\nSo there are 10 / 5.5 = 1.81818''']\n",
    "\n",
    "# getting the training examples, I use the first 200 samples exluding the icl examples \n",
    "training_questions = [llama_dat[i]['content'] for i in range(3, 203)]\n",
    "llama_ans = [llama_dat[i]['output'] for i in range(3, 203)]\n",
    "gemma_ans = [gemma_dat[i]['output'] for i in range(3, 203)]\n",
    "mistral_ans = [mistral_dat[i]['output'] for i in range(3, 203)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5161bd47-56d9-4c35-a238-12e26c0ebf5a",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'oracle_validation_resp' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[13], line 17\u001b[0m\n\u001b[1;32m     15\u001b[0m system_prompt \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'''\u001b[39m\u001b[38;5;124m'''\u001b[39m\n\u001b[1;32m     16\u001b[0m SaveJSONL(system_prompt, oracle_training_questions, oracle_train_resp, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdata/gsm8k_oracle_train.jsonl\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m---> 17\u001b[0m SaveJSONL(system_prompt, oracle_validation_questions, \u001b[43moracle_validation_resp\u001b[49m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdata/gsm8k_oracle_validation.jsonl\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'oracle_validation_resp' is not defined"
     ]
    }
   ],
   "source": [
    "##Construct Oracle Data\n",
    "oracle_training_questions = [llama_dat[i]['content'] for i in range(3, 203)]\n",
    "oracle_validation_questions = [llama_dat[i]['content'] for i in range(203, 303)]\n",
    "oracle = 'gpt-4o'\n",
    "oracle_train_resp = []\n",
    "oracle_val_resp = []\n",
    "for question in oracle_training_questions:\n",
    "    system_prompt = ''''''\n",
    "    oprompt = FormatInput(system_prompt, question, oracle)\n",
    "    oracle_train_resp.append(QueryModel(oprompt, oracle, api='OPENAI'))\n",
    "for question in oracle_validation_questions:\n",
    "    system_prompt = ''''''\n",
    "    oprompt = FormatInput(system_prompt, question, oracle)\n",
    "    oracle_val_resp.append(QueryModel(oprompt, oracle, api='OPENAI'))\n",
    "system_prompt = ''''''\n",
    "SaveJSONL(system_prompt, oracle_training_questions, oracle_train_resp, 'data/gsm8k_oracle_train.jsonl')\n",
    "SaveJSONL(system_prompt, oracle_validation_questions, oracle_validation_resp, 'data/gsm8k_oracle_validation.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a80f3ace-5b7c-4588-8b33-309b615307f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "SaveJSONL(system_prompt, oracle_validation_questions, oracle_val_resp, 'data/gsm8k_oracle_validation.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c4804145-fdb2-4d86-8c86-ca939df2a125",
   "metadata": {},
   "outputs": [],
   "source": [
    "##Construct ICL Data\n",
    "improvement_model = 'gpt-3.5-turbo'\n",
    "ICL_llama_ans = []\n",
    "ICL_gemma_ans = []\n",
    "ICL_mistral_ans = []\n",
    "for question in training_questions:\n",
    "    system_prompt = ''''''\n",
    "    llama_ICL_prompt = ''.join(llama_icl_examples)\n",
    "    gemma_ICL_prompt = '\\n\\n'.join(gemma_icl_examples)\n",
    "    mistral_ICL_prompt = '\\n\\n'.join(mistral_icl_examples)\n",
    "    llama_prompt = llama_ICL_prompt+question\n",
    "    gemma_prompt = gemma_ICL_prompt+question\n",
    "    mistral_prompt = mistral_ICL_prompt+question\n",
    "    llama_prompt = FormatInput(system_prompt, gemma_prompt, improvement_model)\n",
    "    gemma_prompt = FormatInput(system_prompt, gemma_prompt, improvement_model)\n",
    "    mistral_prompt = FormatInput(system_prompt, mistral_prompt, improvement_model)\n",
    "    ICL_llama_ans.append(QueryModel(llama_prompt, improvement_model, api='OPENAI'))\n",
    "    ICL_gemma_ans.append(QueryModel(gemma_prompt, improvement_model, api='OPENAI'))\n",
    "    ICL_mistral_ans.append(QueryModel(mistral_prompt, improvement_model, api='OPENAI'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a4818088-7742-44b5-83b2-86421cfccf49",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_prompt = ''''''\n",
    "SaveJSONL(system_prompt, gold_train_quest, gold_train_resp, 'data/gsm8k_oracle_train.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "415437e7-de27-4008-8a69-6b386692423e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# here we save the improved icl examples as a jsonl file\n",
    "system_prompt = ''''''\n",
    "SaveJSONL(system_prompt, training_questions, ICL_llama_ans, 'gsm8k_llama2_ICL.jsonl')\n",
    "SaveJSONL(system_prompt, training_questions, llama_ans, 'gsm8k_llama2_weak.jsonl')\n",
    "SaveJSONL(system_prompt, training_questions, ICL_gemma_ans, 'gsm8k_gemma_ICL.jsonl')\n",
    "SaveJSONL(system_prompt, training_questions, gemma_ans, 'gsm8k_gemma.jsonl')\n",
    "SaveJSONL(system_prompt, training_questions, ICL_mistral_ans, 'gsm8k_mistral_ICL.jsonl')\n",
    "SaveJSONL(system_prompt, training_questions, mistral_ans, 'gsm8k_mistral_weak.jsonl')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0859367-f459-4811-baf6-fde6746a6e4b",
   "metadata": {},
   "source": [
    "# Finetuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f63f43fb-a530-4d00-8863-3d96d5e2818e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing file: data/gsm8k_oracle_train.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 121, 610\n",
      "mean / median: 321.005, 320.5\n",
      "p5 / p95: 208.9, 435.2\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 76, 513\n",
      "mean / median: 242.545, 242.0\n",
      "p5 / p95: 152.70000000000002, 339.29999999999995\n",
      "**************************************************\n"
     ]
    }
   ],
   "source": [
    "improvement_model = 'gpt-3.5-turbo' \n",
    "frac_train = 1 #Do not modify this. Validation data will not be used anyways...\n",
    "CheckTokens(f'data/gsm8k_oracle_train.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "31f5e5ab-fab0-4ec1-b05a-8675e66e865f",
   "metadata": {},
   "outputs": [
    {
     "ename": "BadRequestError",
     "evalue": "Error code: 400 - {'error': {'message': 'Model gpt-4o-mini is not available for fine-tuning or does not exist.', 'type': 'invalid_request_error', 'param': None, 'code': 'model_not_available'}}",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mBadRequestError\u001b[0m                           Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[16], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m oracle_turbo \u001b[38;5;241m=\u001b[39m \u001b[43mFineTune\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtraining_file_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mdata/gsm8k_oracle_train.jsonl\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mvalidation_file_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mimprovement_model\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;66;03m#oralce_turbo_name = GetFineTunedModelName(weak_fine_tune_job)\u001b[39;00m\n",
      "File \u001b[0;32m~/Documents/supertransfer/w2s reasoning/reasoning_functions.py:43\u001b[0m, in \u001b[0;36mFineTune\u001b[0;34m(training_file_name, validation_file_name, model)\u001b[0m\n\u001b[1;32m     38\u001b[0m validation_response \u001b[38;5;241m=\u001b[39m openai_client\u001b[38;5;241m.\u001b[39mfiles\u001b[38;5;241m.\u001b[39mcreate(\n\u001b[1;32m     39\u001b[0m     file\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mopen\u001b[39m(validation_file_name, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrb\u001b[39m\u001b[38;5;124m\"\u001b[39m), purpose\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfine-tune\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m     40\u001b[0m )\n\u001b[1;32m     41\u001b[0m validation_file_id \u001b[38;5;241m=\u001b[39m validation_response\u001b[38;5;241m.\u001b[39mid\n\u001b[0;32m---> 43\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mopenai_client\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfine_tuning\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjobs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     44\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtraining_file\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtraining_file_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     45\u001b[0m \u001b[43m    \u001b[49m\u001b[43mvalidation_file\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidation_file_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     46\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Enter base model name. Note that in Azure OpenAI the model name contains dashes and cannot contain dot/period characters. \u001b[39;49;00m\n\u001b[1;32m     47\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     49\u001b[0m job_id \u001b[38;5;241m=\u001b[39m response\u001b[38;5;241m.\u001b[39mid\n\u001b[1;32m     51\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m job_id\n",
      "File \u001b[0;32m~/miniconda3/envs/GPT4/lib/python3.12/site-packages/openai/resources/fine_tuning/jobs.py:104\u001b[0m, in \u001b[0;36mJobs.create\u001b[0;34m(self, model, training_file, hyperparameters, suffix, validation_file, extra_headers, extra_query, extra_body, timeout)\u001b[0m\n\u001b[1;32m     37\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate\u001b[39m(\n\u001b[1;32m     38\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m     39\u001b[0m     \u001b[38;5;241m*\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     50\u001b[0m     timeout: \u001b[38;5;28mfloat\u001b[39m \u001b[38;5;241m|\u001b[39m httpx\u001b[38;5;241m.\u001b[39mTimeout \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m|\u001b[39m NotGiven \u001b[38;5;241m=\u001b[39m NOT_GIVEN,\n\u001b[1;32m     51\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m FineTuningJob:\n\u001b[1;32m     52\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m     53\u001b[0m \u001b[38;5;124;03m    Creates a job that fine-tunes a specified model from a given dataset.\u001b[39;00m\n\u001b[1;32m     54\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    102\u001b[0m \u001b[38;5;124;03m      timeout: Override the client-level default timeout for this request, in seconds\u001b[39;00m\n\u001b[1;32m    103\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> 104\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_post\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    105\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m/fine_tuning/jobs\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    106\u001b[0m \u001b[43m        \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmaybe_transform\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    107\u001b[0m \u001b[43m            \u001b[49m\u001b[43m{\u001b[49m\n\u001b[1;32m    108\u001b[0m \u001b[43m                \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodel\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    109\u001b[0m \u001b[43m                \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_file\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraining_file\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    110\u001b[0m \u001b[43m                \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mhyperparameters\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mhyperparameters\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    111\u001b[0m \u001b[43m                \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43msuffix\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43msuffix\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    112\u001b[0m \u001b[43m                \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mvalidation_file\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalidation_file\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    113\u001b[0m \u001b[43m            \u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    114\u001b[0m \u001b[43m            \u001b[49m\u001b[43mjob_create_params\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mJobCreateParams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    115\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    116\u001b[0m \u001b[43m        \u001b[49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmake_request_options\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    117\u001b[0m \u001b[43m            \u001b[49m\u001b[43mextra_headers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextra_headers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mextra_query\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextra_query\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mextra_body\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextra_body\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\n\u001b[1;32m    118\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    119\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcast_to\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mFineTuningJob\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    120\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/GPT4/lib/python3.12/site-packages/openai/_base_client.py:1088\u001b[0m, in \u001b[0;36mSyncAPIClient.post\u001b[0;34m(self, path, cast_to, body, options, files, stream, stream_cls)\u001b[0m\n\u001b[1;32m   1074\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mpost\u001b[39m(\n\u001b[1;32m   1075\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m   1076\u001b[0m     path: \u001b[38;5;28mstr\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1083\u001b[0m     stream_cls: \u001b[38;5;28mtype\u001b[39m[_StreamT] \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1084\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ResponseT \u001b[38;5;241m|\u001b[39m _StreamT:\n\u001b[1;32m   1085\u001b[0m     opts \u001b[38;5;241m=\u001b[39m FinalRequestOptions\u001b[38;5;241m.\u001b[39mconstruct(\n\u001b[1;32m   1086\u001b[0m         method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpost\u001b[39m\u001b[38;5;124m\"\u001b[39m, url\u001b[38;5;241m=\u001b[39mpath, json_data\u001b[38;5;241m=\u001b[39mbody, files\u001b[38;5;241m=\u001b[39mto_httpx_files(files), \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39moptions\n\u001b[1;32m   1087\u001b[0m     )\n\u001b[0;32m-> 1088\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m cast(ResponseT, \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcast_to\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopts\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstream\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstream\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstream_cls\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstream_cls\u001b[49m\u001b[43m)\u001b[49m)\n",
      "File \u001b[0;32m~/miniconda3/envs/GPT4/lib/python3.12/site-packages/openai/_base_client.py:853\u001b[0m, in \u001b[0;36mSyncAPIClient.request\u001b[0;34m(self, cast_to, options, remaining_retries, stream, stream_cls)\u001b[0m\n\u001b[1;32m    844\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrequest\u001b[39m(\n\u001b[1;32m    845\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m    846\u001b[0m     cast_to: Type[ResponseT],\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    851\u001b[0m     stream_cls: \u001b[38;5;28mtype\u001b[39m[_StreamT] \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m    852\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ResponseT \u001b[38;5;241m|\u001b[39m _StreamT:\n\u001b[0;32m--> 853\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_request\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    854\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcast_to\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcast_to\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    855\u001b[0m \u001b[43m        \u001b[49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    856\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstream\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstream\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    857\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstream_cls\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstream_cls\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    858\u001b[0m \u001b[43m        \u001b[49m\u001b[43mremaining_retries\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mremaining_retries\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    859\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/GPT4/lib/python3.12/site-packages/openai/_base_client.py:930\u001b[0m, in \u001b[0;36mSyncAPIClient._request\u001b[0;34m(self, cast_to, options, remaining_retries, stream, stream_cls)\u001b[0m\n\u001b[1;32m    927\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m err\u001b[38;5;241m.\u001b[39mresponse\u001b[38;5;241m.\u001b[39mis_closed:\n\u001b[1;32m    928\u001b[0m         err\u001b[38;5;241m.\u001b[39mresponse\u001b[38;5;241m.\u001b[39mread()\n\u001b[0;32m--> 930\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_status_error_from_response(err\u001b[38;5;241m.\u001b[39mresponse) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    932\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_response(\n\u001b[1;32m    933\u001b[0m     cast_to\u001b[38;5;241m=\u001b[39mcast_to,\n\u001b[1;32m    934\u001b[0m     options\u001b[38;5;241m=\u001b[39moptions,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    937\u001b[0m     stream_cls\u001b[38;5;241m=\u001b[39mstream_cls,\n\u001b[1;32m    938\u001b[0m )\n",
      "\u001b[0;31mBadRequestError\u001b[0m: Error code: 400 - {'error': {'message': 'Model gpt-4o-mini is not available for fine-tuning or does not exist.', 'type': 'invalid_request_error', 'param': None, 'code': 'model_not_available'}}"
     ]
    }
   ],
   "source": [
    "oracle_turbo = FineTune(training_file_name='data/gsm8k_oracle_train.jsonl',validation_file_name=None, model = improvement_model)\n",
    "#oralce_turbo_name = GetFineTunedModelName(weak_fine_tune_job)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4aa08237-5c8d-4549-ae66-aecb1cf6ca95",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing file: gsm8k_llama2_weak.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 67, 562\n",
      "mean / median: 161.81, 148.5\n",
      "p5 / p95: 103.9, 222.0\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 17, 494\n",
      "mean / median: 83.35, 69.0\n",
      "p5 / p95: 43.0, 125.0\n",
      "**************************************************\n",
      "Processing file: gsm8k_llama2_ICL_mini.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 97, 595\n",
      "mean / median: 291.73, 275.5\n",
      "p5 / p95: 165.60000000000002, 446.5\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 50, 513\n",
      "mean / median: 213.27, 196.0\n",
      "p5 / p95: 105.9, 359.1\n",
      "**************************************************\n",
      "Processing file: gsm8k_gemma_weak.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 62, 628\n",
      "mean / median: 174.735, 150.5\n",
      "p5 / p95: 96.0, 241.29999999999998\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 10, 512\n",
      "mean / median: 96.275, 70.5\n",
      "p5 / p95: 34.900000000000006, 154.29999999999998\n",
      "**************************************************\n",
      "Processing file: gsm8k_gemma_ICL_mini.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 97, 2063\n",
      "mean / median: 297.895, 277.5\n",
      "p5 / p95: 167.9, 438.0\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 50, 1992\n",
      "mean / median: 219.435, 201.0\n",
      "p5 / p95: 110.9, 330.5999999999999\n",
      "**************************************************\n",
      "Processing file: gsm8k_mistral_weak.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 63, 316\n",
      "mean / median: 160.8, 158.0\n",
      "p5 / p95: 96.9, 234.1\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 18, 207\n",
      "mean / median: 82.34, 77.0\n",
      "p5 / p95: 40.0, 132.69999999999996\n",
      "**************************************************\n",
      "Processing file: gsm8k_mistral_ICL_mini.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 86, 2074\n",
      "mean / median: 298.47, 281.5\n",
      "p5 / p95: 166.0, 426.29999999999995\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 39, 2003\n",
      "mean / median: 220.01, 205.0\n",
      "p5 / p95: 108.9, 316.1999999999999\n",
      "**************************************************\n"
     ]
    }
   ],
   "source": [
    "CheckTokens(f'gsm8k_llama2_weak.jsonl')\n",
    "CheckTokens(f'gsm8k_llama2_ICL_mini.jsonl')\n",
    "CheckTokens(f'gsm8k_gemma_weak.jsonl')\n",
    "CheckTokens(f'gsm8k_gemma_ICL_mini.jsonl')\n",
    "CheckTokens(f'gsm8k_mistral_weak.jsonl')\n",
    "CheckTokens(f'gsm8k_mistral_ICL_mini.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "6868ed5f-830b-4378-a48d-2a38dc6b98fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "weak_fine_tune_job = FineTune('gsm8k_llama2_weak.jsonl')\n",
    "llama_weak_GPT_mini = GetFineTunedModelName(weak_fine_tune_job)\n",
    "print('llama_weak_gpt_mini')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "42d44876-7b85-48ff-818d-15a82533bfb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "llama_ICL_fine_tune_job = FineTune('gsm8k_llama2_ICL_mini.jsonl')\n",
    "llama_ICL_gpt = GetFineTunedModelName(llama_ICL_fine_tune_job)\n",
    "print('llama_ICL_gpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0faef555-e084-461a-9a8b-9ffaa6011a49",
   "metadata": {},
   "outputs": [],
   "source": [
    "gemma_weak_fine_tune_job = FineTune('gsm8k_gemma_weak.jsonl')\n",
    "gemma_weak_GPT_mini = GetFineTunedModelName(gemma_weak_fine_tune_job)\n",
    "print(gemma_weak_GPT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ef207e75-0978-4762-947b-658ec6ef23ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "gemma_ICL_fine_tune_job = FineTune('gsm8k_gemma_ICL_mini.jsonl')\n",
    "gemma_ICL_gpt_mini = GetFineTunedModelName(gemma_ICL_fine_tune_job)\n",
    "print(gemma_ICL_GPT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "824c06dc-af0e-46aa-a821-25e997140521",
   "metadata": {},
   "outputs": [],
   "source": [
    "mistral_weak_fine_tune_job = FineTune('gsm8k_mistral_weak.jsonl')\n",
    "mistral_weak_GPT = GetFineTunedModelName(mistral_weak_fine_tune_job)\n",
    "print(mistral_weak_GPT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "e96ebf77-6df6-48f5-894e-07abecc1160b",
   "metadata": {},
   "outputs": [],
   "source": [
    "mistral_ICL_fine_tune_job = FineTune('gsm8k_mistral_ICL_mini.jsonl')\n",
    "mistral_ICL_GPT_mini = GetFineTunedModelName(mistral_ICL_fine_tune_job)\n",
    "print(mistral_ICL_GPT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "238d35c9-551f-4ce3-a402-addb970c4223",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ft:gpt-3.5-turbo-0125:university-of-michigan::9u7BAH6O\n"
     ]
    }
   ],
   "source": [
    "print(llama_weak_GPT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "18e57c6b-65cd-4975-acf2-2c0b897be12b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fill these in from print statements or OpenAI website\n",
    "llama_ICL_id = \n",
    "llama_weak_id = \n",
    "gemma_weak_id  = \n",
    "gemma_ICL_id = \n",
    "mistral_weak_id = \n",
    "mistral_ICL_id = "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "91a12536-3a61-4e85-a7a8-c344e451c38f",
   "metadata": {},
   "outputs": [],
   "source": [
    "turbo_oracle_id = 'ft:gpt-3.5-turbo-0125:university-of-michigan::B2m3Oc0g'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d676071-9aa8-41e3-998a-46b84d51ca3a",
   "metadata": {},
   "source": [
    "# Test Responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fba24154-66b2-4e8b-81fb-a1d0f302d4bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('data/gsm8k.jsonl') as f:\n",
    "    test_data = [json.loads(line) for line in f]\n",
    "small_test = [test_data[i]['question'] for i in range(0, 100)]\n",
    "small_test_key = [test_data[i]['solution'] for i in range(0, 100)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bfa7600e-1442-4e17-8ce4-839d799756df",
   "metadata": {},
   "outputs": [],
   "source": [
    "##Construct test data\n",
    "turbo_oracle_test_ans = []\n",
    "llama_ICL_gpt_test_ans = []\n",
    "llama_weak_gpt_test_ans = []\n",
    "gemma_weak_gpt_test_ans = []\n",
    "gemma_ICL_gpt_test_ans = []\n",
    "mistral_weak_gpt_test_ans = []\n",
    "mistral_ICL_gpt_test_ans = []\n",
    "for i, question in enumerate(small_test):\n",
    "    system_prompt = ''''''\n",
    "    user_prompt = question\n",
    "    key = small_test_key[i]\n",
    "    prompt = FormatInput(system_prompt, user_prompt,'gpt-4o-mini')\n",
    "    llama_ICL_gpt_test_ans.append(QueryModel(prompt, llama_ICL_id, api='OPENAI'))\n",
    "    llama_weak_gpt_test_ans.append(QueryModel(prompt, llama_weak_id, api='OPENAI'))\n",
    "    gemma_weak_gpt_test_ans.append(QueryModel(prompt,gemma_weak_id,api = 'OPENAI'))\n",
    "    gemma_ICL_gpt_test_ans.append(QueryModel(prompt,gemma_ICL_id,api = 'OPENAI'))\n",
    "    mistral_weak_gpt_test_ans.append(QueryModel(prompt,mistral_weak_id,api = 'OPENAI'))\n",
    "    mistral_ICL_gpt_test_ans.append(QueryModel(prompt,mistral_ICL_id,api = 'OPENAI'))\n",
    "    turbo_oracle_test_ans.append(QueryModel(prompt, mini_oracle_id,api='OPENAI'))\n",
    "SaveJSONL(system_prompt, small_test, llama_ICL_gpt_test_ans, 'gsm8k_llama2_ICL_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, llama_weak_gpt_test_ans, 'gsm8k_llama2_weak_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, mistral_weak_gpt_test_ans, 'gsm8k_mistral_weak_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, mistral_ICL_gpt_test_ans, 'gsm8k_mistral_ICL_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, gemma_weak_gpt_test_ans, 'gsm8k_gemma_weak_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, gemma_ICL_gpt_test_ans, 'gsm8k_gemma_ICL_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, turbo_oracle_test_ans, 'data/gsm8k_oracle_test.jsonl')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae1b1d22-c6b0-473b-83ce-1933f6a3f660",
   "metadata": {},
   "source": [
    "# Eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6cf82689-a704-4423-acbc-f4654e2d6e69",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('data/gsm8k_llama2_ICL_test.jsonl') as f:\n",
    "    llama_ICL_test_resp = [json.loads(line) for line in f]\n",
    "with open('data/gsm8k_llama2_weak_test.jsonl') as f:\n",
    "    llama_weak_test_resp = [json.loads(line) for line in f]\n",
    "with open('data/gsm8k_mistral_weak_test.jsonl') as f:\n",
    "    mistral_weak_test_resp = [json.loads(line) for line in f]\n",
    "with open('data/gsm8k_mistral_ICL_test.jsonl') as f:\n",
    "    mistral_ICL_test_resp = [json.loads(line) for line in f]\n",
    "with open('data/gsm8k_gemma_weak_test.jsonl') as f:\n",
    "    gemma_weak_test_resp = [json.loads(line) for line in f]\n",
    "with open('data/gsm8k_gemma_ICL_test.jsonl') as f:\n",
    "    gemma_ICL_test_resp = [json.loads(line) for line in f]\n",
    "with open('data/gsm8k_mini_oracle_test.jsonl') as f:\n",
    "    oracle_test_resp = [json.loads(line) for line in f]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "64006bf9-2dc8-4f92-82cd-11581ae037b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "oracle_scores = []\n",
    "llama_ICL_scores, llama_weak_scores, baseline = [], [], []\n",
    "mistral_ICL_scores, mistral_weak_scores, gemma_ICL_scores, gemma_weak_scores = [], [], [], []\n",
    "gemma_weak_scores = []\n",
    "llama_rep_scores = []\n",
    "eval_model = 'gpt-4o'\n",
    "for i, testq in enumerate(small_test):\n",
    "    eval_sys = GetEvalSystemPrompt()\n",
    "    key = small_test_key[i]\n",
    "    eval_prompt_oracle = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, oracle_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    eval_prompt_llama_weak = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, llama_weak_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    eval_prompt_llama_rep = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, llama_weak_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    eval_prompt_llama_ICL = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, llama_ICL_test_resp[i]['messages'][2]['content']), model=eval_model)\n",
    "    try: #sometimes the gemma trained model fails to provide an answer\n",
    "        eval_prompt_gemma_weak = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, gemma_weak_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "        gemma_weak_scores.append(GPTEval(eval_prompt_gemma_weak, model=eval_model))\n",
    "    except: # if answer is blank score 0\n",
    "        print(i)\n",
    "        gemma_weak_scores.append(0)\n",
    "    eval_prompt_gemma_ICL = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, gemma_ICL_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    eval_prompt_mistral_weak = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, mistral_weak_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    eval_prompt_mistral_ICL = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, mistral_ICL_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    llama_ICL_scores.append(GPTEval(eval_prompt_llama_ICL, model=eval_model))\n",
    "    llama_weak_scores.append(GPTEval(eval_prompt_llama_weak, model=eval_model))\n",
    "    mistral_ICL_scores.append(GPTEval(eval_prompt_mistral_ICL, model=eval_model))\n",
    "    mistral_weak_scores.append(GPTEval(eval_prompt_mistral_weak, model=eval_model))\n",
    "    gemma_ICL_scores.append(GPTEval(eval_prompt_gemma_ICL, model=eval_model))\n",
    "    gemma_weak_scores.append(GPTEval(eval_prompt_gemma_weak, model=eval_model))\n",
    "    oracle_scores.append(GPTEval(eval_prompt_oracle, model=eval_model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a33603c7-5c6f-454f-98ab-52d179a0745f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "oracle:[0.9145]\n"
     ]
    }
   ],
   "source": [
    "print('oracle:'+str(sum(oracle_scores)/100))\n",
    "#print('baseline:'+str(sum(baseline)/100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "8aa81454-f2ab-4606-8fed-4238fe744e99",
   "metadata": {},
   "outputs": [],
   "source": [
    "#dictionary = {'baseline':baseline, 'lweak': llama_weak_scores, 'lICL': llama_ICL_scores, 'gweak': gemma_weak_scores, 'gICL:': gemma_ICL_scores, 'mweak': mistral_weak_scores, 'mICL': mistral_ICL_scores}\n",
    "#np.save('gsm8k_scores_mini.npy', dictionary) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "edcdc912-183a-4519-9f42-3da5211d5474",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "baseline:[0.9135] +/- 0.05540496367655159\n",
      "lweak:[0.5195] +/- 0.09982880345872126\n",
      "lICL:[0.9065] +/- 0.05726525997496214\n",
      "gweak:[0.77460526] +/- 0.09660641280911364\n",
      "gICL:[0.92] +/- 0.054258639865002144\n",
      "mweak:[0.7265] +/- 0.08733218192625214\n",
      "mICL:[0.9295] +/- 0.05101166533254918\n"
     ]
    }
   ],
   "source": [
    "print('lweak:'+str(sum(llama_weak_scores)/100)+' +/- '+str(2*np.std(llama_weak_scores)/10))\n",
    "print('lICL:'+str(sum(llama_ICL_scores)/100)+' +/- '+str(2*np.std(llama_ICL_scores)/10))\n",
    "print('gweak:'+str(sum(gemma_weak_scores)/100)+' +/- '+str(2*np.std(gemma_weak_scores)/10))\n",
    "print('gICL:'+str(sum(gemma_ICL_scores)/100)+' +/- '+str(2*np.std(gemma_ICL_scores)/10))\n",
    "print('mweak:'+str(sum(mistral_weak_scores)/100)+' +/- '+str(2*np.std(mistral_weak_scores)/10))\n",
    "print('mICL:'+str(sum(mistral_ICL_scores)/100)+' +/- '+str(2*np.std(mistral_ICL_scores)/10))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
