{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ecfa1755-3d2f-43b8-9883-920c28b116f0",
   "metadata": {},
   "source": [
    "# ICL Step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "896b6ab0-57db-48ef-af26-3656336064ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from reasoning_functions import *\n",
    "with open('train_gsm8k_llama2_7b_full_weak.json') as f:\n",
    "    llama_dat = json.load(f)\n",
    "with open('train_gsm8k_gemma_2b_full_weak.json') as f:\n",
    "    gemma_dat = json.load(f)\n",
    "    #print(d)\n",
    "with open('train_gsm8k_mistral_7b_full_weak.json') as f:\n",
    "    mistral_dat = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bbb9f0c5-b85c-4d07-ba9d-e2b4650dc9ac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'id': 2,\n",
       " 'content': 'Question: Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. How many flowers does Mark have in his garden?\\nAnswer:',\n",
       " 'output': 'There are 10 * 0.8 = 8 more purple flowers than yellow flowers.\\nSo there are 10 + 8 = 18 purple flowers.\\nGreen flowers are 25/100 * (10 + 18) = 5.5 times less numerous than yellow and purple flowers.\\nSo there are 10 / 5.5 = 1.81818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818181818'}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mistral_dat[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2c838db7-fb05-46b9-bd07-02e7a2006dd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "llama_icl_examples = ['''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. \n",
    "How many clips did Natalia sell altogether in April and May?\\nAnswer: \n",
    "Natalia sold 48 x 1/2 = 24 clips in May.\\nSo, she sold 48 + 24 = 72 clips altogether in April and May.\\n#### 72\\n\\n''', \n",
    "''' Question: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?\\nAnswer:\n",
    "She earns $12 an hour and she just did 50 minutes of babysitting so she earned 12/60 = $0.2 an hour.\\nSo she earned $0.2 x 50 = $10.\\n#### 10\\n\\n''',\n",
    "'''Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there \n",
    "are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. \n",
    "How many flowers does Mark have in his garden? \\nAnswer: There are 10 x 1.8 = 18 purple flowers.\\n\\nThere are 10 x 25% = \n",
    "2.5 green flowers.\\nThere are 10 + 18 + 2.5 = 30 flowers in total.\\n#### 30''']\n",
    "\n",
    "gemma_icl_examples = ['''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. \n",
    "How many clips did Natalia sell altogether in April and May?\\nAnswer: In April, Natalia sold 48 clips.\\nIn May, \n",
    "Natalia sold half as many clips as she did in April, so she sold 48/2=24 clips.\\nIn April and May, Natalia sold 48+24=72 clips.\\n#### 72''',\n",
    " '''Question: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?\\nAnswer:\n",
    " 12 * 50 / 60 = 10\\nShe earned 10 dollars.\\n#### 10 ''', '''Question: Mark has a garden with flowers. He planted plants \n",
    " of three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. There are only 25% as many green flowers as \n",
    " there are yellow and purple flowers. How many flowers does Mark have in his garden?\\nAnswer: There are 10 x 80/100 = 8 \n",
    " purple flowers.\\nThere are 25/100 x 10 = 2.5 green flowers.\\nSo, there are 10 + 8 + 2.5 = 12.5 green flowers.\\nTherefore, Mark has \n",
    " 12.5 x 100/25 = 50 flowers in his garden.\\n#### 50 ''' ]\n",
    "\n",
    "mistral_icl_examples = ['''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. \n",
    "How many clips did Natalia sell altogether in April and May?\\nAnswer: In May, Natalia sold 48/2 = 24 clips.\\nAltogether, \n",
    "Natalia sold 48 + 24 = 72 clips.\\n#### 72''', '''Question: Weng earns $12 an hour for babysitting. Yesterday, she just did \n",
    "50 minutes of babysitting. How much did she earn?\\nAnswer: She was paid for 50 minutes of babysitting.\\nSo she was paid for 50/60 = 0.833 \n",
    "of an hour.\\nSo she earned 12 * 0.833 = $9.99.\\n#### 9.99''', '''Question: Mark has a garden with flowers. He planted plants of \n",
    "three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. \n",
    "There are only 25% as many green flowers as there are yellow and purple flowers. How many flowers does Mark have in his garden?\\nAnswer:\n",
    "There are 10 * 0.8 = 8 more purple flowers than yellow flowers.\\nSo there are 10 + 8 = 18 purple flowers.\\nGreen flowers \n",
    "are 25/100 * (10 + 18) = 5.5 times less numerous than yellow and purple flowers.\\nSo there are 10 / 5.5 = 1.81818''']\n",
    "\n",
    "# getting the training examples, I use the first 200 samples exluding the icl examples \n",
    "training_questions = [llama_dat[i]['content'] for i in range(3, 203)]\n",
    "llama_ans = [llama_dat[i]['output'] for i in range(3, 203)]\n",
    "gemma_ans = [gemma_dat[i]['output'] for i in range(3, 203)]\n",
    "mistral_ans = [mistral_dat[i]['output'] for i in range(3, 203)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c4804145-fdb2-4d86-8c86-ca939df2a125",
   "metadata": {},
   "outputs": [],
   "source": [
    "##Construct ICL Data\n",
    "improvement_model = 'gpt-4o-mini-2024-07-18'\n",
    "ICL_llama_ans = []\n",
    "ICL_gemma_ans = []\n",
    "ICL_mistral_ans = []\n",
    "for question in training_questions:\n",
    "    system_prompt = ''''''\n",
    "    llama_ICL_prompt = ''.join(llama_icl_examples)\n",
    "    gemma_ICL_prompt = '\\n\\n'.join(gemma_icl_examples)\n",
    "    mistral_ICL_prompt = '\\n\\n'.join(mistral_icl_examples)\n",
    "    llama_prompt = llama_ICL_prompt+question\n",
    "    gemma_prompt = gemma_ICL_prompt+question\n",
    "    mistral_prompt = mistral_ICL_prompt+question\n",
    "    llama_prompt = FormatInput(system_prompt, gemma_prompt, improvement_model)\n",
    "    gemma_prompt = FormatInput(system_prompt, gemma_prompt, improvement_model)\n",
    "    mistral_prompt = FormatInput(system_prompt, mistral_prompt, improvement_model)\n",
    "    ICL_llama_ans.append(QueryModel(llama_prompt, improvement_model, api='OPENAI'))\n",
    "    ICL_gemma_ans.append(QueryModel(gemma_prompt, improvement_model, api='OPENAI'))\n",
    "    ICL_mistral_ans.append(QueryModel(mistral_prompt, improvement_model, api='OPENAI'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "415437e7-de27-4008-8a69-6b386692423e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# here we save the improved icl examples as a jsonl file\n",
    "system_prompt = ''''''\n",
    "SaveJSONL(system_prompt, training_questions, ICL_llama_ans, 'gsm8k_llama2_ICL_mini.jsonl')\n",
    "SaveJSONL(system_prompt, training_questions, llama_ans, 'gsm8k_llama2_weak.jsonl')\n",
    "SaveJSONL(system_prompt, training_questions, ICL_gemma_ans, 'gsm8k_gemma_ICL_mini.jsonl')\n",
    "SaveJSONL(system_prompt, training_questions, gemma_ans, 'gsm8k_gemma_weak.jsonl')\n",
    "SaveJSONL(system_prompt, training_questions, ICL_mistral_ans, 'gsm8k_mistral_ICL_mini.jsonl')\n",
    "SaveJSONL(system_prompt, training_questions, mistral_ans, 'gsm8k_mistral_weak.jsonl')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0859367-f459-4811-baf6-fde6746a6e4b",
   "metadata": {},
   "source": [
    "# Finetuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f63f43fb-a530-4d00-8863-3d96d5e2818e",
   "metadata": {},
   "outputs": [],
   "source": [
    "improvement_model = 'gpt-4o-mini-2024-07-18'\n",
    "frac_train = 1 #Do not modify this. Validation data will not be used anyways..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4aa08237-5c8d-4549-ae66-aecb1cf6ca95",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing file: gsm8k_llama2_weak.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 67, 562\n",
      "mean / median: 161.81, 148.5\n",
      "p5 / p95: 103.9, 222.0\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 17, 494\n",
      "mean / median: 83.35, 69.0\n",
      "p5 / p95: 43.0, 125.0\n",
      "**************************************************\n",
      "Processing file: gsm8k_llama2_ICL_mini.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 97, 595\n",
      "mean / median: 291.73, 275.5\n",
      "p5 / p95: 165.60000000000002, 446.5\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 50, 513\n",
      "mean / median: 213.27, 196.0\n",
      "p5 / p95: 105.9, 359.1\n",
      "**************************************************\n",
      "Processing file: gsm8k_gemma_weak.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 62, 628\n",
      "mean / median: 174.735, 150.5\n",
      "p5 / p95: 96.0, 241.29999999999998\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 10, 512\n",
      "mean / median: 96.275, 70.5\n",
      "p5 / p95: 34.900000000000006, 154.29999999999998\n",
      "**************************************************\n",
      "Processing file: gsm8k_gemma_ICL_mini.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 97, 2063\n",
      "mean / median: 297.895, 277.5\n",
      "p5 / p95: 167.9, 438.0\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 50, 1992\n",
      "mean / median: 219.435, 201.0\n",
      "p5 / p95: 110.9, 330.5999999999999\n",
      "**************************************************\n",
      "Processing file: gsm8k_mistral_weak.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 63, 316\n",
      "mean / median: 160.8, 158.0\n",
      "p5 / p95: 96.9, 234.1\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 18, 207\n",
      "mean / median: 82.34, 77.0\n",
      "p5 / p95: 40.0, 132.69999999999996\n",
      "**************************************************\n",
      "Processing file: gsm8k_mistral_ICL_mini.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 86, 2074\n",
      "mean / median: 298.47, 281.5\n",
      "p5 / p95: 166.0, 426.29999999999995\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 39, 2003\n",
      "mean / median: 220.01, 205.0\n",
      "p5 / p95: 108.9, 316.1999999999999\n",
      "**************************************************\n"
     ]
    }
   ],
   "source": [
    "CheckTokens(f'gsm8k_llama2_weak.jsonl')\n",
    "CheckTokens(f'gsm8k_llama2_ICL_mini.jsonl')\n",
    "CheckTokens(f'gsm8k_gemma_weak.jsonl')\n",
    "CheckTokens(f'gsm8k_gemma_ICL_mini.jsonl')\n",
    "CheckTokens(f'gsm8k_mistral_weak.jsonl')\n",
    "CheckTokens(f'gsm8k_mistral_ICL_mini.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "6868ed5f-830b-4378-a48d-2a38dc6b98fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "weak_fine_tune_job = FineTune('gsm8k_llama2_weak.jsonl')\n",
    "llama_weak_GPT_mini = GetFineTunedModelName(weak_fine_tune_job)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "42d44876-7b85-48ff-818d-15a82533bfb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "llama_ICL_fine_tune_job = FineTune('gsm8k_llama2_ICL_mini.jsonl')\n",
    "llama_ICL_gpt = GetFineTunedModelName(llama_ICL_fine_tune_job)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0faef555-e084-461a-9a8b-9ffaa6011a49",
   "metadata": {},
   "outputs": [],
   "source": [
    "gemma_weak_fine_tune_job = FineTune('gsm8k_gemma_weak.jsonl')\n",
    "gemma_weak_GPT_mini = GetFineTunedModelName(gemma_weak_fine_tune_job)\n",
    "print(gemma_weak_GPT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ef207e75-0978-4762-947b-658ec6ef23ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "gemma_ICL_fine_tune_job = FineTune('gsm8k_gemma_ICL_mini.jsonl')\n",
    "gemma_ICL_gpt_mini = GetFineTunedModelName(gemma_ICL_fine_tune_job)\n",
    "print(gemma_ICL_GPT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "824c06dc-af0e-46aa-a821-25e997140521",
   "metadata": {},
   "outputs": [],
   "source": [
    "mistral_weak_fine_tune_job = FineTune('gsm8k_mistral_weak.jsonl')\n",
    "mistral_weak_GPT = GetFineTunedModelName(mistral_weak_fine_tune_job)\n",
    "print(mistral_weak_GPT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "e96ebf77-6df6-48f5-894e-07abecc1160b",
   "metadata": {},
   "outputs": [],
   "source": [
    "mistral_ICL_fine_tune_job = FineTune('gsm8k_mistral_ICL_mini.jsonl')\n",
    "mistral_ICL_GPT_mini = GetFineTunedModelName(mistral_ICL_fine_tune_job)\n",
    "print(mistral_ICL_GPT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "238d35c9-551f-4ce3-a402-addb970c4223",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ft:gpt-3.5-turbo-0125:university-of-michigan::9u7BAH6O\n"
     ]
    }
   ],
   "source": [
    "print(llama_weak_GPT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "18e57c6b-65cd-4975-acf2-2c0b897be12b",
   "metadata": {},
   "outputs": [],
   "source": [
    "llama_ICL_id = \n",
    "llama_weak_id = \n",
    "gemma_weak_id  = \n",
    "gemma_ICL_id = \n",
    "mistral_weak_id = \n",
    "mistral_ICL_id = "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d676071-9aa8-41e3-998a-46b84d51ca3a",
   "metadata": {},
   "source": [
    "# Test Responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "fba24154-66b2-4e8b-81fb-a1d0f302d4bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('gsm8k.jsonl') as f:\n",
    "    test_data = [json.loads(line) for line in f]\n",
    "small_test = [test_data[i]['question'] for i in range(100)]\n",
    "small_test_key = [test_data[i]['solution'] for i in range(100)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "bfa7600e-1442-4e17-8ce4-839d799756df",
   "metadata": {},
   "outputs": [],
   "source": [
    "##Construct test data\n",
    "llama_ICL_gpt_test_ans = []\n",
    "llama_weak_gpt_test_ans = []\n",
    "gemma_weak_gpt_test_ans = []\n",
    "gemma_ICL_gpt_test_ans = []\n",
    "mistral_weak_gpt_test_ans = []\n",
    "mistral_ICL_gpt_test_ans = []\n",
    "baseline_ans = []\n",
    "for i, question in enumerate(small_test):\n",
    "    system_prompt = ''''''\n",
    "    user_prompt = question\n",
    "    key = small_test_key[i]\n",
    "    prompt = FormatInput(system_prompt, user_prompt,'gpt-4o-mini-2024-07-18')\n",
    "    llama_ICL_gpt_test_ans.append(QueryModel(prompt, llama_ICL_id, api='OPENAI'))\n",
    "    llama_weak_gpt_test_ans.append(QueryModel(prompt, llama_weak_id, api='OPENAI'))\n",
    "    gemma_weak_gpt_test_ans.append(QueryModel(prompt,gemma_weak_id,api = 'OPENAI'))\n",
    "    gemma_ICL_gpt_test_ans.append(QueryModel(prompt,gemma_ICL_id,api = 'OPENAI'))\n",
    "    mistral_weak_gpt_test_ans.append(QueryModel(prompt,mistral_weak_id,api = 'OPENAI'))\n",
    "    mistral_ICL_gpt_test_ans.append(QueryModel(prompt,mistral_ICL_id,api = 'OPENAI'))\n",
    "    baseline_ans.append(QueryModel(prompt, 'gpt-4o-mini-2024-07-18',api='OPENAI'))\n",
    "SaveJSONL(system_prompt, small_test, llama_ICL_gpt_test_ans, 'gsm8k_llama2_mini_ICL_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, llama_weak_gpt_test_ans, 'gsm8k_llama2_mini_weak_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, baseline_ans, 'gsm8k_baseline_mini_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, mistral_weak_gpt_test_ans, 'gsm8k_mistral_mini_weak_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, mistral_ICL_gpt_test_ans, 'gsm8k_mistral_mini_ICL_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, gemma_weak_gpt_test_ans, 'gsm8k_gemma_mini_weak_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, gemma_ICL_gpt_test_ans, 'gsm8k_gemma_mini_ICL_test.jsonl')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae1b1d22-c6b0-473b-83ce-1933f6a3f660",
   "metadata": {},
   "source": [
    "# Eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "6cf82689-a704-4423-acbc-f4654e2d6e69",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('gsm8k_llama2_mini_ICL_test.jsonl') as f:\n",
    "    llama_ICL_test_resp = [json.loads(line) for line in f]\n",
    "with open('gsm8k_llama2_mini_weak_test.jsonl') as f:\n",
    "    llama_weak_test_resp = [json.loads(line) for line in f]\n",
    "with open('gsm8k_baseline_mini_test.jsonl') as f:\n",
    "    baseline_resp = [json.loads(line) for line in f]\n",
    "with open('gsm8k_mistral_mini_weak_test.jsonl') as f:\n",
    "    mistral_weak_test_resp = [json.loads(line) for line in f]\n",
    "with open('gsm8k_mistral_mini_ICL_test.jsonl') as f:\n",
    "    mistral_ICL_test_resp = [json.loads(line) for line in f]\n",
    "with open('gsm8k_gemma_mini_weak_test.jsonl') as f:\n",
    "    gemma_weak_test_resp = [json.loads(line) for line in f]\n",
    "with open('gsm8k_gemma_mini_ICL_test.jsonl') as f:\n",
    "    gemma_ICL_test_resp = [json.loads(line) for line in f]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "64006bf9-2dc8-4f92-82cd-11581ae037b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "llama_ICL_scores, llama_weak_scores, baseline = [], [], []\n",
    "mistral_ICL_scores, mistral_weak_scores, gemma_ICL_scores, gemma_weak_scores = [], [], [], []\n",
    "gemma_weak_scores = []\n",
    "llama_rep_scores = []\n",
    "eval_model = 'gpt-4o'\n",
    "for i, testq in enumerate(small_test):\n",
    "    eval_sys = GetEvalSystemPrompt()\n",
    "    key = small_test_key[i]\n",
    "    eval_prompt_baseline = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, baseline_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    eval_prompt_llama_weak = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, llama_weak_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    eval_prompt_llama_rep = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, llama_weak_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    eval_prompt_llama_ICL = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, llama_ICL_test_resp[i]['messages'][2]['content']), model=eval_model)\n",
    "    try: #sometimes the gemma trained model fails to provide an answer\n",
    "        eval_prompt_gemma_weak = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, gemma_weak_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "        gemma_weak_scores.append(GPTEval(eval_prompt_gemma_weak, model=eval_model))\n",
    "    except: # if answer is blank score 0\n",
    "        print(i)\n",
    "        gemma_weak_scores.append(0)\n",
    "    eval_prompt_gemma_ICL = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, gemma_ICL_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    eval_prompt_mistral_weak = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, mistral_weak_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    eval_prompt_mistral_ICL = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, mistral_ICL_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    llama_ICL_scores.append(GPTEval(eval_prompt_llama_ICL, model=eval_model))\n",
    "    llama_weak_scores.append(GPTEval(eval_prompt_llama_weak, model=eval_model))\n",
    "    baseline.append(GPTEval(eval_prompt_baseline, model=eval_model))\n",
    "    mistral_ICL_scores.append(GPTEval(eval_prompt_mistral_ICL, model=eval_model))\n",
    "    mistral_weak_scores.append(GPTEval(eval_prompt_mistral_weak, model=eval_model))\n",
    "    gemma_ICL_scores.append(GPTEval(eval_prompt_gemma_ICL, model=eval_model))\n",
    "    gemma_weak_scores.append(GPTEval(eval_prompt_gemma_weak, model=eval_model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "8aa81454-f2ab-4606-8fed-4238fe744e99",
   "metadata": {},
   "outputs": [],
   "source": [
    "dictionary = {'baseline':baseline, 'lweak': llama_weak_scores, 'lICL': llama_ICL_scores, 'gweak': gemma_weak_scores, 'gICL:': gemma_ICL_scores, 'mweak': mistral_weak_scores, 'mICL': mistral_ICL_scores}\n",
    "np.save('gsm8k_scores_mini.npy', dictionary) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "edcdc912-183a-4519-9f42-3da5211d5474",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "baseline:[0.9135] +/- 0.05540496367655159\n",
      "lweak:[0.5195] +/- 0.09982880345872126\n",
      "lICL:[0.9065] +/- 0.05726525997496214\n",
      "gweak:[0.77460526] +/- 0.09660641280911364\n",
      "gICL:[0.92] +/- 0.054258639865002144\n",
      "mweak:[0.7265] +/- 0.08733218192625214\n",
      "mICL:[0.9295] +/- 0.05101166533254918\n"
     ]
    }
   ],
   "source": [
    "print('baseline:'+str(sum(baseline)/100)+' +/- '+str(2*np.std(baseline)/10))\n",
    "print('lweak:'+str(sum(llama_weak_scores)/100)+' +/- '+str(2*np.std(llama_weak_scores)/10))\n",
    "print('lICL:'+str(sum(llama_ICL_scores)/100)+' +/- '+str(2*np.std(llama_ICL_scores)/10))\n",
    "print('gweak:'+str(sum(gemma_weak_scores)/100)+' +/- '+str(2*np.std(gemma_weak_scores)/10))\n",
    "print('gICL:'+str(sum(gemma_ICL_scores)/100)+' +/- '+str(2*np.std(gemma_ICL_scores)/10))\n",
    "print('mweak:'+str(sum(mistral_weak_scores)/100)+' +/- '+str(2*np.std(mistral_weak_scores)/10))\n",
    "print('mICL:'+str(sum(mistral_ICL_scores)/100)+' +/- '+str(2*np.std(mistral_ICL_scores)/10))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
