{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import prettytable\n",
    "import random\n",
    "from collections import defaultdict \n",
    "\n",
    "random.seed(53)\n",
    "\n",
    "category_correct_gem0 = defaultdict(int)\n",
    "category_correct_gem4 = defaultdict(int)\n",
    "category_correct_gem4cot = defaultdict(int)\n",
    "category_correct_gpt0 = defaultdict(int)\n",
    "category_correct_gpt4 = defaultdict(int)\n",
    "category_correct_gpt4cot = defaultdict(int)\n",
    "category_total = defaultdict(int)\n",
    "\n",
    "\n",
    "with open('results/illusionvqa_soft_loc_gpt_gemini_results.json') as f:\n",
    "    dataset = json.load(f)\n",
    "\n",
    "\n",
    "for item in dataset:\n",
    "    gem0 = item['GEM_0_SHOT']\n",
    "    gem4 = item['GEM_4_SHOT']\n",
    "    gem4cot = item['GEM_4_SHOT_COT']\n",
    "    gpt0 = item['GPT_0_SHOT']\n",
    "    gpt4 = item['GPT_4_SHOT']\n",
    "    gpt4cot = item['GPT_4_SHOT_COT']\n",
    "\n",
    "    answer = item['answer']\n",
    "    category = item['category']\n",
    "\n",
    "    if gem0 == answer:\n",
    "        category_correct_gem0[category] += 1\n",
    "    if gem4 == answer:\n",
    "        category_correct_gem4[category] += 1\n",
    "    if gem4cot == answer:\n",
    "        category_correct_gem4cot[category] += 1\n",
    "    if gpt0 == answer:\n",
    "        category_correct_gpt0[category] += 1\n",
    "    if gpt4 == answer:\n",
    "        category_correct_gpt4[category] += 1\n",
    "    if gpt4cot == answer:\n",
    "        category_correct_gpt4cot[category] += 1\n",
    "        \n",
    "    category_total[category] += 1\n",
    "    \n",
    "   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GEM_0_SHOT Accuracy:  0.435\n",
      "GEM_4_SHOT Accuracy:  0.418\n",
      "GEM_4_SHOT_COT Accuracy:  0.339\n",
      "GPT_0_SHOT Accuracy:  0.4\n",
      "GPT_4_SHOT Accuracy:  0.46\n",
      "GPT_4_SHOT_COT Accuracy:  0.497\n",
      "Total:  1000\n"
     ]
    }
   ],
   "source": [
    "total_correct_gem0 = 0\n",
    "total_correct_gem4 = 0\n",
    "total_correct_gem4cot = 0\n",
    "total_correct_gpt0 = 0\n",
    "total_correct_gpt4 = 0\n",
    "total_correct_gpt4cot = 0\n",
    "total = 0\n",
    "\n",
    "for category in category_total:\n",
    "    total_correct_gem0 += category_correct_gem0[category]\n",
    "    total_correct_gem4 += category_correct_gem4[category]\n",
    "    total_correct_gem4cot += category_correct_gem4cot[category]\n",
    "    total_correct_gpt0 += category_correct_gpt0[category]\n",
    "    total_correct_gpt4 += category_correct_gpt4[category]\n",
    "    total_correct_gpt4cot += category_correct_gpt4cot[category]\n",
    "    total += category_total[category]\n",
    "\n",
    "print(\"GEM_0_SHOT Accuracy: \", total_correct_gem0 / total)\n",
    "print(\"GEM_4_SHOT Accuracy: \", total_correct_gem4 / total)\n",
    "print(\"GEM_4_SHOT_COT Accuracy: \", total_correct_gem4cot / total)\n",
    "print(\"GPT_0_SHOT Accuracy: \", total_correct_gpt0 / total)\n",
    "print(\"GPT_4_SHOT Accuracy: \", total_correct_gpt4 / total)\n",
    "print(\"GPT_4_SHOT_COT Accuracy: \", total_correct_gpt4cot / total)\n",
    "print(\"Total: \", total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------------+-------+------------+------------+------------+------------+\n",
      "|      Category     | Total | GEM_0_SHOT | GPT_0_SHOT | GEM_4_SHOT | GPT_4_SHOT |\n",
      "+-------------------+-------+------------+------------+------------+------------+\n",
      "| soft_localization |  1000 |   0.435    |    0.4     |   0.418    |    0.46    |\n",
      "+-------------------+-------+------------+------------+------------+------------+\n"
     ]
    }
   ],
   "source": [
    "# table = prettytable.PrettyTable()\n",
    "# table.field_names = [\"Category\", \"Total\",\"GEM_0_SHOT\", \"GPT_0_SHOT\",\"GEM_4_SHOT\", \"GPT_4_SHOT\"]\n",
    "# for category in category_total:\n",
    "#     table.add_row([category, category_total[category],\n",
    "#                      category_correct_gem0[category] / category_total[category],\n",
    "#                      category_correct_gpt0[category] / category_total[category],\n",
    "#                      category_correct_gem4[category] / category_total[category],\n",
    "#                      category_correct_gpt4[category] / category_total[category]\n",
    "#                      ])\n",
    "# #sort by total\n",
    "# table.sortby = \"Total\"\n",
    "# table.reversesort = True\n",
    "# print(table)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
