{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import prettytable\n",
    "import random\n",
    "from collections import defaultdict \n",
    "\n",
    "random.seed(53)\n",
    "\n",
    "category_correct_gem0 = defaultdict(int)\n",
    "category_correct_gem4 = defaultdict(int)\n",
    "# category_correct_gem4cot = defaultdict(int)\n",
    "category_correct_gpt0 = defaultdict(int)\n",
    "category_correct_gpt4 = defaultdict(int)\n",
    "# category_correct_gpt4cot = defaultdict(int)\n",
    "category_total = defaultdict(int)\n",
    "\n",
    "\n",
    "with open('results/illusionvqa_comprehension_gpt_gemini_results.json') as f:\n",
    "    dataset = json.load(f)\n",
    "\n",
    "\n",
    "for item in dataset:\n",
    "    gem0 = item['GEM_0_SHOT']\n",
    "    gem4 = item['GEM_4_SHOT']\n",
    "    # gem4cot = item['gem4cot']\n",
    "    gpt0 = item['GPT_0_SHOT']\n",
    "    gpt4 = item['GPT_4_SHOT']\n",
    "    # gpt4cot = item['gpt4cot']\n",
    "\n",
    "    answer = item['answer']\n",
    "    category = item['category']\n",
    "\n",
    "    if gem0 == answer:\n",
    "        category_correct_gem0[category] += 1\n",
    "    if gem4 == answer:\n",
    "        category_correct_gem4[category] += 1\n",
    "    # if gem4cot == answer:\n",
    "    #     category_correct_gem4cot[category] += 1\n",
    "    if gpt0 == answer:\n",
    "        category_correct_gpt0[category] += 1\n",
    "    if gpt4 == answer:\n",
    "        category_correct_gpt4[category] += 1\n",
    "    # if gpt4cot == answer:\n",
    "    #     category_correct_gpt4cot[category] += 1\n",
    "        \n",
    "    category_total[category] += 1\n",
    "    \n",
    "   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GEM_0_SHOT Accuracy:  0.5126436781609195\n",
      "GEM_4_SHOT Accuracy:  0.5287356321839081\n",
      "GPT_0_SHOT Accuracy:  0.5885057471264368\n",
      "GPT_4_SHOT Accuracy:  0.6298850574712643\n",
      "Total:  435\n"
     ]
    }
   ],
   "source": [
    "total_correct_gem0 = 0\n",
    "total_correct_gem4 = 0\n",
    "# total_correct_gem4cot = 0\n",
    "total_correct_gpt0 = 0\n",
    "total_correct_gpt4 = 0\n",
    "# total_correct_gpt4cot = 0\n",
    "total = 0\n",
    "\n",
    "for category in category_total:\n",
    "    total_correct_gem0 += category_correct_gem0[category]\n",
    "    total_correct_gem4 += category_correct_gem4[category]\n",
    "    # total_correct_gem4cot += category_correct_gem4cot[category]\n",
    "    total_correct_gpt0 += category_correct_gpt0[category]\n",
    "    total_correct_gpt4 += category_correct_gpt4[category]\n",
    "    # total_correct_gpt4cot += category_correct_gpt4cot[category]\n",
    "    total += category_total[category]\n",
    "\n",
    "print(\"GEM_0_SHOT Accuracy: \", total_correct_gem0 / total)\n",
    "print(\"GEM_4_SHOT Accuracy: \", total_correct_gem4 / total)\n",
    "# print(\"GEM_4_SHOT_COT Accuracy: \", total_correct_gem4cot / total)\n",
    "print(\"GPT_0_SHOT Accuracy: \", total_correct_gpt0 / total)\n",
    "print(\"GPT_4_SHOT Accuracy: \", total_correct_gpt4 / total)\n",
    "# print(\"GPT_4_SHOT_COT Accuracy: \", total_correct_gpt4cot / total)\n",
    "print(\"Total: \", total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------------------+-------+---------------------+---------------------+---------------------+---------------------+\n",
      "|         Category        | Total |      GEM_0_SHOT     |      GPT_0_SHOT     |      GEM_4_SHOT     |      GPT_4_SHOT     |\n",
      "+-------------------------+-------+---------------------+---------------------+---------------------+---------------------+\n",
      "|    impossible object    |  134  |  0.5671641791044776 |  0.5522388059701493 |  0.5671641791044776 |  0.5970149253731343 |\n",
      "|        real-scene       |   64  |       0.46875       |       0.578125      |       0.46875       |       0.546875      |\n",
      "|           size          |   46  | 0.45652173913043476 |  0.5869565217391305 |  0.5217391304347826 |  0.6956521739130435 |\n",
      "|          hidden         |   45  |  0.4222222222222222 |  0.5111111111111111 |  0.4888888888888889 |  0.4666666666666667 |\n",
      "|     deceptive design    |   37  |  0.6486486486486487 |  0.7027027027027027 |  0.6756756756756757 |  0.7297297297297297 |\n",
      "|      angle illusion     |   26  |  0.5384615384615384 |  0.6923076923076923 |         0.5         |  0.8461538461538461 |\n",
      "|          color          |   23  | 0.17391304347826086 |  0.6956521739130435 | 0.17391304347826086 |  0.8260869565217391 |\n",
      "|       edited-scene      |   21  |  0.6666666666666666 |  0.7142857142857143 |  0.6666666666666666 |  0.8095238095238095 |\n",
      "|         counting        |   11  |  0.5454545454545454 | 0.36363636363636365 |  0.5454545454545454 | 0.45454545454545453 |\n",
      "|       upside-down       |   7   | 0.42857142857142855 |  0.7142857142857143 |  0.5714285714285714 |  0.7142857142857143 |\n",
      "| positive-negative space |   7   |  0.8571428571428571 |  0.5714285714285714 |  0.7142857142857143 |  0.8571428571428571 |\n",
      "|      circle-spiral      |   6   |  0.3333333333333333 |         0.5         |  0.3333333333333333 |  0.3333333333333333 |\n",
      "|    repeating pattern    |   2   |         0.5         |         0.5         |         0.5         |         0.0         |\n",
      "|       perspective       |   2   |         0.5         |         1.0         |         1.0         |         1.0         |\n",
      "|        occlusion        |   2   |         0.5         |         0.0         |         0.5         |         0.5         |\n",
      "|     angle constancy     |   2   |         0.5         |         0.5         |         0.5         |         0.0         |\n",
      "+-------------------------+-------+---------------------+---------------------+---------------------+---------------------+\n"
     ]
    }
   ],
   "source": [
    "table = prettytable.PrettyTable()\n",
    "table.field_names = [\"Category\", \"Total\",\"GEM_0_SHOT\", \"GPT_0_SHOT\",\"GEM_4_SHOT\", \"GPT_4_SHOT\"]\n",
    "for category in category_total:\n",
    "    table.add_row([category, category_total[category],\n",
    "                     category_correct_gem0[category] / category_total[category],\n",
    "                     category_correct_gpt0[category] / category_total[category],\n",
    "                     category_correct_gem4[category] / category_total[category],\n",
    "                     category_correct_gpt4[category] / category_total[category]\n",
    "                     ])\n",
    "#sort by total\n",
    "table.sortby = \"Total\"\n",
    "table.reversesort = True\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5263157894736842\n",
      "0.42105263157894735\n",
      "0.5789473684210527\n",
      "0.42105263157894735\n"
     ]
    }
   ],
   "source": [
    "cats = [\n",
    "    category_correct_gem0,\n",
    "    category_correct_gpt0,\n",
    "    category_correct_gem4,\n",
    "    category_correct_gpt4\n",
    "]\n",
    "\n",
    "for cat in cats:\n",
    "    #misc\n",
    "    misc_correct = cat[\"counting\"]+cat[\"repeating pattern\"]+cat[\"perspective\"]+cat[\"occlusion\"]+cat[\"angle constancy\"]\n",
    "    misc_total = category_total[\"counting\"]+category_total[\"repeating pattern\"]+category_total[\"perspective\"]+category_total[\"occlusion\"]+category_total[\"angle constancy\"]\n",
    "    print(misc_correct/misc_total)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
