{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import json\n",
    "\n",
    "def construct_mcq(options, correct_option):\n",
    "    correct_option_letter = None\n",
    "    i = \"a\"\n",
    "    mcq = \"\"\n",
    "\n",
    "    for option in options:\n",
    "        if option == correct_option:\n",
    "            correct_option_letter = i\n",
    "        mcq += f\"{i}. {option}\\n\"\n",
    "        i = chr(ord(i) + 1)\n",
    "\n",
    "    if correct_option_letter is None:\n",
    "        print(options, correct_option)\n",
    "        raise ValueError(\"Correct option not found in the options\")\n",
    "    \n",
    "    mcq = mcq[:-1]\n",
    "    return mcq, correct_option_letter\n",
    "\n",
    "def resize_image(image_path, size):\n",
    "    '''resize image so that the largest edge is atmost size'''\n",
    "    img = Image.open(image_path)\n",
    "    width, height = img.size\n",
    "\n",
    "    if width <= size and height <= size:\n",
    "        return img\n",
    "    \n",
    "    if width > height:\n",
    "        new_width = size\n",
    "        new_height = int(height * (size / width))\n",
    "    else:\n",
    "        new_height = size\n",
    "        new_width = int(width * (size / height))\n",
    "    img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)\n",
    "    return img\n",
    "\n",
    "\n",
    "def add_row(content, data, i, with_answer=False):  \n",
    "\n",
    "    content.append(\"Image \"+str(i)+\": \")\n",
    "    content.append(resize_image(data[\"image_path\"], 512))\n",
    "    content.append(data[\"question\"])\n",
    "    content.append(data[\"mcq\"])\n",
    "\n",
    "    if with_answer:\n",
    "        content.append(\"Answer {}: {}\".format(i, data[\"correct_option_letter\"]))\n",
    "    else:\n",
    "        content.append(\"Answer {}: \".format(i))\n",
    "    \n",
    "    return content\n",
    "   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import google.generativeai as genai\n",
    "\n",
    "GOOGLE_API_KEY='YOUR_API_KEY_HERE'\n",
    "\n",
    "FEWSHOT_JSON = \"illusionVQA/comprehension/fewshot_labels.json\"\n",
    "FEWSHOT_IMAGE_DIR = \"illusionVQA/comprehension/FEW_SHOTS/\"\n",
    "EVAL_JSON = \"illusionVQA/comprehension/eval_labels.json\"\n",
    "EVAL_IMAGE_DIR = \"illusionVQA/comprehension/EVAL/\"\n",
    "\n",
    "genai.configure(api_key=GOOGLE_API_KEY)\n",
    "model = genai.GenerativeModel('gemini-pro-vision')\n",
    "model_name = 'gemini-pro-vision'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(FEWSHOT_JSON) as f:\n",
    "    fewshot_dataset = json.load(f)\n",
    "\n",
    "for data in fewshot_dataset:\n",
    "    data[\"image_path\"] = FEWSHOT_IMAGE_DIR + data[\"image\"]\n",
    "    data[\"mcq\"], data[\"correct_option_letter\"] = construct_mcq(data[\"options\"], data[\"answer\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import json\n",
    "with open(EVAL_JSON) as f:\n",
    "    eval_dataset = json.load(f)\n",
    "\n",
    "from collections import defaultdict\n",
    "\n",
    "category_count = defaultdict(int)\n",
    "import os\n",
    "for data in eval_dataset:\n",
    "    if data[\"image\"] not in os.listdir(EVAL_IMAGE_DIR):\n",
    "        print(data[\"image\"])\n",
    "        continue\n",
    "    data[\"image_path\"] = EVAL_IMAGE_DIR + data[\"image\"]\n",
    "    data[\"mcq\"], data[\"correct_option_letter\"] = construct_mcq(data[\"options\"], data[\"answer\"])\n",
    "    category_count[data[\"category\"]] += 1\n",
    "\n",
    "print(category_count)\n",
    "print(len(eval_dataset))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 0-shot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "content = [\"You'll be given an image, an instruction and some options. You have to select the correct one. Do not explain your reasoning. Answer with the option's letter from the given choices directly.\"\n",
    "]\n",
    "\n",
    "next_data_idx = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4-shot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# content = [\"You'll be given an image, an instruction and some choices. You have to select the correct one. Do not explain your reasoning. Answer with the option's letter from the given choices directly. Here are a few examples:\"\n",
    "# ]\n",
    "\n",
    "# i = 1\n",
    "# for data in fewshot_dataset:\n",
    "#     content = add_row(content, data, i, with_answer=True)\n",
    "#     i += 1\n",
    "# content.append(\"Now you try it.\")\n",
    "\n",
    "# next_data_idx = i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "ytrue = []\n",
    "ypred = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for  data in tqdm(eval_dataset):\n",
    "    content_t = add_row(content.copy(), data, next_data_idx, with_answer=False)\n",
    "    # print(content_t)\n",
    "    while True:\n",
    "        try:\n",
    "            response = model.generate_content(content_t,\n",
    "                                              safety_settings=[\n",
    "                {\n",
    "                    \"category\": \"HARM_CATEGORY_HARASSMENT\",\n",
    "                    \"threshold\": \"HIGH\",\n",
    "                },\n",
    "                {\n",
    "                    \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n",
    "                    \"threshold\": \"HIGH\",\n",
    "                },\n",
    "                {\n",
    "                    \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n",
    "                    \"threshold\": \"HIGH\",\n",
    "                },\n",
    "                {\n",
    "                    \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n",
    "                    \"threshold\": \"HIGH\",\n",
    "                },\n",
    "                ]\n",
    "            )\n",
    "            break\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            print(\"Internal Error\")\n",
    "            continue\n",
    "    \n",
    "    try:\n",
    "        gemini_answer = response.text\n",
    "    except Exception as e:\n",
    "        try:\n",
    "            gemini_answer = response.parts[0].text\n",
    "        except Exception as e:\n",
    "            print(\"External Error:\", response.prompt_feedback)\n",
    "\n",
    "            gemini_answer = str(response.prompt_feedback)\n",
    "        \n",
    "    if gemini_answer[-1] == \".\":\n",
    "        gemini_answer = gemini_answer[:-1]\n",
    "    gemini_answer = gemini_answer[-1].lower()\n",
    "\n",
    "    answer = data[\"correct_option_letter\"]\n",
    "\n",
    "    ytrue.append(answer)\n",
    "    ypred.append(gemini_answer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#replace \\n with x\n",
    "ypred = [x.replace(\"\\n\", \"x\") for x in ypred]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(ytrue), len(ypred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set(ytrue), set(ypred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_dataset[0]\n",
    "Image.open(eval_dataset[0][\"image_path\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score, classification_report, confusion_matrix\n",
    "from collections import Counter\n",
    "\n",
    "print(accuracy_score(ytrue, ypred))\n",
    "print(confusion_matrix(ytrue, ypred))\n",
    "print(classification_report(ytrue, ypred))\n",
    "\n",
    "print(Counter(ytrue))\n",
    "print(Counter(ypred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import prettytable\n",
    "\n",
    "table = prettytable.PrettyTable()\n",
    "table.field_names = [\"Category\", \"Total\", \"Wrong\", \"Accuracy\"]\n",
    "\n",
    "got_wrong_dict = defaultdict(int)\n",
    "\n",
    "for i in range(len(ytrue)):\n",
    "    if ypred[i] != ytrue[i]:\n",
    "        got_wrong_dict[eval_dataset[i][\"category\"]] += 1\n",
    "    else:\n",
    "        got_wrong_dict[eval_dataset[i][\"category\"]] += 0\n",
    "\n",
    "for k, v in got_wrong_dict.items():\n",
    "    table.add_row([k, category_count[k], v, 1 - (v/category_count[k])])\n",
    "\n",
    "\n",
    "#sort by total\n",
    "table.sortby = \"Total\"\n",
    "table.reversesort = True\n",
    "print(table)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "METRIC_SAVE_DIR = \"performance_metrics/\"\n",
    "\n",
    "eval_dataset_copy = eval_dataset.copy()\n",
    "\n",
    "print(len(eval_dataset_copy))\n",
    "for i,data in enumerate(eval_dataset_copy):\n",
    "    if ypred[i] != ytrue[i]:\n",
    "        # map letter to option f\n",
    "        if \"BLOCK\" in ypred[i]:\n",
    "            data[\"vlm_answer\"] = \"BLOCK\"\n",
    "        else:\n",
    "            # print(ypred[i])\n",
    "            # print(i)\n",
    "            try:\n",
    "                data[\"vlm_answer\"] = data[\"options\"][ord(ypred[i]) - ord(\"a\")]\n",
    "            except Exception as e:\n",
    "                data[\"vlm_answer\"] = ypred[i]\n",
    "                print(ypred[i])\n",
    "\n",
    "\n",
    "import json\n",
    "with open(METRIC_SAVE_DIR + \"gemini_results.json\", \"w\") as f:\n",
    "    json.dump(eval_dataset_copy, f)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
