{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Large RAM is required to load the larger models. Running on GPU can optimize inference speed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "import requests\n",
    "from lavis.models import load_model_and_preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "    print(\"CUDA device is available\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Replace with the path to the directory containing the CLEVR dataset\n",
    "CLEVR_DIR = './CLEVR_v1.0'\n",
    "SCENE_MAPPING_DIR = './'\n",
    "\n",
    "def load_data():\n",
    "    \"\"\"\n",
    "    Load metadata and questions for all scenes in the CLEVR dataset.\n",
    "    \"\"\"\n",
    "    print('opening scenes...')\n",
    "    # Load metadata for all scenes\n",
    "    with open(os.path.join(CLEVR_DIR, 'scenes', 'CLEVR_train_scenes.json')) as f:\n",
    "        scenes = json.load(f)['scenes']\n",
    "    \n",
    "    print('Opening questions...')\n",
    "    # Load questions for all scenes\n",
    "    with open(os.path.join(CLEVR_DIR, 'questions', 'CLEVR_train_questions.json')) as f:\n",
    "        questions = json.load(f)['questions']\n",
    "    \n",
    "    print('Creating map...')\n",
    "    # Create a dictionary mapping scene numbers to scene data and questions\n",
    "    scene_data = {}\n",
    "    for scene in tqdm.tqdm(scenes):\n",
    "        scene_number = scene['image_index']\n",
    "        scene_questions = [q for q in questions if q['image_index'] == scene_number]\n",
    "        scene_data[scene_number] = {'metadata': scene, 'questions': scene_questions}\n",
    "\n",
    "    return scene_data\n",
    "\n",
    "def load_scene_data():\n",
    "    \"\"\"\n",
    "    Loads scene_data.json \n",
    "    \"\"\"\n",
    "    with open(os.path.join(SCENE_MAPPING_DIR, 'clevr_val_scene_mapping.json')) as f:\n",
    "        scene_data = json.load(f)\n",
    "    \n",
    "    return scene_data\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_scene_data(scene_number, scene_data):\n",
    "    \"\"\"\n",
    "    Get metadata and questions for a given scene number in the CLEVR dataset.\n",
    "    \"\"\"\n",
    "    # Find the scene with the given number\n",
    "    scene = scene_data.get(scene_number)\n",
    "    if scene is None:\n",
    "        print(f'Scene {scene_number} not found.')\n",
    "        return\n",
    "    \n",
    "    # Print metadata\n",
    "    metadata = scene['metadata']\n",
    "    print()\n",
    "    print(f'Scene {scene_number}:')\n",
    "    print()\n",
    "    print(f'  Objects: {len(metadata[\"objects\"])}')\n",
    "    for obj in metadata[\"objects\"]:\n",
    "        print(f'    Object:')\n",
    "        print(f'      Color: {obj[\"color\"]}')\n",
    "        print(f'      Size: {obj[\"size\"]}')\n",
    "        print(f'      Rotation: {obj[\"rotation\"]}')\n",
    "        print(f'      Shape: {obj[\"shape\"]}')\n",
    "        print(f'      Material: {obj[\"material\"]}')\n",
    "        print(f'      3D Coords: {obj[\"3d_coords\"]}')\n",
    "        print(f'      Pixel Coords: {obj[\"pixel_coords\"]}')\n",
    "        print()\n",
    "    print(f'  Relationships: {metadata[\"relationships\"]}')\n",
    "    print()\n",
    "    print(f'  Directions: {metadata[\"directions\"]}')\n",
    "    print()\n",
    "    print(f'  Image Filename: {metadata[\"image_filename\"]}')\n",
    "    print()\n",
    "    \n",
    "    # Print questions and answers\n",
    "    questions = scene['questions']\n",
    "    print(f'  Questions: {len(questions)}')\n",
    "    for q in questions:\n",
    "        print(f'    Question: {q[\"question\"]}')\n",
    "        print()\n",
    "        print(f'      Answer: {q[\"answer\"]}')\n",
    "        print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "scene_data = load_scene_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_filename = scene_data['0']['metadata']['image_filename']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_dir = os.path.join('/home/user/Desktop/vqa_research/CLEVR_v1.0/images/val',img_filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Load image from directory**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_image = Image.open(img_dir).convert('RGB')   \n",
    "display(raw_image.resize((596, 437)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scene_data['0']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# blip2-flant5 demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import sys\n",
    "# if 'google.colab' in sys.modules:\n",
    "#     print('Running in Colab.')\n",
    "#     !pip3 install salesforce-lavis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "import requests\n",
    "from lavis.models import load_model_and_preprocess"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load an example image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# img_url = 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png' \n",
    "# raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')   \n",
    "# display(raw_image.resize((596, 437)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# setup device to use\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load pretrained/finetuned BLIP2 captioning model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# we associate a model with its preprocessors to make it easier for inference.\n",
    "model, vis_processors, _ = load_model_and_preprocess(\n",
    "    name=\"blip2_t5\", model_type=\"pretrain_flant5xxl\", is_eval=True, device=device\n",
    ")\n",
    "\n",
    "# Other available models:\n",
    "# \n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_opt\", model_type=\"pretrain_opt2.7b\", is_eval=True, device=device\n",
    "# )\n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_opt\", model_type=\"pretrain_opt6.7b\", is_eval=True, device=device\n",
    "# )\n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_opt\", model_type=\"caption_coco_opt2.7b\", is_eval=True, device=device\n",
    "# )\n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_opt\", model_type=\"caption_coco_opt6.7b\", is_eval=True, device=device\n",
    "# )\n",
    "#\n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_t5\", model_type=\"pretrain_flant5xl\", is_eval=True, device=device\n",
    "# )\n",
    "#\n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_t5\", model_type=\"caption_coco_flant5xl\", is_eval=True, device=device\n",
    "# )\n",
    "\n",
    "vis_processors.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Iterate dataset and run on clevr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Extract val files if needed\n",
    "# import zipfile\n",
    "\n",
    "# # Open the zip file in read mode.\n",
    "# with zipfile.ZipFile(\"val.zip\", \"r\") as zip_file:\n",
    "\n",
    "#     # Extract all the files to the current working directory.\n",
    "#     zip_file.extractall()\n",
    "\n",
    "# # The zip file has now been unzipped to the current working directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "scene_data['14999']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# iterate through the dataset\n",
    "\n",
    "def run_on_model(model,start = 0, end = 14999, output_file='response_data_blip2_t5_xxl_COT.json'):\n",
    "    for i, key in tqdm.tqdm(enumerate(scene_data)):\n",
    "        if i< start:\n",
    "            continue\n",
    "        if i> end:\n",
    "            break\n",
    "            \n",
    "        if i%1000 == 1:\n",
    "            # Save scene data to JSON file after each iteration\n",
    "            if output_file:\n",
    "                with open(output_file, 'w') as f:\n",
    "                    json.dump(scene_data, f)\n",
    "\n",
    "        img_filename = scene_data[key]['metadata']['image_filename']\n",
    "        img_dir = os.path.join('./val',img_filename)\n",
    "        raw_image = Image.open(img_dir).convert('RGB') \n",
    "        # Prepare image\n",
    "        image = vis_processors[\"eval\"](raw_image).unsqueeze(0).to(device)\n",
    "\n",
    "        scene_setup = 'You may assume that any metal object is shiny, and any rubber object is not shiny (\"matte\"). All objects are either \"metal\" or \"rubber\", and in 2 sizes: \"large\" or \"small\". All objects are one of the following colours: \"blue\", \"brown\", \"cyan\", \"gray\", \"green\", \"purple\", \"red\", \"yellow\". All objects are one of the following shapes: \"cube\", \"cylinder\", \"sphere\". For numeric answers, give an integer and not in words.\\n'\n",
    "        instruction = 'Now answer the following questions with step-by-step reasoning.'\n",
    "        format_instruction = 'Give the final one word answer at the end of your reasoning. Thus, your response format must be:\\nReasoning for the answer\\nFinal answer:\\nFinal one word answer'\n",
    "        \n",
    "        questions = scene_data[key]['questions']\n",
    "        \n",
    "        for i,q in enumerate(questions):\n",
    "            ground_truth = q['answer']\n",
    "            question = q['question']\n",
    "            qn_to_ask = scene_setup+ instruction+ format_instruction +'Question: '+ question + \"Answer: Let's think step by step.\"\n",
    "            model_ans = model.generate({\n",
    "                'image': image,\n",
    "                'prompt': qn_to_ask\n",
    "            }, max_length = 200)\n",
    "            scene_data[key]['questions'][i]['blip2_t5_xxl_cot'] = model_ans\n",
    "        \n",
    "    # Save scene data to JSON file after each iteration\n",
    "    if output_file:\n",
    "        with open(output_file, 'w') as f:\n",
    "            json.dump(scene_data, f)\n",
    "        \n",
    "    return scene_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "response_map = run_on_model(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response_map['0']['questions'][7]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
