{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "\n",
    "def construct_mcq(options, correct_option):\n",
    "    correct_option_letter = None\n",
    "    i = \"a\"\n",
    "    mcq = \"\"\n",
    "\n",
    "    for option in options:\n",
    "        if option == correct_option:\n",
    "            correct_option_letter = i\n",
    "        mcq += f\"{i}. {option}\\n\"\n",
    "        i = chr(ord(i) + 1)\n",
    "\n",
    "    if correct_option_letter is None:\n",
    "        print(options, correct_option)\n",
    "        raise ValueError(\"Correct option not found in the options\")\n",
    "    \n",
    "    mcq = mcq[:-1]\n",
    "    return mcq, correct_option_letter\n",
    "\n",
    "def resize_image(image_path, size):\n",
    "    '''resize image so that the largest edge is atmost size'''\n",
    "    img = Image.open(image_path)\n",
    "    width, height = img.size\n",
    "\n",
    "    if width <= size and height <= size:\n",
    "        return img\n",
    "    \n",
    "    if width > height:\n",
    "        new_width = size\n",
    "        new_height = int(height * (size / width))\n",
    "    else:\n",
    "        new_height = size\n",
    "        new_width = int(width * (size / height))\n",
    "    img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)\n",
    "    return img\n",
    "\n",
    "\n",
    "def add_row(content, data, i, with_answer=False):  \n",
    "\n",
    "    content.append(\"Image \"+str(i)+\": \")\n",
    "    content.append(resize_image(data[\"image_path\"], 512))\n",
    "    content.append(data[\"question\"])\n",
    "    content.append(data[\"mcq\"])\n",
    "\n",
    "    if with_answer:\n",
    "        content.append(\"Reasoning: {}\".format(data[\"reasoning\"]))    \n",
    "        content.append(\"Answer: {}\".format(data[\"correct_option_letter\"]))\n",
    "    else:\n",
    "        content.append(\"Reasoning: \")\n",
    "    \n",
    "    return content\n",
    "   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import google.generativeai as genai\n",
    "\n",
    "GOOGLE_API_KEY='YOUR_API_KEY_HERE'\n",
    "\n",
    "FEWSHOT_JSON = \"illusionVQA/comprehension/fewshot_labels.json\"\n",
    "FEWSHOT_IMAGE_DIR = \"illusionVQA/comprehension/FEW_SHOTS/\"\n",
    "EVAL_JSON = \"illusionVQA/comprehension/eval_labels.json\"\n",
    "EVAL_IMAGE_DIR = \"illusionVQA/comprehension/EVAL/\"\n",
    "\n",
    "genai.configure(api_key=GOOGLE_API_KEY)\n",
    "model = genai.GenerativeModel('gemini-pro-vision')\n",
    "model_name = 'gemini-pro-vision'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import json\n",
    "with open(FEWSHOT_JSON) as f:\n",
    "    fewshot_dataset = json.load(f)\n",
    "\n",
    "for data in fewshot_dataset:\n",
    "    data[\"image_path\"] = FEWSHOT_IMAGE_DIR + data[\"image\"]\n",
    "    data[\"mcq\"], data[\"correct_option_letter\"] = construct_mcq(data[\"options\"], data[\"answer\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(EVAL_JSON) as f:\n",
    "    eval_dataset = json.load(f)\n",
    "\n",
    "from collections import defaultdict\n",
    "\n",
    "category_count = defaultdict(int)\n",
    "import os\n",
    "for data in eval_dataset:\n",
    "    if data[\"image\"] not in os.listdir(EVAL_IMAGE_DIR):\n",
    "        print(data[\"image\"])\n",
    "        continue\n",
    "    data[\"image_path\"] = EVAL_IMAGE_DIR + data[\"image\"]\n",
    "    data[\"mcq\"], data[\"correct_option_letter\"] = construct_mcq(data[\"options\"], data[\"answer\"])\n",
    "    category_count[data[\"category\"]] += 1\n",
    "\n",
    "print(category_count)\n",
    "print(len(eval_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "content = [\"\"\"You'll be given an image, an instruction and some choices. You have to select the correct one. Reason about the choices in the context of the question and the image. End your answer with \"Answer\": {letter_of_correct_choice} without the curly brackets. Here are a few examples:\"\"\"\n",
    "]\n",
    "\n",
    "i = 1\n",
    "for data in fewshot_dataset:\n",
    "    content = add_row(content, data, i, with_answer=True)\n",
    "    i += 1\n",
    "content.append(\"Now you try it.\")\n",
    "\n",
    "next_data_idx = i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import time\n",
    "ytrue = []\n",
    "ypred = []\n",
    "\n",
    "for i,data in tqdm(enumerate(eval_dataset)):\n",
    "    content_t = add_row(content.copy(), data, next_data_idx, with_answer=False)\n",
    "    # print(content_t)\n",
    "    while True:\n",
    "        try:\n",
    "            response = model.generate_content(content_t,\n",
    "                                              safety_settings=[\n",
    "                {\n",
    "                    \"category\": \"HARM_CATEGORY_HARASSMENT\",\n",
    "                    \"threshold\": \"HIGH\",\n",
    "                },\n",
    "                {\n",
    "                    \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n",
    "                    \"threshold\": \"HIGH\",\n",
    "                },\n",
    "                {\n",
    "                    \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n",
    "                    \"threshold\": \"HIGH\",\n",
    "                },\n",
    "                {\n",
    "                    \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n",
    "                    \"threshold\": \"HIGH\",\n",
    "                },\n",
    "                ]\n",
    "            )\n",
    "            break\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            print(\"Internal Error\")\n",
    "            continue\n",
    "    \n",
    "    try:\n",
    "        gemini_answer = response.text\n",
    "    except Exception as e:\n",
    "        try:\n",
    "            gemini_answer = response.parts[0].text\n",
    "        except Exception as e:\n",
    "            print(\"External Error:\", response.prompt_feedback)\n",
    "            # print(response.candidates)\n",
    "            # print(response.parts)\n",
    "            gemini_answer = str(response.prompt_feedback)\n",
    "        \n",
    "    \n",
    "\n",
    "    print(\"GEMINI: \", gemini_answer)\n",
    "\n",
    "    if gemini_answer[-1] == \".\":\n",
    "        gemini_answer = gemini_answer[:-1]\n",
    "    gemini_answer = gemini_answer[-1].lower()\n",
    "\n",
    "    answer = data[\"correct_option_letter\"]\n",
    "\n",
    "    ytrue.append(answer)\n",
    "    ypred.append(gemini_answer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#replace \\n with x\n",
    "ypred = [x.replace(\"\\n\", \"x\") for x in ypred]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_dataset[0]\n",
    "Image.open(eval_dataset[0][\"image_path\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score, classification_report, confusion_matrix\n",
    "from collections import Counter\n",
    "\n",
    "print(accuracy_score(ytrue, ypred))\n",
    "print(confusion_matrix(ytrue, ypred))\n",
    "print(classification_report(ytrue, ypred))\n",
    "\n",
    "print(Counter(ytrue))\n",
    "print(Counter(ypred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import prettytable\n",
    "\n",
    "table = prettytable.PrettyTable()\n",
    "table.field_names = [\"Category\", \"Total\", \"Wrong\", \"Accuracy\"]\n",
    "\n",
    "got_wrong_dict = defaultdict(int)\n",
    "\n",
    "for i in range(len(ypred)):\n",
    "    if ypred[i] != ytrue[i]:\n",
    "        got_wrong_dict[eval_dataset[i][\"category\"]] += 1\n",
    "\n",
    "\n",
    "for k, v in got_wrong_dict.items():\n",
    "    table.add_row([k, category_count[k], v, 1 - (v/category_count[k])])\n",
    "\n",
    "\n",
    "#sort by total\n",
    "table.sortby = \"Total\"\n",
    "table.reversesort = True\n",
    "print(table)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "METRIC_SAVE_DIR = \"performance_metrics/\"\n",
    "\n",
    "eval_dataset_copy = eval_dataset.copy()\n",
    "\n",
    "print(len(eval_dataset_copy))\n",
    "for i,data in enumerate(eval_dataset_copy):\n",
    "    if ypred[i] != ytrue[i]:\n",
    "        # map letter to option f\n",
    "        if \"BLOCK\" in ypred[i]:\n",
    "            data[\"vlm_answer\"] = \"BLOCK\"\n",
    "        else:\n",
    "            # print(ypred[i])\n",
    "            # print(i)\n",
    "            try:\n",
    "                data[\"vlm_answer\"] = data[\"options\"][ord(ypred[i]) - ord(\"a\")]\n",
    "            except Exception as e:\n",
    "                data[\"vlm_answer\"] = ypred[i]\n",
    "                print(ypred[i])\n",
    "\n",
    "\n",
    "import json\n",
    "with open(METRIC_SAVE_DIR + \"gemini_reasoning_results.json\", \"w\") as f:\n",
    "    json.dump(eval_dataset_copy, f)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
