{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "7cc7b9d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import json\n",
    "import sklearn as sk\n",
    "import matplotlib as plt\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "7d9b5316",
   "metadata": {},
   "outputs": [],
   "source": [
    "def confusion_matrix(data, prompt):\n",
    "    data = data[data[\"PromptType\"]==prompt]\n",
    "    return np.array([[len(data[data[\"model_bruteforce\"]==\"1\"][data[\"human_bruteforce\"]==\"1\"]), len(data[data[\"model_bruteforce\"]==\"1\"][data[\"human_bruteforce\"]==\"0\"])], [len(data[data[\"model_bruteforce\"]==\"0\"][data[\"human_bruteforce\"]==\"1\"]), len(data[data[\"model_bruteforce\"]==\"0\"][data[\"human_bruteforce\"]==\"0\"])]])\n",
    "\n",
    "def correctness(data, prompt):\n",
    "    data = data[data[\"PromptType\"]==prompt]\n",
    "    return len(data[data[\"correctness\"]==\"1\"]) / len(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "1ac9523f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Adequate: 190\n",
      "Inadequate: 55\n",
      "Model o3 Prompt: basicprompt 83.5/90.2\n",
      "Model o3 Prompt: mathPrompt 84.8/85.4\n",
      "Model o3 Prompt: hint_prompt 88.1/77.5\n",
      "Model o3 Prompt: combinedhintPrompt 88.2/80.0\n",
      "Adequate: 158\n",
      "Inadequate: 65\n",
      "Model GeminiFlash Prompt: basicprompt 59.6/48.0\n",
      "Model GeminiFlash Prompt: mathPrompt 59.1/59.2\n",
      "Model GeminiFlash Prompt: hint_prompt 60.5/60.8\n",
      "Model GeminiFlash Prompt: combinedhintPrompt 64.3/70.5\n",
      "Adequate: 11\n",
      "Inadequate: 239\n",
      "Model Qwen1 Prompt: basicprompt 36.4/2.5\n",
      "Model Qwen1 Prompt: mathPrompt 9.1/3.8\n",
      "Model Qwen1 Prompt: hintPrompt 36.4/5.4\n",
      "Model Qwen1 Prompt: combinedhintPrompt 9.1/3.4\n",
      "Adequate: 100\n",
      "Inadequate: 150\n",
      "Model Qwen14 Prompt: basicprompt 38.0/11.3\n",
      "Model Qwen14 Prompt: mathPrompt 41.0/12.0\n",
      "Model Qwen14 Prompt: hintPrompt 45.0/15.3\n",
      "Model Qwen14 Prompt: combinedhintPrompt 38.0/18.0\n",
      "Adequate: 111\n",
      "Inadequate: 139\n",
      "Model Qwen70 Prompt: basicprompt 33.3/17.3\n",
      "Model Qwen70 Prompt: mathPrompt 36.0/15.1\n",
      "Model Qwen70 Prompt: hintPrompt 39.6/15.1\n",
      "Model Qwen70 Prompt: combinedhintPrompt 42.3/18.7\n",
      "Adequate: 183\n",
      "Inadequate: 67\n",
      "Model DSChat Prompt: basicprompt 41.0/28.8\n",
      "Model DSChat Prompt: mathPrompt 45.4/28.4\n",
      "Model DSChat Prompt: hintPrompt 47.0/26.9\n",
      "Model DSChat Prompt: combinedhintPrompt 47.0/25.8\n",
      "Adequate: 192\n",
      "Inadequate: 58\n",
      "Model DSReason Prompt: basicprompt 47.6/34.5\n",
      "Model DSReason Prompt: mathPrompt 46.1/43.1\n",
      "Model DSReason Prompt: hintPrompt 53.6/35.1\n",
      "Model DSReason Prompt: combinedhintPrompt 53.1/42.1\n"
     ]
    }
   ],
   "source": [
    "data = {}\n",
    "summary = {}\n",
    "models = ['o3', 'GeminiFlash', 'Qwen1', 'Qwen14', 'Qwen70', 'DSChat', 'DSReason']\n",
    "for model in models:\n",
    "    file_path = f\"response_evaluation/Logic/SolutionSummary-{model}/resultsEvaluations_evaluatedbyo3-2025-04-16.jsonl\"\n",
    "    with open(file_path, 'r') as file:\n",
    "        data[model] = [json.loads(line) for line in file]\n",
    "    \n",
    "    with open(f\"response_evaluation/Logic/LogicAll-{model}{'Batch' if model == 'o3' else ''}/resultsEvaluations_evaluatedbyo3-2025-04-16.jsonl\", 'r') as file:\n",
    "        correctnessData = [json.loads(line) for line in file]\n",
    "    \n",
    "    prompts = [\"basicprompt\", \"mathPrompt\", \"hintPrompt\", \"hint_prompt\", \"combinedhintPrompt\"]\n",
    "    summaries = pd.DataFrame(data[model])\n",
    "    correctness = pd.DataFrame(correctnessData)\n",
    "    # print(correctness.head())\n",
    "    data[model] = correctness\n",
    "    summary = {}\n",
    "    # print(model, prompts)\n",
    "    \n",
    "    print(\"Adequate:\", len(summaries[summaries[\"Summary\"]==\"1\"]))\n",
    "    print(\"Inadequate:\", len(summaries[summaries[\"Summary\"]==\"0\"]))\n",
    "    \n",
    "    for index, row in summaries.iterrows():\n",
    "        summary[row[\"ID\"]] = int(row[\"Summary\"])\n",
    "\n",
    "    for prompt in prompts:\n",
    "        bscores = [[0, 0], [0, 0]]\n",
    "        cscores = [[0, 0], [0, 0]]\n",
    "        for index, row in correctness.iterrows():\n",
    "            if row[\"PromptType\"] == prompt:\n",
    "                if (type(row[\"Response\"]) != str or row[\"Response\"] == None or row[\"Response\"] == \"NaN\" or row[\"Response\"] == \"None\" or row[\"Response\"] == \"\" or row[\"model_bruteforce\"] == \"NULL\" or row[\"Response\"] is str and row[\"Response\"].isspace()):\n",
    "                        continue\n",
    "                try:\n",
    "                    cscores[summary[int(row[\"ID\"])]][int(row[\"correctness\"])] += 1\n",
    "                    bscores[summary[int(row[\"ID\"])]][int(row[\"model_bruteforce\"])] += 1\n",
    "                except:\n",
    "                    # print(\"failed\", row[\"ID\"], row[\"correctness\"], row[\"model_bruteforce\"])\n",
    "                    pass\n",
    "        # print(cscores, bscores)\n",
    "        try:\n",
    "            print(\"Model\", model, \"Prompt:\", prompt, \n",
    "            f'{round(100*cscores[1][1]/(cscores[1][0]+cscores[1][1]), 1)}/{round(100*cscores[0][1]/(cscores[0][0]+cscores[0][1]), 1)}',\n",
    "            # f'{round(100*bscores[1][1]/(bscores[1][0]+bscores[1][1]), 1)}/{round(100*bscores[0][1]/(bscores[0][0]+bscores[0][1]), 1)}',\n",
    "              )\n",
    "        except:\n",
    "            pass\n",
    "    # for j in range(4)\n",
    "    #     print(round(100*scores[0][j][1]/(scores[0][j][0]), 1), round(100*scores[1][j][1]/(scores[1][j][0]), 1)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "8a722452",
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "'human_bruteforce'",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mKeyError\u001b[39m                                  Traceback (most recent call last)",
      "\u001b[36mFile \u001b[39m\u001b[32m~/anaconda3/envs/Brainteasers/lib/python3.13/site-packages/pandas/core/indexes/base.py:3805\u001b[39m, in \u001b[36mIndex.get_loc\u001b[39m\u001b[34m(self, key)\u001b[39m\n\u001b[32m   3804\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m3805\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_engine\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcasted_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   3806\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
      "\u001b[36mFile \u001b[39m\u001b[32mindex.pyx:167\u001b[39m, in \u001b[36mpandas._libs.index.IndexEngine.get_loc\u001b[39m\u001b[34m()\u001b[39m\n",
      "\u001b[36mFile \u001b[39m\u001b[32mindex.pyx:196\u001b[39m, in \u001b[36mpandas._libs.index.IndexEngine.get_loc\u001b[39m\u001b[34m()\u001b[39m\n",
      "\u001b[36mFile \u001b[39m\u001b[32mpandas/_libs/hashtable_class_helper.pxi:7081\u001b[39m, in \u001b[36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[39m\u001b[34m()\u001b[39m\n",
      "\u001b[36mFile \u001b[39m\u001b[32mpandas/_libs/hashtable_class_helper.pxi:7089\u001b[39m, in \u001b[36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[39m\u001b[34m()\u001b[39m\n",
      "\u001b[31mKeyError\u001b[39m: 'human_bruteforce'",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[31mKeyError\u001b[39m                                  Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[74]\u001b[39m\u001b[32m, line 9\u001b[39m\n\u001b[32m      7\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(prompts)):\n\u001b[32m      8\u001b[39m     prompt = prompts[i]\n\u001b[32m----> \u001b[39m\u001b[32m9\u001b[39m     cmd = sk.metrics.ConfusionMatrixDisplay(\u001b[43mconfusion_matrix\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfulldata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprompt\u001b[49m\u001b[43m)\u001b[49m, display_labels=[\u001b[33m\"\u001b[39m\u001b[33mUsed Brute Force\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mDid Not Use Brute Force\u001b[39m\u001b[33m\"\u001b[39m])\n\u001b[32m     10\u001b[39m     cmd.plot()\n\u001b[32m     11\u001b[39m     cmd.ax_.set(xlabel=\u001b[33m'\u001b[39m\u001b[33mHuman Solution\u001b[39m\u001b[33m'\u001b[39m, ylabel=\u001b[33m'\u001b[39m\u001b[33mModel Solution\u001b[39m\u001b[33m'\u001b[39m)\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[72]\u001b[39m\u001b[32m, line 3\u001b[39m, in \u001b[36mconfusion_matrix\u001b[39m\u001b[34m(data, prompt)\u001b[39m\n\u001b[32m      1\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mconfusion_matrix\u001b[39m(data, prompt):\n\u001b[32m      2\u001b[39m     data = data[data[\u001b[33m\"\u001b[39m\u001b[33mPromptType\u001b[39m\u001b[33m\"\u001b[39m]==prompt]\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m np.array([[\u001b[38;5;28mlen\u001b[39m(data[data[\u001b[33m\"\u001b[39m\u001b[33mmodel_bruteforce\u001b[39m\u001b[33m\"\u001b[39m]==\u001b[33m\"\u001b[39m\u001b[33m1\u001b[39m\u001b[33m\"\u001b[39m][\u001b[43mdata\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mhuman_bruteforce\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m==\u001b[33m\"\u001b[39m\u001b[33m1\u001b[39m\u001b[33m\"\u001b[39m]), \u001b[38;5;28mlen\u001b[39m(data[data[\u001b[33m\"\u001b[39m\u001b[33mmodel_bruteforce\u001b[39m\u001b[33m\"\u001b[39m]==\u001b[33m\"\u001b[39m\u001b[33m1\u001b[39m\u001b[33m\"\u001b[39m][data[\u001b[33m\"\u001b[39m\u001b[33mhuman_bruteforce\u001b[39m\u001b[33m\"\u001b[39m]==\u001b[33m\"\u001b[39m\u001b[33m0\u001b[39m\u001b[33m\"\u001b[39m])], [\u001b[38;5;28mlen\u001b[39m(data[data[\u001b[33m\"\u001b[39m\u001b[33mmodel_bruteforce\u001b[39m\u001b[33m\"\u001b[39m]==\u001b[33m\"\u001b[39m\u001b[33m0\u001b[39m\u001b[33m\"\u001b[39m][data[\u001b[33m\"\u001b[39m\u001b[33mhuman_bruteforce\u001b[39m\u001b[33m\"\u001b[39m]==\u001b[33m\"\u001b[39m\u001b[33m1\u001b[39m\u001b[33m\"\u001b[39m]), \u001b[38;5;28mlen\u001b[39m(data[data[\u001b[33m\"\u001b[39m\u001b[33mmodel_bruteforce\u001b[39m\u001b[33m\"\u001b[39m]==\u001b[33m\"\u001b[39m\u001b[33m0\u001b[39m\u001b[33m\"\u001b[39m][data[\u001b[33m\"\u001b[39m\u001b[33mhuman_bruteforce\u001b[39m\u001b[33m\"\u001b[39m]==\u001b[33m\"\u001b[39m\u001b[33m0\u001b[39m\u001b[33m\"\u001b[39m])]])\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/anaconda3/envs/Brainteasers/lib/python3.13/site-packages/pandas/core/frame.py:4102\u001b[39m, in \u001b[36mDataFrame.__getitem__\u001b[39m\u001b[34m(self, key)\u001b[39m\n\u001b[32m   4100\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.columns.nlevels > \u001b[32m1\u001b[39m:\n\u001b[32m   4101\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._getitem_multilevel(key)\n\u001b[32m-> \u001b[39m\u001b[32m4102\u001b[39m indexer = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mcolumns\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   4103\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m is_integer(indexer):\n\u001b[32m   4104\u001b[39m     indexer = [indexer]\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/anaconda3/envs/Brainteasers/lib/python3.13/site-packages/pandas/core/indexes/base.py:3812\u001b[39m, in \u001b[36mIndex.get_loc\u001b[39m\u001b[34m(self, key)\u001b[39m\n\u001b[32m   3807\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(casted_key, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[32m   3808\u001b[39m         \u001b[38;5;28misinstance\u001b[39m(casted_key, abc.Iterable)\n\u001b[32m   3809\u001b[39m         \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28many\u001b[39m(\u001b[38;5;28misinstance\u001b[39m(x, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m casted_key)\n\u001b[32m   3810\u001b[39m     ):\n\u001b[32m   3811\u001b[39m         \u001b[38;5;28;01mraise\u001b[39;00m InvalidIndexError(key)\n\u001b[32m-> \u001b[39m\u001b[32m3812\u001b[39m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01merr\u001b[39;00m\n\u001b[32m   3813\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[32m   3814\u001b[39m     \u001b[38;5;66;03m# If we have a listlike key, _check_indexing_error will raise\u001b[39;00m\n\u001b[32m   3815\u001b[39m     \u001b[38;5;66;03m#  InvalidIndexError. Otherwise we fall through and re-raise\u001b[39;00m\n\u001b[32m   3816\u001b[39m     \u001b[38;5;66;03m#  the TypeError.\u001b[39;00m\n\u001b[32m   3817\u001b[39m     \u001b[38;5;28mself\u001b[39m._check_indexing_error(key)\n",
      "\u001b[31mKeyError\u001b[39m: 'human_bruteforce'"
     ]
    }
   ],
   "source": [
    "fulldata = pd.DataFrame()\n",
    "for dataset in data.keys():\n",
    "    fulldata = pd.DataFrame(data[dataset])\n",
    "    prompts = fulldata[\"PromptType\"].unique()\n",
    "    \n",
    "    # fig, axes = plt.pyplot.subplots(1, len(prompts), figsize=(30, 6))\n",
    "    for i in range(len(prompts)):\n",
    "        prompt = prompts[i]\n",
    "        cmd = sk.metrics.ConfusionMatrixDisplay(confusion_matrix(fulldata, prompt), display_labels=[\"Used Brute Force\", \"Did Not Use Brute Force\"])\n",
    "        cmd.plot()\n",
    "        cmd.ax_.set(xlabel='Human Solution', ylabel='Model Solution')\n",
    "        cmd.ax_.set_title(f'{dataset} Brute Force Matrix for {prompt}')\n",
    "    \n",
    "        print(f\"Correctness for {dataset} on {prompt}:\", correctness(fulldata, prompt))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d03c56aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.pyplot.subplots(1, 5, figsize=(30, 6))\n",
    "for i in range(len(prompts)):\n",
    "    prompt = prompts[i]\n",
    "    cmd = sk.metrics.ConfusionMatrixDisplay(confusion_matrix(fulldata, prompt), display_labels=[\"Used Brute Force\", \"Did Not Use Brute Force\"])\n",
    "    cmd.plot()\n",
    "    cmd.ax_.set(xlabel='Human Solution', ylabel='Model Solution')\n",
    "    cmd.ax_.set_title(f'Brute Force Matrix for {prompt}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bed10ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "tempdata = fulldata[fulldata[\"PromptType\"] == \"basicprompt\"]\n",
    "# print(tempdata.head())\n",
    "len(tempdata[tempdata[\"model_bruteforce\"]==\"1\"][tempdata[\"human_bruteforce\"]==\"0\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24d56f2b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Brainteasers",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
