{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import base64\n",
    "from openai import OpenAI\n",
    "from dotenv import load_dotenv\n",
    "import os\n",
    "import json\n",
    "from collections import defaultdict\n",
    "import os\n",
    "import random\n",
    "from tqdm import tqdm\n",
    "\n",
    "random.seed(42)\n",
    "load_dotenv()\n",
    "\n",
    "\n",
    "def resize_image(image_path, size):\n",
    "    '''resize image so that the largest edge is atmost size'''\n",
    "    img = Image.open(image_path)\n",
    "    width, height = img.size\n",
    "\n",
    "    if width <= size and height <= size:\n",
    "        return img\n",
    "    \n",
    "    if width > height:\n",
    "        new_width = size\n",
    "        new_height = int(height * (size / width))\n",
    "    else:\n",
    "        new_height = size\n",
    "        new_width = int(width * (size / height))\n",
    "    img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)\n",
    "    return img\n",
    "\n",
    "\n",
    "def encode_image(image_path):\n",
    "    img = resize_image(image_path, 512)\n",
    "    temp_name = \"temp.jpg\"\n",
    "    img.save(temp_name)\n",
    "    with open(temp_name, \"rb\") as image_file:\n",
    "        return base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
    "\n",
    "\n",
    "def construct_mcq(options, correct_option):\n",
    "    correct_option_letter = None\n",
    "    i = \"a\"\n",
    "    mcq = \"\"\n",
    "\n",
    "    for option in options:\n",
    "        if option == correct_option:\n",
    "            correct_option_letter = i\n",
    "        mcq += f\"{i}. {option}\\n\"\n",
    "        i = chr(ord(i) + 1)\n",
    "\n",
    "    mcq = mcq[:-1]\n",
    "    return mcq, correct_option_letter\n",
    "\n",
    "def add_row(content, data, i, with_answer=False):  \n",
    "\n",
    "    \n",
    "    content.append({\n",
    "            \"type\": \"text\",\n",
    "            \"text\": \"Image \"+str(i)+\": \"+data[\"question\"]+\"\\n\"+data[\"mcq\"],\n",
    "        })\n",
    "    \n",
    "    content.append(\n",
    "        {\n",
    "            \"type\": \"image_url\",\n",
    "            \"image_url\": {\n",
    "                \"url\": f\"data:image/jpeg;base64,{encode_image(data[\"image_path\"])}\",\n",
    "                \"detail\": \"low\"\n",
    "            }\n",
    "        }\n",
    "    )\n",
    "    if with_answer:\n",
    "        content.append(\n",
    "            {\n",
    "                \"type\": \"text\",\n",
    "                \"text\": \"Answer {}: \".format(i)+data[\"correct_option_letter\"],\n",
    "            }\n",
    "        )\n",
    "    else:\n",
    "        content.append(\n",
    "            {\n",
    "                \"type\": \"text\",\n",
    "                \"text\": \"Answer {}: \".format(i),\n",
    "            }\n",
    "        )\n",
    "    \n",
    "    return content\n",
    "   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "FEWSHOT_JSON = \"illusionVQA/comprehension/fewshot_labels.json\"\n",
    "FEWSHOT_IMAGE_DIR = \"illusionVQA/comprehension/FEW_SHOTS/\"\n",
    "EVAL_JSON = \"illusionVQA/comprehension/eval_labels.json\"\n",
    "EVAL_IMAGE_DIR = \"illusionVQA/comprehension/EVAL/\"\n",
    "\n",
    "client = OpenAI(api_key=os.getenv(\"OPENAI_API_KEY\"))\n",
    "model_name = \"gpt-4-vision-preview\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import json\n",
    "with open(FEWSHOT_JSON) as f:\n",
    "    fewshot_dataset = json.load(f)\n",
    "\n",
    "for data in fewshot_dataset:\n",
    "    data[\"image_path\"] = FEWSHOT_IMAGE_DIR + data[\"image\"]\n",
    "    data[\"mcq\"], data[\"correct_option_letter\"] = construct_mcq(data[\"options\"], data[\"answer\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(EVAL_JSON) as f:\n",
    "    eval_dataset = json.load(f)\n",
    "\n",
    "random.shuffle(eval_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "category_count = defaultdict(int)\n",
    "\n",
    "for data in eval_dataset:\n",
    "    if data[\"image\"] not in os.listdir(EVAL_IMAGE_DIR):\n",
    "        print(data[\"image\"])\n",
    "        continue\n",
    "    data[\"image_path\"] = EVAL_IMAGE_DIR + data[\"image\"]\n",
    "    data[\"mcq\"], data[\"correct_option_letter\"] = construct_mcq(data[\"options\"], data[\"answer\"])\n",
    "    category_count[data[\"category\"]] += 1\n",
    "\n",
    "print(category_count)\n",
    "print(len(eval_dataset))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 0-shot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "content = [\n",
    "    {\n",
    "        \"type\": \"text\",\n",
    "        \"text\": \"You'll be given an image, an instruction and some options. You have to select the correct one. Do not explain your reasoning. Answer with only the letter which corresponds to the correct option. Do not repeat the entire answer.\",\n",
    "    }\n",
    "]\n",
    "next_idx = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4-shot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# content = [\n",
    "#     {\n",
    "#         \"type\": \"text\",\n",
    "#         \"text\": \"You'll be given an image, an instruction and some choices. You have to select the correct one. Do not explain your reasoning. Answer with the option's letter from the given choices directly. Here are a few examples:\",\n",
    "#     }\n",
    "# ]\n",
    "\n",
    "\n",
    "# i = 1\n",
    "# for data in fewshot_dataset:\n",
    "#     content = add_row(content, data, i, with_answer=True)\n",
    "#     i += 1\n",
    "\n",
    "# content.append({\n",
    "#                     \"type\": \"text\",\n",
    "#                     \"text\": \"Now you try it!\",\n",
    "#                 })\n",
    "\n",
    "# next_idx = i"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "ytrue = []\n",
    "ypred = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "MAX_RETRIES = 2\n",
    "for  data in tqdm(eval_dataset):\n",
    "\n",
    "    content_t = add_row(content.copy(), data, next_idx, with_answer=False)\n",
    "    retries = MAX_RETRIES\n",
    "    while True:\n",
    "        try:\n",
    "            response = client.chat.completions.create(\n",
    "                model=\"gpt-4-vision-preview\",\n",
    "                messages=[\n",
    "                    {\n",
    "                        \"role\": \"user\",\n",
    "                        \"content\": content_t,\n",
    "                    }\n",
    "                ],\n",
    "                max_tokens=5,\n",
    "            )\n",
    "            gpt4_answer = response.choices[0].message.content.strip()[0].lower()\n",
    "            break\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            retries -= 1\n",
    "            time.sleep(30)\n",
    "            if retries == 0:\n",
    "                gpt4_answer = 'GPT4 could not answer this question.'\n",
    "                print(\"retries exhausted\")\n",
    "                break\n",
    "            continue\n",
    "            \n",
    "  \n",
    "    answer = data[\"correct_option_letter\"].strip()\n",
    "    ytrue.append(answer)\n",
    "    ypred.append(gpt4_answer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(ypred)):\n",
    "    if ypred[i] == 'GPT4 could not answer this question.':\n",
    "        print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(ytrue), len(ypred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Result Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score, classification_report, confusion_matrix\n",
    "from collections import Counter\n",
    "\n",
    "print(accuracy_score(ytrue, ypred))\n",
    "print(confusion_matrix(ytrue, ypred))\n",
    "print(classification_report(ytrue, ypred))\n",
    "\n",
    "print(Counter(ytrue))\n",
    "print(Counter(ypred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import prettytable\n",
    "from collections import defaultdict\n",
    "\n",
    "table = prettytable.PrettyTable()\n",
    "table.field_names = [\"Category\", \"Total\", \"Wrong\", \"Accuracy\"]\n",
    "\n",
    "got_wrong_dict = defaultdict(int)\n",
    "\n",
    "for i in range(len(ypred)):\n",
    "    if ypred[i] != ytrue[i]:\n",
    "        got_wrong_dict[eval_dataset[i][\"category\"]] += 1\n",
    "    else:\n",
    "        got_wrong_dict[eval_dataset[i][\"category\"]] += 0\n",
    "\n",
    "\n",
    "for k, v in got_wrong_dict.items():\n",
    "    table.add_row([k, category_count[k], v, 1 - (v/category_count[k])])\n",
    "\n",
    "\n",
    "#sort by total\n",
    "table.sortby = \"Total\"\n",
    "table.reversesort = True\n",
    "print(table)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "METRIC_SAVE_DIR = \"performance_metrics/\"\n",
    "\n",
    "eval_dataset_copy = eval_dataset.copy()\n",
    "\n",
    "print(len(eval_dataset_copy))\n",
    "for i,data in enumerate(eval_dataset_copy):\n",
    " \n",
    "    if ypred[i] != ytrue[i]:\n",
    "        # map letter to option f\n",
    "        if \"BLOCK\" in ypred[i]:\n",
    "            data[\"vlm_answer\"] = \"BLOCK\"\n",
    "        else:\n",
    "            try:\n",
    "                data[\"vlm_answer\"] = data[\"options\"][ord(ypred[i]) - ord(\"a\")]\n",
    "            except Exception as e:\n",
    "                data[\"vlm_answer\"] = ypred[i]\n",
    "                print(ypred[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import json\n",
    "with open(METRIC_SAVE_DIR + \"gpt4_results.json\", \"w\") as f:\n",
    "    json.dump(eval_dataset_copy, f, indent=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
