{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7cc7b9d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import json\n",
    "import sklearn as sk\n",
    "import matplotlib.pyplot as plt\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7d9b5316",
   "metadata": {},
   "outputs": [],
   "source": [
    "def confusion_matrix(data, prompt):\n",
    "    data = data[data[\"PromptType\"]==prompt]\n",
    "    return np.array([[len(data[data[\"model_bruteforce\"]==\"1\"][data[\"human_bruteforce\"]==\"1\"]), len(data[data[\"model_bruteforce\"]==\"1\"][data[\"human_bruteforce\"]==\"0\"])], [len(data[data[\"model_bruteforce\"]==\"0\"][data[\"human_bruteforce\"]==\"1\"]), len(data[data[\"model_bruteforce\"]==\"0\"][data[\"human_bruteforce\"]==\"0\"])]])\n",
    "\n",
    "def correctness(data, prompt):\n",
    "    data = data[data[\"PromptType\"]==prompt]\n",
    "    return len(data[data[\"correctness\"]==\"1\"]) / len(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1ac9523f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 1, 10: 0, 11: 0, 12: 1, 13: 0, 14: 0, 15: 0, 16: 0, 17: 0, 18: 0, 19: 0, 20: 0, 21: 0, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 1, 28: 0, 29: 0, 30: 0, 31: 0, 32: 1, 33: 0, 34: 0, 35: 0, 36: 0, 37: 0, 38: 1, 39: 0, 40: 0, 41: 0, 42: 0, 43: 1, 44: 1, 45: 1, 46: 1, 47: 0, 48: 0, 49: 0, 50: 0, 51: 0, 52: 0, 53: 0, 54: 1, 55: 1, 56: 0, 57: 0, 58: 0, 59: 0, 60: 0, 61: 0, 62: 0, 63: 1, 64: 0, 65: 0, 66: 0, 67: 0, 68: 0, 69: 0, 70: 0, 71: 0, 72: 1, 73: 0, 74: 0, 75: 0, 76: 0, 77: 0, 78: 0, 79: 0, 80: 0, 81: 1, 82: 0, 83: 0, 84: 0, 85: 0, 86: 0, 87: 0, 88: 0, 89: 0, 90: 0, 91: 0, 92: 1, 93: 0, 94: 0, 95: 0, 96: 0, 97: 1, 98: 0, 99: 0, 100: 0, 101: 0, 102: 0, 103: 1, 104: 1, 105: 0, 106: 0, 107: 1, 108: 0, 109: 0, 110: 0, 111: 0, 112: 0, 113: 0, 114: 0, 115: 0, 116: 0, 117: 0, 118: 0, 119: 0, 120: 0, 121: 0, 122: 0, 123: 0, 124: 0, 125: 0, 126: 0, 127: 0, 128: 0, 129: 1, 130: 0, 131: 0, 132: 1, 133: 1, 134: 0, 135: 1, 136: 0, 137: 0, 138: 0, 139: 0, 140: 0, 141: 0, 142: 0, 143: 0, 144: 0, 145: 0, 146: 0, 147: 0, 148: 0, 149: 0, 150: 1, 151: 0, 152: 0, 153: 0, 154: 1, 155: 0, 156: 0, 157: 0, 158: 0, 159: 0, 160: 0, 161: 0, 162: 0, 163: 0, 164: 0, 165: 0, 166: 0, 167: 0, 168: 0, 169: 0, 170: 1, 171: 0, 172: 0, 173: 1, 174: 0, 175: 0, 176: 0, 177: 0, 178: 0, 179: 0, 180: 0, 181: 0, 182: 0, 183: 0, 184: 1, 185: 0, 186: 1, 187: 0, 188: 0, 189: 0, 190: 0, 191: 0, 192: 0, 193: 0, 194: 0, 195: 0, 196: 1, 197: 0, 198: 0, 199: 0, 200: 0, 201: 0, 202: 0, 203: 0, 204: 0, 205: 0, 206: 0, 207: 1, 208: 0, 209: 0, 210: 0, 211: 0, 212: 0, 213: 0, 214: 0, 215: 0, 216: 0, 217: 0, 218: 0, 219: 0, 220: 0, 221: 0, 222: 0, 223: 0, 224: 0, 225: 1, 226: 0, 227: 0, 228: 0, 229: 0, 230: 0, 231: 0, 232: 0, 233: 0, 234: 0, 235: 0, 236: 0, 237: 0, 238: 0, 239: 0, 240: 0, 241: 0, 242: 0, 243: 0, 244: 0, 245: 0, 246: 1, 247: 0, 248: 0, 249: 0}\n",
      "{0: 0, 1: 0, 2: 0, 3: 1, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0, 11: 0, 12: 0, 13: 1, 14: 0, 15: 0, 16: 0, 17: 0, 18: 0, 19: 1, 20: 0, 21: 0, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 1, 34: 0, 35: 0, 36: 0, 37: 0, 38: 0, 39: 1, 40: 0, 41: 0, 42: 0, 43: 0, 44: 0, 45: 0, 46: 0, 47: 0, 48: 0, 49: 0, 50: 1, 51: 0, 52: 1, 53: 0, 54: 0, 55: 0, 56: 0, 57: 0, 58: 1, 59: 0, 60: 0, 61: 0, 62: 0, 63: 1, 64: 1, 65: 0, 66: 0, 67: 0, 68: 0, 69: 0, 70: 0, 71: 0, 72: 1, 73: 0, 74: 0, 75: 0, 76: 0, 77: 0, 78: 0, 79: 0, 80: 0, 81: 0, 82: 0, 83: 0, 84: 0, 85: 0, 86: 0, 87: 0, 88: 0, 89: 0, 90: 0, 91: 0, 92: 0, 93: 0, 94: 0, 95: 0, 96: 1, 97: 0, 98: 0, 99: 0, 100: 0, 101: 0, 102: 0, 103: 0, 104: 1, 105: 0, 106: 0, 107: 1, 108: 0, 109: 1, 110: 0, 111: 0, 112: 0, 113: 0, 114: 0, 115: 0, 116: 0, 117: 0, 118: 0, 119: 0, 120: 0, 121: 0, 122: 0, 123: 1, 124: 0, 125: 0, 126: 0, 127: 1, 128: 0, 129: 0, 130: 0, 131: 0, 132: 0, 133: 0, 134: 0, 135: 0, 136: 0, 137: 0, 138: 0, 139: 0, 140: 0, 141: 0, 142: 0, 143: 0, 144: 0, 145: 0, 146: 0, 147: 0, 148: 0, 149: 0, 150: 0, 151: 0, 152: 0, 153: 0, 154: 0, 155: 0, 156: 0, 157: 0, 158: 0, 159: 0, 160: 0, 161: 0, 162: 0, 163: 0, 164: 0, 165: 0, 166: 0, 167: 0, 168: 0, 169: 0, 170: 0, 171: 0, 172: 0, 173: 0, 174: 0, 175: 0, 176: 1, 177: 0, 178: 0, 179: 0, 180: 0, 181: 0, 182: 0, 183: 0, 184: 0, 185: 0, 186: 0, 187: 1, 188: 0, 189: 0, 190: 0, 191: 0, 192: 0, 193: 1, 194: 0, 195: 0, 196: 0, 197: 0, 198: 0, 199: 0, 200: 0, 201: 0, 202: 0, 203: 0, 204: 1, 205: 0, 206: 0, 207: 0, 208: 0, 209: 0, 210: 0, 211: 1, 212: 0, 213: 0, 214: 0, 215: 0, 216: 0, 217: 0, 218: 0, 219: 0, 220: 0, 221: 0, 222: 1, 223: 0, 224: 0, 225: 0, 226: 0, 227: 1, 228: 0, 229: 0, 230: 0, 231: 0, 232: 0, 233: 0, 234: 0, 235: 0, 236: 0, 237: 0, 238: 0, 239: 0, 240: 0, 241: 0, 242: 0, 243: 0, 244: 1, 245: 0, 246: 0, 247: 0, 248: 0, 249: 0}\n"
     ]
    }
   ],
   "source": [
    "alldata = {}\n",
    "models = ['DSChat', 'DSReason', 'GeminiFlash', 'o3', 'Qwen1', 'Qwen14', 'Qwen70']\n",
    "tests = ['MathMain', 'MathHint', 'MathHintCombined', 'LogicMain', 'LogicHint', 'LogicHintCombined']\n",
    "humanbruteforcemath = {}\n",
    "humanbruteforcetotalmath = {}\n",
    "humanbruteforcelogic = {}\n",
    "humanbruteforcetotallogic = {}\n",
    "mathdifficulty = {}\n",
    "logicdifficulty = {}\n",
    "mathpopularity = {}\n",
    "logicpopularity = {}\n",
    "mathcategory = {}\n",
    "logiccategory = {}\n",
    "\n",
    "logicdata = pd.read_csv(\"data/braingle/braingle_Logic_with_categories.csv\")\n",
    "mathdata = pd.read_csv(\"data/braingle/braingle_Math_with_categories.csv\")\n",
    "for index, row in logicdata.iterrows():\n",
    "    logicdifficulty[index] = row[\"Difficulty\"]\n",
    "    logicpopularity[index] = row[\"Popularity/Fun\"]\n",
    "    logiccategory[index] = row['categories']\n",
    "for index, row in mathdata.iterrows():\n",
    "    mathdifficulty[index] = row[\"Difficulty\"]\n",
    "    mathpopularity[index] = row[\"Popularity/Fun\"]\n",
    "    mathcategory[index] = row['categories']\n",
    "\n",
    "# models = ['FinalLogic-Qwen1']\n",
    "for model in models:\n",
    "    file_path = f\"response_evaluation/Logic/HintRedo-{model}/resultsEvaluations_evaluatedbyo3-2025-04-16.jsonl\"\n",
    "    with open(file_path, 'r') as file:\n",
    "        alldata[(model, \"LogicHint\")] = [json.loads(line) for line in file]\n",
    "\n",
    "    for index, row in pd.DataFrame(alldata[(model, \"LogicHint\")]).iterrows():\n",
    "        try:\n",
    "            if (row['human_bruteforce'] != '1' and row['human_bruteforce'] != '0'):\n",
    "                continue\n",
    "            if (row[\"ID\"] not in humanbruteforcetotallogic.keys()):\n",
    "                humanbruteforcetotallogic[row[\"ID\"]] = 0\n",
    "                humanbruteforcelogic[row[\"ID\"]] = 0\n",
    "            humanbruteforcetotallogic[row[\"ID\"]] += 1\n",
    "            humanbruteforcelogic[row[\"ID\"]] += int(row['human_bruteforce'])\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "for model in models:\n",
    "    file_path = f\"response_evaluation/Logic/FinalLogic-{model}/resultsEvaluations_evaluatedbyo3-2025-04-16.jsonl\"\n",
    "    with open(file_path, 'r') as file:\n",
    "        alldata[(model, \"LogicMain\")] = [json.loads(line) for line in file]\n",
    "\n",
    "    for index, row in pd.DataFrame(alldata[(model, \"LogicMain\")]).iterrows():\n",
    "        try:\n",
    "            if (row['human_bruteforce'] != '1' and row['human_bruteforce'] != '0'):\n",
    "                continue\n",
    "            if (row[\"ID\"] not in humanbruteforcetotallogic.keys()):\n",
    "                humanbruteforcetotallogic[row[\"ID\"]] = 0\n",
    "                humanbruteforcelogic[row[\"ID\"]] = 0\n",
    "            humanbruteforcetotallogic[row[\"ID\"]] += 1\n",
    "            humanbruteforcelogic[row[\"ID\"]] += int(row['human_bruteforce'])\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "for model in models:\n",
    "    file_path = f\"response_evaluation/Math/FinalMath-{model}/resultsEvaluations_evaluatedbyo3-2025-04-16.jsonl\"\n",
    "    with open(file_path, 'r') as file:\n",
    "        alldata[(model, \"MathMain\")] = [json.loads(line) for line in file]\n",
    "\n",
    "    for index, row in pd.DataFrame(alldata[(model, \"MathMain\")]).iterrows():\n",
    "        try:\n",
    "            if (row['human_bruteforce'] != '1' and row['human_bruteforce'] != '0'):\n",
    "                continue\n",
    "            if (row[\"ID\"] not in humanbruteforcetotalmath.keys()):\n",
    "                humanbruteforcetotalmath[row[\"ID\"]] = 0\n",
    "                humanbruteforcemath[row[\"ID\"]] = 0\n",
    "            humanbruteforcetotalmath[row[\"ID\"]] += 1\n",
    "            humanbruteforcemath[row[\"ID\"]] += int(row['human_bruteforce'])\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "for model in models:\n",
    "    file_path = f\"response_evaluation/Math/HintRedo-{model}/resultsEvaluations_evaluatedbyo3-2025-04-16.jsonl\"\n",
    "    with open(file_path, 'r') as file:\n",
    "        alldata[(model, \"MathHint\")] = [json.loads(line) for line in file]\n",
    "\n",
    "    for index, row in pd.DataFrame(alldata[(model, \"MathHint\")]).iterrows():\n",
    "        try:\n",
    "            if (row['human_bruteforce'] != '1' and row['human_bruteforce'] != '0'):\n",
    "                continue\n",
    "            if (row[\"ID\"] not in humanbruteforcetotalmath.keys()):\n",
    "                humanbruteforcetotalmath[row[\"ID\"]] = 0\n",
    "                humanbruteforcemath[row[\"ID\"]] = 0\n",
    "            humanbruteforcetotalmath[row[\"ID\"]] += 1\n",
    "            humanbruteforcemath[row[\"ID\"]] += int(row['human_bruteforce'])\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "for model in models:\n",
    "    file_path = f\"response_evaluation/Logic/CombinedHint-{model}/resultsEvaluations_evaluatedbyo3-2025-04-16.jsonl\"\n",
    "    try:\n",
    "        with open(file_path, 'r') as file:\n",
    "            alldata[(model, \"LogicHintCombined\")] = [json.loads(line) for line in file]\n",
    "\n",
    "        for index, row in pd.DataFrame(alldata[(model, \"LogicHintCombined\")]).iterrows():\n",
    "            try:\n",
    "                if (row['human_bruteforce'] != '1' and row['human_bruteforce'] != '0'):\n",
    "                    continue\n",
    "                if (row[\"ID\"] not in humanbruteforcetotallogic.keys()):\n",
    "                    humanbruteforcetotallogic[row[\"ID\"]] = 0\n",
    "                    humanbruteforcelogic[row[\"ID\"]] = 0\n",
    "                humanbruteforcetotallogic[row[\"ID\"]] += 1\n",
    "                humanbruteforcelogic[row[\"ID\"]] += int(row['human_bruteforce'])\n",
    "            except:\n",
    "                pass\n",
    "    except:\n",
    "        pass\n",
    "\n",
    "for model in models:\n",
    "    file_path = f\"response_evaluation/Math/CombinedHint-{model}/resultsEvaluations_evaluatedbyo3-2025-04-16.jsonl\"\n",
    "    try:\n",
    "        with open(file_path, 'r') as file:\n",
    "            alldata[(model, \"MathHintCombined\")] = [json.loads(line) for line in file]\n",
    "\n",
    "        for index, row in pd.DataFrame(alldata[(model, \"MathHintCombined\")]).iterrows():\n",
    "            try:\n",
    "                if (row['human_bruteforce'] != '1' and row['human_bruteforce'] != '0'):\n",
    "                    continue\n",
    "                if (row[\"ID\"] not in humanbruteforcetotalmath.keys()):\n",
    "                    humanbruteforcetotalmath[row[\"ID\"]] = 0\n",
    "                    humanbruteforcemath[row[\"ID\"]] = 0\n",
    "                humanbruteforcetotalmath[row[\"ID\"]] += 1\n",
    "                humanbruteforcemath[row[\"ID\"]] += int(row['human_bruteforce'])\n",
    "            except:\n",
    "                pass\n",
    "    except:\n",
    "        pass\n",
    "\n",
    "\n",
    "for i in range(250):\n",
    "    humanbruteforcemath[i] = round(humanbruteforcemath[i] / humanbruteforcetotalmath[i])\n",
    "    humanbruteforcelogic[i] = round(humanbruteforcelogic[i] / humanbruteforcetotallogic[i])\n",
    "\n",
    "print(humanbruteforcemath)\n",
    "print(humanbruteforcelogic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d03c56aa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: DSChat, Test: MathMain, Prompt: basicprompt [10.4 30.8  2.8 56. ] 250.0 Correctness: 0.58/0.4\n",
      "Model: DSChat, Test: MathMain, Prompt: mathPrompt [ 9.6 28.4  3.6 58.4] 250.0 Correctness: 0.536/0.32\n",
      "Model: DSChat, Test: MathHint, Prompt: hintPrompt [ 9.6 28.8  3.6 58. ] 250.0 Correctness: 0.568/0.38\n",
      "Model: DSChat, Test: MathHintCombined, Prompt: combinedhintPrompt [11.2 30.   2.  56.8] 250.0 Correctness: 0.6/0.38\n",
      "Model: DSChat, Test: LogicMain, Prompt: basicprompt [ 6. 38.  4. 52.] 250.0 Correctness: 0.392/0.32\n",
      "Model: DSChat, Test: LogicMain, Prompt: mathPrompt [ 7.2 34.   2.8 56. ] 250.0 Correctness: 0.412/0.34\n",
      "Model: DSChat, Test: LogicHint, Prompt: hintPrompt [ 7.6 30.   2.4 60. ] 250.0 Correctness: 0.414/0.26\n",
      "Model: DSChat, Test: LogicHintCombined, Prompt: combinedhintPrompt [ 6.4 28.8  3.6 61.2] 250.0 Correctness: 0.408/0.2\n",
      "Model: DSReason, Test: MathMain, Prompt: basicprompt [ 8.  13.2  5.2 73.6] 250.0 Correctness: 0.664/0.48\n",
      "Model: DSReason, Test: MathMain, Prompt: mathPrompt [ 8.8 10.8  4.4 75.9] 249.0 Correctness: 0.677/0.52\n",
      "Model: DSReason, Test: MathHint, Prompt: hintPrompt [ 8.8 15.6  4.4 71.2] 250.0 Correctness: 0.728/0.48\n",
      "Model: DSReason, Test: MathHintCombined, Prompt: combinedhintPrompt [ 8.4 10.8  4.8 76. ] 250.0 Correctness: 0.724/0.6\n",
      "Model: DSReason, Test: LogicMain, Prompt: basicprompt [ 5.2 13.3  4.4 77.1] 249.0 Correctness: 0.446/0.26\n",
      "Model: DSReason, Test: LogicMain, Prompt: mathPrompt [ 4. 14.  6. 76.] 250.0 Correctness: 0.44/0.28\n",
      "Model: DSReason, Test: LogicHint, Prompt: hintPrompt [ 5.2 10.4  4.4 79.9] 249.0 Correctness: 0.498/0.347\n",
      "Model: DSReason, Test: LogicHintCombined, Prompt: combinedhintPrompt [ 6.   6.   3.6 84.3] 249.0 Correctness: 0.526/0.44\n",
      "Model: GeminiFlash, Test: MathMain, Prompt: basicprompt [ 7.8 16.7  5.3 70.2] 245.0 Correctness: 0.678/0.5\n",
      "Model: GeminiFlash, Test: MathMain, Prompt: mathPrompt [ 7.  15.2  6.6 71.3] 244.0 Correctness: 0.687/0.479\n",
      "Model: GeminiFlash, Test: MathHint, Prompt: hintPrompt [ 1.2  7.4  2.5 88.9] 81.0 Correctness: 0.827/0.6\n",
      "Model: GeminiFlash, Test: MathHintCombined, Prompt: combinedhintPrompt [ 1.3  0.   2.7 96. ] 75.0 Correctness: 0.707/0.571\n",
      "Model: GeminiFlash, Test: LogicMain, Prompt: basicprompt [ 7.6 15.3  2.4 74.7] 249.0 Correctness: 0.141/0.02\n",
      "Model: GeminiFlash, Test: LogicMain, Prompt: mathPrompt [ 8.1 11.7  2.  78.1] 247.0 Correctness: 0.108/0.044\n",
      "Model: GeminiFlash, Test: LogicHint, Prompt: hintPrompt [ 4.9  8.2  3.3 83.6] 61.0 Correctness: 0.508/0.2\n",
      "Model: GeminiFlash, Test: LogicHintCombined, Prompt: combinedhintPrompt [ 1.8 12.5  1.8 83.9] 56.0 Correctness: 0.518/0.0\n",
      "Model: o3, Test: MathMain, Prompt: basicprompt [ 7.3  9.8  6.1 76.8] 246.0 Correctness: 0.809/0.667\n",
      "Model: o3, Test: MathMain, Prompt: mathPrompt [ 4.8  4.4  8.5 82.3] 248.0 Correctness: 0.798/0.66\n",
      "Model: o3, Test: MathHint, Prompt: hintPrompt [ 4.9  8.9  6.4 79.8] 203.0 Correctness: 0.901/0.806\n",
      "Model: o3, Test: MathHintCombined, Prompt: combinedhintPrompt [ 3.   4.6  7.1 85.3] 197.0 Correctness: 0.904/0.806\n",
      "Model: o3, Test: LogicMain, Prompt: basicprompt [ 2.3  8.9  7.5 81.3] 214.0 Correctness: 0.827/0.824\n",
      "Model: o3, Test: LogicMain, Prompt: mathPrompt [ 1.8  5.   6.8 86.4] 220.0 Correctness: 0.782/0.694\n",
      "Model: o3, Test: LogicHint, Prompt: hintPrompt [ 2.3  7.   6.4 84.3] 172.0 Correctness: 0.878/0.818\n",
      "Model: o3, Test: LogicHintCombined, Prompt: combinedhintPrompt [ 4.3  9.8  5.5 80.5] 164.0 Correctness: 0.915/0.952\n",
      "Model: Qwen1, Test: MathMain, Prompt: basicprompt [ 9.2 28.4  4.  58.4] 250.0 Correctness: 0.164/0.12\n",
      "Model: Qwen1, Test: MathMain, Prompt: mathPrompt [ 9.2 30.   4.  56.8] 250.0 Correctness: 0.16/0.08\n",
      "Model: Qwen1, Test: MathHint, Prompt: hintPrompt [ 8.8 31.2  4.4 55.6] 250.0 Correctness: 0.148/0.08\n",
      "Model: Qwen1, Test: MathHintCombined, Prompt: combinedhintPrompt [ 9.6 28.8  3.6 58. ] 250.0 Correctness: 0.192/0.1\n",
      "Model: Qwen1, Test: LogicMain, Prompt: basicprompt [ 6.4 22.4  3.6 67.6] 250.0 Correctness: 0.04/0.04\n",
      "Model: Qwen1, Test: LogicMain, Prompt: mathPrompt [ 6.4 21.6  3.6 68.4] 250.0 Correctness: 0.036/0.04\n",
      "Model: Qwen1, Test: LogicHint, Prompt: hintPrompt [ 6.4 16.8  3.6 73.2] 250.0 Correctness: 0.056/0.06\n",
      "Model: Qwen1, Test: LogicHintCombined, Prompt: combinedhintPrompt [ 8.4 22.8  1.6 67.2] 250.0 Correctness: 0.044/0.06\n",
      "Model: Qwen14, Test: MathMain, Prompt: basicprompt [10.  33.6  3.2 53.2] 250.0 Correctness: 0.424/0.24\n",
      "Model: Qwen14, Test: MathMain, Prompt: mathPrompt [10.8 27.2  2.4 59.6] 250.0 Correctness: 0.412/0.24\n",
      "Model: Qwen14, Test: MathHint, Prompt: hintPrompt [10.4 28.4  2.8 58.4] 250.0 Correctness: 0.436/0.2\n",
      "Model: Qwen14, Test: MathHintCombined, Prompt: combinedhintPrompt [ 9.6 28.4  3.6 58.4] 250.0 Correctness: 0.424/0.24\n",
      "Model: Qwen14, Test: LogicMain, Prompt: basicprompt [ 7.2 35.6  2.8 54.4] 250.0 Correctness: 0.22/0.16\n",
      "Model: Qwen14, Test: LogicMain, Prompt: mathPrompt [ 6.4 34.8  3.6 55.2] 250.0 Correctness: 0.228/0.16\n",
      "Model: Qwen14, Test: LogicHint, Prompt: hintPrompt [ 7.2 32.4  2.8 57.6] 250.0 Correctness: 0.284/0.22\n",
      "Model: Qwen14, Test: LogicHintCombined, Prompt: combinedhintPrompt [ 5.2 33.2  4.8 56.8] 250.0 Correctness: 0.26/0.28\n",
      "Model: Qwen70, Test: MathMain, Prompt: basicprompt [11.6 27.2  1.6 59.6] 250.0 Correctness: 0.42/0.18\n",
      "Model: Qwen70, Test: MathMain, Prompt: mathPrompt [ 9.6 27.2  3.6 59.6] 250.0 Correctness: 0.404/0.22\n",
      "Model: Qwen70, Test: MathHint, Prompt: hintPrompt [ 9.2 25.6  4.  61.2] 250.0 Correctness: 0.448/0.24\n",
      "Model: Qwen70, Test: MathHintCombined, Prompt: combinedhintPrompt [ 8.8 26.   4.4 60.8] 250.0 Correctness: 0.428/0.18\n",
      "Model: Qwen70, Test: LogicMain, Prompt: basicprompt [ 8.  26.8  2.  63.2] 250.0 Correctness: 0.248/0.18\n",
      "Model: Qwen70, Test: LogicMain, Prompt: mathPrompt [ 7.2 25.6  2.8 64.4] 250.0 Correctness: 0.252/0.16\n",
      "Model: Qwen70, Test: LogicHint, Prompt: hintPrompt [ 6.4 25.6  3.6 64.4] 250.0 Correctness: 0.272/0.2\n",
      "Model: Qwen70, Test: LogicHintCombined, Prompt: combinedhintPrompt [ 7.2 23.2  2.8 66.8] 250.0 Correctness: 0.301/0.265\n"
     ]
    }
   ],
   "source": [
    "for model in models:\n",
    "    for test in tests:\n",
    "        try:\n",
    "            data = pd.DataFrame(alldata[(model, test)])\n",
    "        except:\n",
    "            continue\n",
    "        prompts = data[\"PromptType\"].unique()\n",
    "\n",
    "        for prompt in prompts:\n",
    "            if \"hint\" in prompt and \"Main\" in test or \"symbol\" in prompt:\n",
    "                continue\n",
    "            tempdata = data[data[\"PromptType\"]==prompt]\n",
    "            bfarray = np.zeros((2, 2))\n",
    "            bfdiff = []\n",
    "            nbfdiff = []\n",
    "            bfpop = []\n",
    "            nbfpop = []\n",
    "            correctness = []\n",
    "            correctnessDiff = []\n",
    "            category = {}\n",
    "            count = 0\n",
    "\n",
    "            for index, row in tempdata.iterrows():\n",
    "                # if (type(row[\"Response\"]) != str):\n",
    "                #     print(row[\"Response\"])\n",
    "                if (type(row[\"Response\"]) != str or row[\"Response\"] == None or row[\"Response\"] == \"NaN\" or row[\"Response\"] == \"None\" or row[\"Response\"] == \"\" or row[\"model_bruteforce\"] == \"NULL\" or row[\"Response\"] is str and row[\"Response\"].isspace()):\n",
    "                    # print(\"Filtered!\")\n",
    "                    continue\n",
    "                try:\n",
    "                    if (row[\"model_bruteforce\"] == \"1\"):\n",
    "                        if \"Math\" in test:\n",
    "                            bfdiff.append(mathdifficulty[row[\"ID\"]])\n",
    "                            bfpop.append(mathpopularity[row[\"ID\"]])\n",
    "                        else:\n",
    "                            bfdiff.append(logicdifficulty[row[\"ID\"]])\n",
    "                            bfpop.append(logicpopularity[row[\"ID\"]])\n",
    "                    elif (row[\"model_bruteforce\"] == \"0\"):\n",
    "                        if \"Math\" in test:\n",
    "                            nbfdiff.append(mathdifficulty[row[\"ID\"]])\n",
    "                            nbfpop.append(mathpopularity[row[\"ID\"]])\n",
    "                        else:\n",
    "                            nbfdiff.append(logicdifficulty[row[\"ID\"]])\n",
    "                            nbfpop.append(logicpopularity[row[\"ID\"]])\n",
    "                    \n",
    "                    if \"Math\" in test:\n",
    "                        bfarray[1-int(row[\"model_bruteforce\"])][1-humanbruteforcemath[row['ID']]] += 1\n",
    "                    if \"Logic\" in test:\n",
    "                        bfarray[1-int(row[\"model_bruteforce\"])][1-humanbruteforcelogic[row['ID']]] += 1\n",
    "                    correctness.append(int(row[\"correctness\"]))\n",
    "\n",
    "                    if \"Logic\" in test:\n",
    "                        # print(logiccategory[row[\"ID\"]])\n",
    "                        if logiccategory[row[\"ID\"]] not in category.keys():\n",
    "                            category[logiccategory[row[\"ID\"]]] = [0, 0]\n",
    "                        \n",
    "                        # print(row[\"model_bruteforce\"])\n",
    "\n",
    "                        category[logiccategory[row[\"ID\"]]][1-int(row[\"model_bruteforce\"])] += 1\n",
    "\n",
    "                    if (row['ID'] < 50):\n",
    "                        correctnessDiff.append(int(row[\"correctness\"]))\n",
    "                    # count += 1\n",
    "\n",
    "                    \n",
    "                except Exception as e:\n",
    "                    # print(\"Error:\", e)\n",
    "                    # print(row[\"model_bruteforce\"])\n",
    "                    pass\n",
    "            # print(category)\n",
    "            \n",
    "            print(f\"Model: {model}, Test: {test}, Prompt: {prompt}\", \n",
    "                #   print(category),\n",
    "                  np.round(100*bfarray.flatten()/np.sum(bfarray), 1), np.sum(bfarray), \n",
    "                  \"Correctness:\", str(np.round(np.mean(correctness), 3)) + \"/\" + str(np.round(np.mean(correctnessDiff), 3)), \n",
    "                #   \"Difficulty (BF/NBF):\", np.round(np.mean(bfdiff), 2), np.round(np.mean(nbfdiff), 2), \n",
    "                #   \"Popularity (BF/NBF):\", np.round(np.mean(bfpop), 2), np.round(np.mean(nbfpop), 2),\n",
    "                  \n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5ed10e30",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Human Solution Math:\n",
      "Difficulty (BF/NBF): 2.83 2.8\n",
      "Popularity (BF/NBF): 2.32 2.33\n",
      "Human Solution Logic:\n",
      "Difficulty (BF/NBF): 2.7 2.65\n",
      "Popularity (BF/NBF): 2.37 2.51\n"
     ]
    }
   ],
   "source": [
    "print(\"Human Solution Math:\")\n",
    "totaldiffbf = []\n",
    "totaldiffnbf = []\n",
    "totalpopbf = []\n",
    "totalpopnbf = []\n",
    "for i in range(250):\n",
    "    if (humanbruteforcemath[i] == 1):\n",
    "        totaldiffbf.append(mathdifficulty[i])\n",
    "        totalpopbf.append(mathpopularity[i])\n",
    "    else:\n",
    "        totaldiffnbf.append(mathdifficulty[i])\n",
    "        totalpopnbf.append(mathpopularity[i])\n",
    "print(\"Difficulty (BF/NBF):\", np.round(np.mean(totaldiffbf), 2), np.round(np.mean(totaldiffnbf), 2))\n",
    "print(\"Popularity (BF/NBF):\", np.round(np.mean(totalpopbf), 2), np.round(np.mean(totalpopnbf), 2))\n",
    "print(\"Human Solution Logic:\")\n",
    "totaldiffbf = []\n",
    "totaldiffnbf = []\n",
    "totalpopbf = []\n",
    "totalpopnbf = []\n",
    "for i in range(250):\n",
    "    if (humanbruteforcelogic[i] == 1):\n",
    "        totaldiffbf.append(logicdifficulty[i])\n",
    "        totalpopbf.append(logicpopularity[i])\n",
    "    else:\n",
    "        totaldiffnbf.append(logicdifficulty[i])\n",
    "        totalpopnbf.append(logicpopularity[i])\n",
    "print(\"Difficulty (BF/NBF):\", np.round(np.mean(totaldiffbf), 2), np.round(np.mean(totaldiffnbf), 2))\n",
    "print(\"Popularity (BF/NBF):\", np.round(np.mean(totalpopbf), 2), np.round(np.mean(totalpopnbf), 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "c912abea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: '4,7',\n",
       " 1: '2',\n",
       " 2: '5',\n",
       " 3: '1',\n",
       " 4: '2',\n",
       " 5: '2,8',\n",
       " 6: '4',\n",
       " 7: '4',\n",
       " 8: '2,8',\n",
       " 9: '8',\n",
       " 10: '2',\n",
       " 11: '2',\n",
       " 12: '1',\n",
       " 13: '9',\n",
       " 14: '9',\n",
       " 15: '2,8',\n",
       " 16: '2,8',\n",
       " 17: '7',\n",
       " 18: '2',\n",
       " 19: '2',\n",
       " 20: '3',\n",
       " 21: '2',\n",
       " 22: '1',\n",
       " 23: '2',\n",
       " 24: '5',\n",
       " 25: '1',\n",
       " 26: '1,2',\n",
       " 27: '2',\n",
       " 28: '5',\n",
       " 29: '1',\n",
       " 30: '9',\n",
       " 31: '8',\n",
       " 32: '2',\n",
       " 33: '2',\n",
       " 34: '4',\n",
       " 35: '4',\n",
       " 36: '2',\n",
       " 37: '2',\n",
       " 38: '2',\n",
       " 39: '3',\n",
       " 40: '5',\n",
       " 41: '2',\n",
       " 42: '2',\n",
       " 43: '2,7',\n",
       " 44: '8',\n",
       " 45: '7',\n",
       " 46: '3',\n",
       " 47: '7',\n",
       " 48: '4',\n",
       " 49: '2',\n",
       " 50: '4',\n",
       " 51: '4,2',\n",
       " 52: '2',\n",
       " 53: '2,7',\n",
       " 54: '2',\n",
       " 55: '2',\n",
       " 56: '1',\n",
       " 57: '2,5',\n",
       " 58: '4,9',\n",
       " 59: '5',\n",
       " 60: '8',\n",
       " 61: '9',\n",
       " 62: '2,5',\n",
       " 63: '2,5',\n",
       " 64: '6',\n",
       " 65: '7',\n",
       " 66: '1,8',\n",
       " 67: '9,4',\n",
       " 68: '2',\n",
       " 69: '2',\n",
       " 70: '2',\n",
       " 71: '8',\n",
       " 72: '8',\n",
       " 73: '2',\n",
       " 74: '7',\n",
       " 75: '2',\n",
       " 76: '2',\n",
       " 77: '2',\n",
       " 78: '6',\n",
       " 79: '5',\n",
       " 80: '1',\n",
       " 81: '2',\n",
       " 82: '4,7',\n",
       " 83: '4',\n",
       " 84: '3',\n",
       " 85: '6',\n",
       " 86: '8',\n",
       " 87: '2',\n",
       " 88: '8',\n",
       " 89: '2',\n",
       " 90: '2',\n",
       " 91: '2',\n",
       " 92: '3',\n",
       " 93: '5',\n",
       " 94: '2,8',\n",
       " 95: '2',\n",
       " 96: '5',\n",
       " 97: '2',\n",
       " 98: '2',\n",
       " 99: '7',\n",
       " 100: '2',\n",
       " 101: '5',\n",
       " 102: '5',\n",
       " 103: '2,8',\n",
       " 104: '3',\n",
       " 105: '4,9',\n",
       " 106: '1,2',\n",
       " 107: '1,7',\n",
       " 108: '8',\n",
       " 109: '9',\n",
       " 110: '2',\n",
       " 111: '9',\n",
       " 112: '9',\n",
       " 113: '2',\n",
       " 114: '2',\n",
       " 115: '3,7',\n",
       " 116: '4',\n",
       " 117: '4,7',\n",
       " 118: '2',\n",
       " 119: '4',\n",
       " 120: '2',\n",
       " 121: '2',\n",
       " 122: '8',\n",
       " 123: '2',\n",
       " 124: '2',\n",
       " 125: '7',\n",
       " 126: '7',\n",
       " 127: '1',\n",
       " 128: '1',\n",
       " 129: '8',\n",
       " 130: '4',\n",
       " 131: '9',\n",
       " 132: '9',\n",
       " 133: '2',\n",
       " 134: '9',\n",
       " 135: '2,8',\n",
       " 136: '7',\n",
       " 137: '4',\n",
       " 138: '4',\n",
       " 139: '3',\n",
       " 140: '2,7',\n",
       " 141: '9',\n",
       " 142: '2',\n",
       " 143: '1',\n",
       " 144: '2',\n",
       " 145: '2',\n",
       " 146: '8',\n",
       " 147: '3',\n",
       " 148: '2',\n",
       " 149: '3',\n",
       " 150: '7',\n",
       " 151: '4',\n",
       " 152: '7',\n",
       " 153: '2',\n",
       " 154: '2',\n",
       " 155: '9',\n",
       " 156: '7',\n",
       " 157: '1',\n",
       " 158: '4',\n",
       " 159: '9',\n",
       " 160: '7',\n",
       " 161: '9',\n",
       " 162: '7',\n",
       " 163: '2',\n",
       " 164: '9',\n",
       " 165: '3',\n",
       " 166: '7,9',\n",
       " 167: '2',\n",
       " 168: '1,5',\n",
       " 169: '7',\n",
       " 170: '2',\n",
       " 171: '1,3',\n",
       " 172: '7',\n",
       " 173: '7',\n",
       " 174: '8',\n",
       " 175: '8',\n",
       " 176: '2,8',\n",
       " 177: '3',\n",
       " 178: '3',\n",
       " 179: '9',\n",
       " 180: '2',\n",
       " 181: '4',\n",
       " 182: '3',\n",
       " 183: '3',\n",
       " 184: '8',\n",
       " 185: '7',\n",
       " 186: '2',\n",
       " 187: '7',\n",
       " 188: '2',\n",
       " 189: '9',\n",
       " 190: '7',\n",
       " 191: '2',\n",
       " 192: '8',\n",
       " 193: '5',\n",
       " 194: '7',\n",
       " 195: '5',\n",
       " 196: '4,2',\n",
       " 197: '8',\n",
       " 198: '5',\n",
       " 199: '7',\n",
       " 200: '3',\n",
       " 201: '5,7',\n",
       " 202: '2',\n",
       " 203: '2',\n",
       " 204: '1',\n",
       " 205: '2',\n",
       " 206: '7',\n",
       " 207: '8',\n",
       " 208: '1',\n",
       " 209: '5',\n",
       " 210: '7',\n",
       " 211: '2,7',\n",
       " 212: '1,4',\n",
       " 213: '5',\n",
       " 214: '8',\n",
       " 215: '1',\n",
       " 216: '1',\n",
       " 217: '7',\n",
       " 218: '9',\n",
       " 219: '9',\n",
       " 220: '1',\n",
       " 221: '7',\n",
       " 222: '4',\n",
       " 223: '2',\n",
       " 224: '8',\n",
       " 225: '2,7',\n",
       " 226: '7,9',\n",
       " 227: '2',\n",
       " 228: '1',\n",
       " 229: '8',\n",
       " 230: '8',\n",
       " 231: '2',\n",
       " 232: '2',\n",
       " 233: '2',\n",
       " 234: '2',\n",
       " 235: '2',\n",
       " 236: '8,2',\n",
       " 237: '2',\n",
       " 238: '9',\n",
       " 239: '8',\n",
       " 240: '5',\n",
       " 241: '5',\n",
       " 242: '2',\n",
       " 243: '3',\n",
       " 244: '7',\n",
       " 245: '1',\n",
       " 246: '8',\n",
       " 247: '2,8',\n",
       " 248: '9',\n",
       " 249: '2'}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mathcategory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "b09aa979",
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "'dict' object is not callable",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mTypeError\u001b[39m                                 Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[17]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m      1\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[32m10\u001b[39m):\n\u001b[32m      2\u001b[39m     \u001b[38;5;28;01mfor\u001b[39;00m j \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[32m250\u001b[39m):\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m         \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28mstr\u001b[39m(i) \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[43mmathcategory\u001b[49m\u001b[43m(\u001b[49m\u001b[43mj\u001b[49m\u001b[43m)\u001b[49m):\n\u001b[32m      4\u001b[39m             \u001b[38;5;28;01mcontinue\u001b[39;00m\n",
      "\u001b[31mTypeError\u001b[39m: 'dict' object is not callable"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    for j in range(250):\n",
    "        if (str(i) not in mathcategory(j)):\n",
    "            continue"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Brainteasers",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
