{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import base64\n",
    "from openai import OpenAI\n",
    "from dotenv import load_dotenv\n",
    "import os\n",
    "import json\n",
    "from collections import defaultdict\n",
    "import os\n",
    "import random\n",
    "random.seed(42)\n",
    "load_dotenv()\n",
    "\n",
    "\n",
    "def resize_image(image_path, size):\n",
    "    '''resize image so that the largest edge is atmost size'''\n",
    "    img = Image.open(image_path)\n",
    "    width, height = img.size\n",
    "\n",
    "    if width <= size and height <= size:\n",
    "        return img\n",
    "    \n",
    "    if width > height:\n",
    "        new_width = size\n",
    "        new_height = int(height * (size / width))\n",
    "    else:\n",
    "        new_height = size\n",
    "        new_width = int(width * (size / height))\n",
    "    img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)\n",
    "    return img\n",
    "\n",
    "\n",
    "def encode_image(image_path):\n",
    "    img = resize_image(image_path, 512)\n",
    "    temp_name = \"temp.jpg\"\n",
    "    img.save(temp_name)\n",
    "    with open(temp_name, \"rb\") as image_file:\n",
    "        return base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
    "\n",
    "\n",
    "def construct_mcq(options, correct_option):\n",
    "    correct_option_letter = None\n",
    "    i = \"a\"\n",
    "    mcq = \"\"\n",
    "\n",
    "    for option in options:\n",
    "        if option == correct_option:\n",
    "            correct_option_letter = i\n",
    "        mcq += f\"{i}. {option}\\n\"\n",
    "        i = chr(ord(i) + 1)\n",
    "\n",
    "    mcq = mcq[:-1]\n",
    "    return mcq, correct_option_letter\n",
    "\n",
    "def add_row(content, data, i, with_answer=False):  \n",
    "\n",
    "    \n",
    "    content.append({\n",
    "            \"type\": \"text\",\n",
    "            \"text\": \"Image \"+str(i)+\": \"+data[\"question\"]+\"\\n\"+data[\"mcq\"],\n",
    "        })\n",
    "    \n",
    "    content.append(\n",
    "        {\n",
    "            \"type\": \"image_url\",\n",
    "            \"image_url\": {\n",
    "                \"url\": f\"data:image/jpeg;base64,{encode_image(data[\"image_path\"])}\",\n",
    "                \"detail\": \"low\"\n",
    "            }\n",
    "        }\n",
    "    )\n",
    "    if with_answer:\n",
    "        content.append(\n",
    "            {\n",
    "                \"type\": \"text\",\n",
    "                \"text\": \"Reasoning: {}\".format(data[\"reasoning\"]),\n",
    "            }\n",
    "        )\n",
    "        content.append(\n",
    "            {\n",
    "                \"type\": \"text\",\n",
    "                \"text\": \"Answer: {}\".format(data[\"correct_option_letter\"])\n",
    "            }\n",
    "        )\n",
    "    else:\n",
    "        content.append(\n",
    "            {\n",
    "                \"type\": \"text\",\n",
    "                \"text\": \"Reasoning: \"\n",
    "            }\n",
    "        )\n",
    "    \n",
    "    return content\n",
    "   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "FEWSHOT_JSON = \"illusionVQA/comprehension/fewshot_labels.json\"\n",
    "FEWSHOT_IMAGE_DIR = \"illusionVQA/comprehension/FEW_SHOTS/\"\n",
    "EVAL_JSON = \"illusionVQA/comprehension/eval_labels.json\"\n",
    "EVAL_IMAGE_DIR = \"illusionVQA/comprehension/EVAL/\"\n",
    "\n",
    "client = OpenAI(api_key=os.getenv(\"OPENAI_API_KEY\"))\n",
    "model_name = \"gpt-4-vision-preview\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import json\n",
    "with open(FEWSHOT_JSON) as f:\n",
    "    fewshot_dataset = json.load(f)\n",
    "\n",
    "for data in fewshot_dataset:\n",
    "    data[\"image_path\"] = FEWSHOT_IMAGE_DIR + data[\"image\"]\n",
    "    data[\"mcq\"], data[\"correct_option_letter\"] = construct_mcq(data[\"options\"], data[\"answer\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(EVAL_JSON) as f:\n",
    "    eval_dataset = json.load(f)\n",
    "\n",
    "random.shuffle(eval_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "category_count = defaultdict(int)\n",
    "\n",
    "for data in eval_dataset:\n",
    "    if data[\"image\"] not in os.listdir(EVAL_IMAGE_DIR):\n",
    "        print(data[\"image\"])\n",
    "        continue\n",
    "    data[\"image_path\"] = EVAL_IMAGE_DIR + data[\"image\"]\n",
    "    data[\"mcq\"], data[\"correct_option_letter\"] = construct_mcq(data[\"options\"], data[\"answer\"])\n",
    "    category_count[data[\"category\"]] += 1\n",
    "\n",
    "print(category_count)\n",
    "print(len(eval_dataset))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "content = [\n",
    "    {\n",
    "        \"type\": \"text\",\n",
    "        \"text\": \"\"\"You'll be given an image, an instruction and some choices. You have to select the correct one. Reason about the choices in the context of the question and the image. End your answer with \"Answer\": {letter_of_correct_choice} without the curly brackets. Here are a few examples:\"\"\",\n",
    "    }\n",
    "]\n",
    "\n",
    "\n",
    "i = 1\n",
    "for data in fewshot_dataset:\n",
    "    content = add_row(content, data, i, with_answer=True)\n",
    "    i += 1\n",
    "\n",
    "content.append({\n",
    "                    \"type\": \"text\",\n",
    "                    \"text\": \"Now you try it!\",\n",
    "                })\n",
    "\n",
    "next_idx = i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "ytrue = []\n",
    "ypred = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "MAX_RETRIES = 2\n",
    "for  data in tqdm(eval_dataset):\n",
    "\n",
    "    content_t = add_row(content.copy(), data, next_idx, with_answer=False)\n",
    "    retries = MAX_RETRIES\n",
    "    while True:\n",
    "        try:\n",
    "            response = client.chat.completions.create(\n",
    "                model=\"gpt-4-vision-preview\",\n",
    "                messages=[\n",
    "                    {\n",
    "                        \"role\": \"user\",\n",
    "                        \"content\": content_t,\n",
    "                    }\n",
    "                ],\n",
    "                max_tokens=512\n",
    "            )\n",
    "            gpt4_answer = response.choices[0].message.content.strip()\n",
    "            break\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            retries -= 1\n",
    "            time.sleep(30)\n",
    "            if retries == 0:\n",
    "                gpt4_answer = 'GPT4 could not answer this question.'\n",
    "                print(\"retries exhausted\")\n",
    "                break\n",
    "            continue\n",
    "            \n",
    "    answer = data[\"correct_option_letter\"].strip()\n",
    "    ytrue.append(answer)\n",
    "    ypred.append(gpt4_answer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ypred = [x[-1] for x in ypred]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score, classification_report, confusion_matrix\n",
    "from collections import Counter\n",
    "\n",
    "print(accuracy_score(ytrue, ypred))\n",
    "print(confusion_matrix(ytrue, ypred))\n",
    "print(classification_report(ytrue, ypred))\n",
    "\n",
    "print(Counter(ytrue))\n",
    "print(Counter(ypred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import prettytable\n",
    "from collections import defaultdict\n",
    "\n",
    "table = prettytable.PrettyTable()\n",
    "table.field_names = [\"Category\", \"Total\", \"Wrong\", \"Accuracy\"]\n",
    "\n",
    "got_wrong_dict = defaultdict(int)\n",
    "\n",
    "for i in range(len(ypred)):\n",
    "    if ypred[i] != ytrue[i]:\n",
    "        got_wrong_dict[eval_dataset[i][\"category\"]] += 1\n",
    "    else:\n",
    "        got_wrong_dict[eval_dataset[i][\"category\"]] += 0\n",
    "\n",
    "\n",
    "for k, v in got_wrong_dict.items():\n",
    "    table.add_row([k, category_count[k], v, 1 - (v/category_count[k])])\n",
    "\n",
    "\n",
    "#sort by total\n",
    "table.sortby = \"Total\"\n",
    "table.reversesort = True\n",
    "print(table)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(ypred), len(ytrue)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "METRIC_SAVE_DIR = \"performance_metrics/\"\n",
    "\n",
    "eval_dataset_copy = eval_dataset.copy()\n",
    "\n",
    "print(len(eval_dataset_copy))\n",
    "for i,data in enumerate(eval_dataset_copy):\n",
    "    if i >= len(ypred):\n",
    "        break\n",
    "    if ypred[i] != ytrue[i]:\n",
    "        # map letter to option f\n",
    "        if \"BLOCK\" in ypred[i]:\n",
    "            data[\"vlm_answer\"] = \"BLOCK\"\n",
    "        else:\n",
    "            # print(ypred[i])\n",
    "            # print(i)\n",
    "            try:\n",
    "                data[\"vlm_answer\"] = data[\"options\"][ord(ypred[i]) - ord(\"a\")]\n",
    "            except Exception as e:\n",
    "                data[\"vlm_answer\"] = ypred[i]\n",
    "                print(ypred[i])\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import json\n",
    "with open(METRIC_SAVE_DIR + \"gpt4_reasoning_results.json\", \"w\") as f:\n",
    "    json.dump(eval_dataset_copy, f, indent=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
