{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "06651303-e344-4422-9fb7-3eb543e3f1ba",
   "metadata": {},
   "source": [
    "# ICL Step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d29e719a-2b65-496a-9e1f-4ecdadac4228",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from functions import *\n",
    "with open('math_train_1.json') as f:\n",
    "    gold_dat = json.load(f)\n",
    "with open('train_math_llama2_7b_full_weak.json') as f:\n",
    "    llama_dat = json.load(f)\n",
    "with open('train_math_gemma_2b_full_weak.json') as f:\n",
    "    gemma_dat = json.load(f)\n",
    "    #print(d)\n",
    "with open('train_math_mistral_7b_full_weak.json') as f:\n",
    "    mistral_dat = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97d6d381-cbc5-4bb3-8668-674bd98d76e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "gold_quest = []\n",
    "gold_answers = []\n",
    "for i in range(500):\n",
    "    dct = gold_dat[i]\n",
    "    gold_quest.append(dct['content'])\n",
    "    gold_answers.append(dct['output'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e31ecd58-d5a7-4e3e-99b1-ac89a45dac0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_prompt = ''''''\n",
    "SaveJSONL(system_prompt, gold_quest, gold_answers, 'MATH_gold.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39f90be5-aba2-4921-abc0-7f1926e755ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "#getting gpt-3.5-turbo data\n",
    "training_questions = [llama_dat[i]['content'] for i in range(0, 203)]\n",
    "gpt_ans = []\n",
    "for question in training_questions:\n",
    "    system_prompt = ''''''\n",
    "    gpt_prompt = FormatInput(system_prompt, question, gpt_35_teacher_id)\n",
    "    gpt_ans.append(QueryModel(gpt_prompt, gpt_35_teacher_id, api='OPENAI'))\n",
    "SaveJSONL(system_prompt, training_questions, gpt_ans, 'gpt_3.5_weak.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "79ed52eb-2fcf-419c-a64c-bdfa76a9a52a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('gpt_3.5_weak.jsonl') as f:\n",
    "    gpt_data = [json.loads(line) for line in f]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "40673744-f1b5-41f9-a1b0-79f233f97ee1",
   "metadata": {},
   "outputs": [],
   "source": [
    "llama_icl_examples, gemma_icl_examples, mistral_icl_examples, gpt_ICL_examples = [], [], [], []\n",
    "for i in [0,2]:\n",
    "    llama_icl_examples.append(llama_dat[i]['content']+llama_dat[i]['output'])\n",
    "    gemma_icl_examples.append(gemma_dat[i]['content']+gemma_dat[i]['output'])\n",
    "    mistral_icl_examples.append(mistral_dat[i]['content']+mistral_dat[i]['output'])\n",
    "    gpt_ICL_examples.append(gpt_data[0]['messages'][1]['content']+gpt_data[0]['messages'][2]['content'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "2ebb30e6-a220-43fa-955d-736f858419fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# getting the training examples, I use the first 200 samples exluding the icl examples \n",
    "training_questions = [llama_dat[i]['content'] for i in range(3, 203)]\n",
    "llama_ans = [llama_dat[i]['output'] for i in range(3, 203)]\n",
    "gemma_ans = [gemma_dat[i]['output'] for i in range(3, 203)]\n",
    "mistral_ans = [mistral_dat[i]['output'] for i in range(3, 203)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a403b476-133e-4d06-ba82-2d5bf20024e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#getting ICL data\n",
    "improvement_model = 'gpt-4o-mini-2024-07-18'\n",
    "ICL_llama_ans = []\n",
    "ICL_gemma_ans = []\n",
    "ICL_mistral_ans = []\n",
    "for question in training_questions:\n",
    "    system_prompt = ''''''\n",
    "    llama_ICL_prompt = '\\n\\n'.join(llama_icl_examples)\n",
    "    gemma_ICL_prompt = '\\n\\n'.join(gemma_icl_examples)\n",
    "    mistral_ICL_prompt = '\\n\\n'.join(mistral_icl_examples)\n",
    "    llama_prompt = llama_ICL_prompt+question\n",
    "    gemma_prompt = gemma_ICL_prompt+question\n",
    "    mistral_prompt = mistral_ICL_prompt+question\n",
    "    llama_prompt = FormatInput(system_prompt, llama_prompt, improvement_model)\n",
    "    gemma_prompt = FormatInput(system_prompt, gemma_prompt, improvement_model)\n",
    "    mistral_prompt = FormatInput(system_prompt, mistral_prompt, improvement_model)\n",
    "    ICL_llama_ans.append(QueryModel(llama_prompt, improvement_model, api='OPENAI'))\n",
    "    ICL_gemma_ans.append(QueryModel(gemma_prompt, improvement_model, api='OPENAI'))\n",
    "    ICL_mistral_ans.append(QueryModel(mistral_prompt, improvement_model, api='OPENAI'))\n",
    "SaveJSONL(system_prompt, training_questions, ICL_llama_ans, 'math_llama2_mini_ICL.jsonl')\n",
    "SaveJSONL(system_prompt, training_questions, llama_ans, 'math_llama2_mini_weak.jsonl')\n",
    "SaveJSONL(system_prompt, training_questions, ICL_gemma_ans, 'math_gemma_mini_ICL.jsonl')\n",
    "SaveJSONL(system_prompt, training_questions, gemma_ans, 'math_gemma_mini_weak.jsonl')\n",
    "SaveJSONL(system_prompt, training_questions, ICL_mistral_ans, 'math_mistral_mini_ICL.jsonl')\n",
    "SaveJSONL(system_prompt, training_questions, mistral_ans, 'math_mistral_mini_weak.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "874f2144-2146-4a40-9ac0-e45845404032",
   "metadata": {},
   "outputs": [],
   "source": [
    "# getting gpt ICL data\n",
    "improvement_model = 'gpt-4o-mini-2024-07-18'\n",
    "ICL_gpt_ans = []\n",
    "for question in training_questions:\n",
    "    system_prompt = ''''''\n",
    "    gpt_ICL_prompt = '\\n\\n'.join(gpt_ICL_examples)\n",
    "    gpt_prompt = gpt_ICL_prompt+question\n",
    "    gpt_prompt = FormatInput(system_prompt, gpt_prompt, improvement_model)\n",
    "    ICL_gpt_ans.append(QueryModel(gpt_prompt, improvement_model, api='OPENAI'))\n",
    "SaveJSONL(system_prompt, training_questions, ICL_gpt_ans, 'math_gpt_mini_ICL.jsonl')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ade7f18d-84ec-49ad-9796-014c8fed6410",
   "metadata": {},
   "source": [
    "# FineTuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "351a7518-db67-4060-a017-72b3e6b9a7ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "improvement_model = 'gpt-4o-mini-2024-07-18'\n",
    "frac_train = 1 #Do not modify this. Validation data will not be used anyways..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "bf335e0a-9b24-4826-97ad-672380e71dc3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing file: math_llama2_mini_weak.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 70, 2379\n",
      "mean / median: 601.35, 238.0\n",
      "p5 / p95: 112.80000000000001, 2050.6\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 32, 2282\n",
      "mean / median: 512.3, 132.0\n",
      "p5 / p95: 55.0, 1963.0\n",
      "**************************************************\n",
      "Processing file: math_llama2_mini_ICL.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 133, 2117\n",
      "mean / median: 655.78, 565.5\n",
      "p5 / p95: 267.0, 1148.2\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 100, 1999\n",
      "mean / median: 566.73, 471.5\n",
      "p5 / p95: 201.9, 1014.8\n",
      "**************************************************\n",
      "Processing file: math_gemma_mini_weak.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 59, 2553\n",
      "mean / median: 512.315, 207.0\n",
      "p5 / p95: 103.0, 1953.4\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 15, 2434\n",
      "mean / median: 423.265, 125.0\n",
      "p5 / p95: 55.900000000000006, 1847.6\n",
      "**************************************************\n",
      "Processing file: math_gemma_mini_ICL.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 92, 2069\n",
      "mean / median: 626.195, 559.0\n",
      "p5 / p95: 248.20000000000002, 1089.1999999999998\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 58, 2000\n",
      "mean / median: 537.145, 461.0\n",
      "p5 / p95: 195.3, 961.8\n",
      "**************************************************\n",
      "Processing file: math_mistral_mini_weak.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 69, 2097\n",
      "mean / median: 292.01, 213.5\n",
      "p5 / p95: 102.0, 512.1\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 35, 1956\n",
      "mean / median: 202.96, 133.5\n",
      "p5 / p95: 54.900000000000006, 405.0\n",
      "**************************************************\n",
      "Processing file: math_mistral_mini_ICL.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 144, 2070\n",
      "mean / median: 657.825, 587.0\n",
      "p5 / p95: 280.7, 1157.7\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 111, 2001\n",
      "mean / median: 568.775, 483.0\n",
      "p5 / p95: 212.8, 1043.2\n",
      "**************************************************\n",
      "Processing file: math_gpt_mini_ICL.jsonl\n",
      "\n",
      "#### Distribution of total tokens:\n",
      "min / max: 134, 2160\n",
      "mean / median: 688.525, 597.5\n",
      "p5 / p95: 271.8, 1203.8999999999999\n",
      "\n",
      "#### Distribution of assistant tokens:\n",
      "min / max: 100, 2001\n",
      "mean / median: 599.475, 506.0\n",
      "p5 / p95: 220.20000000000002, 1095.3\n",
      "**************************************************\n"
     ]
    }
   ],
   "source": [
    "CheckTokens(f'math_llama2_mini_weak.jsonl')\n",
    "CheckTokens(f'math_llama2_mini_ICL.jsonl')\n",
    "CheckTokens(f'math_gemma_mini_weak.jsonl')\n",
    "CheckTokens(f'math_gemma_mini_ICL.jsonl')\n",
    "CheckTokens(f'math_mistral_mini_weak.jsonl')\n",
    "CheckTokens(f'math_mistral_mini_ICL.jsonl')\n",
    "CheckTokens(f'math_gpt_mini_ICL.jsonl')\n",
    "#CheckTokens(f'MATH_gold.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4c9ed305-42b1-4344-9245-1d826088d2ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "#math_llama_weak_fine_tune_job = FineTune('math_llama2_mini_weak.jsonl')\n",
    "math_llama_mini_weak_id = "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9ca15103-62dc-40c0-a579-97a17ac2f4a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#math_llama_ICL_fine_tune_job = FineTune('math_llama2_mini_ICL.jsonl')\n",
    "math_llama_mini_ICL_id = "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "16802e5b-19e9-4ae5-a1ce-71220926c173",
   "metadata": {},
   "outputs": [],
   "source": [
    "#math_gemma_weak_fine_tune_job = FineTune('math_gemma_mini_weak.jsonl')\n",
    "math_gemma_mini_weak_id = "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c371eebc-41cd-454a-ae22-44adbd597da3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#math_gemma_ICL_fine_tune_job = FineTune('math_gemma_mini_ICL.jsonl')\n",
    "math_gemma_mini_ICL_id = "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "32640c49-31c1-4609-86d3-18e3cf665d35",
   "metadata": {},
   "outputs": [],
   "source": [
    "#math_mistral_weak_fine_tune_job = FineTune('math_mistral_mini_weak.jsonl')\n",
    "math_mistral_mini_weak_id = "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "edb0a0f1-8f5d-444d-bcec-801ca1c705ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "#math_mistral_ICL_fine_tune_job = FineTune('math_mistral_mini_ICL.jsonl')\n",
    "math_mistral_mini_ICL_id = "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "7da7f64b-9c75-4610-865e-dd5a626cd576",
   "metadata": {},
   "outputs": [],
   "source": [
    "#math_gpt_ICL_fine_tune_job = FineTune('math_gpt_mini_ICL.jsonl')\n",
    "math_gpt_mini_ICL_id = "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbc000d9-c064-41f8-b1a8-e0f59ceb7906",
   "metadata": {},
   "source": [
    "# Test Responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "df780819-6798-4af6-9874-bcbb6c2eda53",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('test.jsonl') as f:\n",
    "    test_data = [json.loads(line) for line in f]\n",
    "#test_data[0]\n",
    "small_test = [test_data[i]['problem'] for i in range(100)]\n",
    "small_test_key = [test_data[i]['solution'] for i in range(100)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "89e4bb8b-6bc1-4b7d-aeac-b4aace6fac92",
   "metadata": {},
   "outputs": [],
   "source": [
    "##Construct test data\n",
    "llama_ICL_gpt_test_ans = []\n",
    "llama_weak_gpt_test_ans = []\n",
    "gemma_weak_gpt_test_ans = []\n",
    "gemma_ICL_gpt_test_ans = []\n",
    "mistral_weak_gpt_test_ans = []\n",
    "mistral_ICL_gpt_test_ans = []\n",
    "baseline_ans = []\n",
    "for i, question in enumerate(small_test):\n",
    "    system_prompt = ''''''\n",
    "    #print(question)\n",
    "    user_prompt = question\n",
    "    key = small_test_key[i]\n",
    "    prompt = FormatInput(system_prompt, user_prompt,'gpt-4o-mini')\n",
    "    llama_ICL_gpt_test_ans.append(QueryModel(prompt, math_llama_mini_ICL_id, api='OPENAI'))\n",
    "    llama_weak_gpt_test_ans.append(QueryModel(prompt, math_llama_mini_weak_id, api='OPENAI'))\n",
    "    gemma_ICL_gpt_test_ans.append(QueryModel(prompt, math_gemma_mini_ICL_id,api = 'OPENAI'))\n",
    "    gemma_weak_gpt_test_ans.append(QueryModel(prompt, math_gemma_mini_weak_id,api = 'OPENAI'))\n",
    "    mistral_weak_gpt_test_ans.append(QueryModel(prompt, math_mistral_mini_weak_id,api = 'OPENAI'))\n",
    "    mistral_ICL_gpt_test_ans.append(QueryModel(prompt, math_mistral_mini_ICL_id,api = 'OPENAI'))\n",
    "    baseline_ans.append(QueryModel(prompt, 'gpt-4o-mini-2024-07-18',api='OPENAI'))\n",
    "SaveJSONL(system_prompt, small_test, llama_ICL_gpt_test_ans, 'math_llama2_mini_ICL_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, llama_weak_gpt_test_ans, 'math_llama2_mini_weak_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, baseline_ans, 'math_mini_baseline_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, mistral_weak_gpt_test_ans, 'math_mistral_mini_weak_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, mistral_ICL_gpt_test_ans, 'math_mistral_mini_ICL_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, gemma_weak_gpt_test_ans, 'math_gemma_mini_weak_test.jsonl')\n",
    "SaveJSONL(system_prompt, small_test, gemma_ICL_gpt_test_ans, 'math_gemma_mini_ICL_test.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "c8887dfb-46ae-414a-8473-bd1380461216",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'gemma_ICL_gpt_test_ans' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[34], line 9\u001b[0m\n\u001b[1;32m      7\u001b[0m     prompt \u001b[38;5;241m=\u001b[39m FormatInput(system_prompt, user_prompt,\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgpt-4o-mini\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m      8\u001b[0m     gpt_ICL_mini_test_ans\u001b[38;5;241m.\u001b[39mappend(QueryModel(prompt, math_gpt_mini_ICL_id, api\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mOPENAI\u001b[39m\u001b[38;5;124m'\u001b[39m))\n\u001b[0;32m----> 9\u001b[0m SaveJSONL(system_prompt, small_test, \u001b[43mgemma_ICL_gpt_test_ans\u001b[49m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmath_gpt_mini_ICL_test.jsonl\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'gemma_ICL_gpt_test_ans' is not defined"
     ]
    }
   ],
   "source": [
    "gpt_ICL_mini_test_ans = []\n",
    "for i, question in enumerate(small_test):\n",
    "    system_prompt = ''''''\n",
    "    #print(question)\n",
    "    user_prompt = question\n",
    "    key = small_test_key[i]\n",
    "    prompt = FormatInput(system_prompt, user_prompt,'gpt-4o-mini')\n",
    "    gpt_ICL_mini_test_ans.append(QueryModel(prompt, math_gpt_mini_ICL_id, api='OPENAI'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "e8bda435-a279-4bbd-9bd1-5082056f0d2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "SaveJSONL(system_prompt, small_test, gpt_ICL_mini_test_ans, 'math_gpt_mini_ICL_test.jsonl')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7eebf1bd-4a32-4240-b19d-4be471d21d51",
   "metadata": {},
   "source": [
    "# Eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "53036d87-ec50-4a79-9b90-abb0ae9150f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('math_llama2_mini_ICL_test.jsonl') as f:\n",
    "    llama_ICL_test_resp = [json.loads(line) for line in f]\n",
    "with open('math_llama2_mini_weak_test.jsonl') as f:\n",
    "    llama_weak_test_resp = [json.loads(line) for line in f]\n",
    "with open('math_mini_baseline_test.jsonl') as f:\n",
    "    baseline_resp = [json.loads(line) for line in f]\n",
    "with open('math_mistral_mini_weak_test.jsonl') as f:\n",
    "    mistral_weak_test_resp = [json.loads(line) for line in f]\n",
    "with open('math_mistral_mini_ICL_test.jsonl') as f:\n",
    "    mistral_ICL_test_resp = [json.loads(line) for line in f]\n",
    "with open('math_gemma_mini_weak_test.jsonl') as f:\n",
    "    gemma_weak_test_resp = [json.loads(line) for line in f]\n",
    "with open('math_gemma_mini_ICL_test.jsonl') as f:\n",
    "    gemma_ICL_test_resp = [json.loads(line) for line in f]\n",
    "with open('math_gpt_mini_ICL_test.jsonl') as f:\n",
    "    gpt_ICL_test_resp = [json.loads(line) for line in f]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "b52b4263-175d-4e86-ac43-b071aafc5868",
   "metadata": {},
   "outputs": [],
   "source": [
    "llama_ICL_scores, llama_weak_scores, baseline = [], [], []\n",
    "mistral_ICL_scores, mistral_weak_scores, gemma_ICL_scores, gemma_weak_scores = [], [], [], []\n",
    "eval_model = 'gpt-4o'\n",
    "for i, testq in enumerate(small_test):\n",
    "    eval_sys = GetEvalSystemPrompt()\n",
    "    key = small_test_key[i]\n",
    "    eval_prompt_baseline = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, baseline_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    eval_prompt_llama_weak = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, llama_weak_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    eval_prompt_llama_ICL = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, llama_ICL_test_resp[i]['messages'][2]['content']), model=eval_model)\n",
    "    try: #sometimes the gemma trained model fails to provide an answer\n",
    "        eval_prompt_gemma_weak = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, gemma_weak_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "        gemma_weak_scores.append(GPTEval(eval_prompt_gemma_weak, model=eval_model))\n",
    "    except: # if answer is blank score 0\n",
    "        print(i)\n",
    "        gemma_weak_scores.append(0)\n",
    "    eval_prompt_gemma_ICL = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, gemma_ICL_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    eval_prompt_mistral_weak = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, mistral_weak_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    eval_prompt_mistral_ICL = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, mistral_ICL_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    llama_ICL_scores.append(GPTEval(eval_prompt_llama_ICL, model=eval_model))\n",
    "    llama_weak_scores.append(GPTEval(eval_prompt_llama_weak, model=eval_model))\n",
    "    baseline.append(GPTEval(eval_prompt_baseline, model=eval_model))\n",
    "    mistral_ICL_scores.append(GPTEval(eval_prompt_mistral_ICL, model=eval_model))\n",
    "    mistral_weak_scores.append(GPTEval(eval_prompt_mistral_weak, model=eval_model))\n",
    "    gemma_ICL_scores.append(GPTEval(eval_prompt_gemma_ICL, model=eval_model))\n",
    "    #gemma_weak_scores.append(GPTEval(eval_prompt_gemma_weak, model=eval_model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "456543be-1f00-416f-b224-11d2fb695f5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_model = 'gpt-4o'\n",
    "scores = []\n",
    "for i, testq in enumerate(small_test):\n",
    "    eval_sys = GetEvalSystemPrompt()\n",
    "    key = small_test_key[i]\n",
    "    eval_prompt_gpt = FormatInput(eval_sys, GetEvalUserPrompt(testq, key, gpt_ICL_test_resp[i]['messages'][2]['content']), model = eval_model)\n",
    "    scores.append(GPTEval(eval_prompt_gpt, model=eval_model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "c65911a9-af4a-4fad-9d46-642c7b586cfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "arrays = [llama_ICL_scores, llama_weak_scores, baseline, mistral_ICL_scores, mistral_weak_scores, gemma_ICL_scores, gemma_weak_scores]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "4ce367d7-dc53-40e6-9445-62009447407c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "baseline:[0.78344444] +/- 0.07967046016072847\n",
      "lweak:[0.06789474] +/- 0.04964795733552417\n",
      "lICL:[0.7415] +/- 0.08618648385912955\n",
      "gweak:[0.30013158] +/- 0.0893730509928318\n",
      "gICL:[0.7395] +/- 0.08575890624302526\n",
      "mweak:[0.35194737] +/- 0.09159445180308703\n",
      "mICL:[0.7295] +/- 0.08616373947316817\n"
     ]
    }
   ],
   "source": [
    "print('baseline:'+str(sum(baseline)/100)+' +/- '+str(2*np.std(baseline)/10))\n",
    "print('lweak:'+str(sum(llama_weak_scores)/100)+' +/- '+str(2*np.std(llama_weak_scores)/10))\n",
    "print('lICL:'+str(sum(llama_ICL_scores)/100)+' +/- '+str(2*np.std(llama_ICL_scores)/10))\n",
    "print('gweak:'+str(sum(gemma_weak_scores)/100)+' +/- '+str(2*np.std(gemma_weak_scores)/10))\n",
    "print('gICL:'+str(sum(gemma_ICL_scores)/100)+' +/- '+str(2*np.std(gemma_ICL_scores)/10))\n",
    "print('mweak:'+str(sum(mistral_weak_scores)/100)+' +/- '+str(2*np.std(mistral_weak_scores)/10))\n",
    "print('mICL:'+str(sum(mistral_ICL_scores)/100)+' +/- '+str(2*np.std(mistral_ICL_scores)/10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "5cc8e7fe-8fb2-4456-9b39-fd5cfe49d73b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dictionary = {'baseline':baseline, 'lweak': llama_weak_scores, 'lICL': llama_ICL_scores, 'gweak': gemma_weak_scores, 'gICL:': gemma_ICL_scores, 'mweak': mistral_weak_scores, 'mICL': mistral_ICL_scores}\n",
    "np.save('math_scores_mini.npy', dictionary) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "5c11fcbd-3a8c-4f43-8264-40aa656438c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.74002632]\n"
     ]
    }
   ],
   "source": [
    "print(sum(scores)/100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a58960d4-e4e0-428f-bfb6-20dc171b1b6a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
