{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Large RAM is required to load the larger models. Running on GPU can optimize inference speed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "import requests\n",
    "from lavis.models import load_model_and_preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "    print(\"CUDA device is available\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_scene_data(filepath):\n",
    "  # Open the file.\n",
    "  with open(filepath, 'r') as f:\n",
    "    # Load the data into a Python object.\n",
    "    data = json.load(f)\n",
    "\n",
    "  return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scene_data = load_scene_data('/home/user/Desktop/vqa_research/ptr_val/ptr_val_scene_mapping.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_dir = './val_images/PTR_val_007239.png'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_filename = scene_data['PTR_val_007239']['metadata']['image_filename']\n",
    "img_dir = os.path.join('./val_images',img_filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Load image from directory**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_image = Image.open(img_dir).convert('RGB')   \n",
    "display(raw_image.resize((596, 437)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# scene_data['0']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# blip2-flant5 demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import sys\n",
    "# if 'google.colab' in sys.modules:\n",
    "#     print('Running in Colab.')\n",
    "#     !pip3 install salesforce-lavis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "import requests\n",
    "from lavis.models import load_model_and_preprocess"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load an example image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# img_url = 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png' \n",
    "# raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')   \n",
    "# display(raw_image.resize((596, 437)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup device to use\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load pretrained/finetuned BLIP2 captioning model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# we associate a model with its preprocessors to make it easier for inference.\n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_t5\", model_type=\"pretrain_flant5xxl\", is_eval=True, device=device\n",
    "# )\n",
    "\n",
    "# Other available models:\n",
    "# \n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_opt\", model_type=\"pretrain_opt2.7b\", is_eval=True, device=device\n",
    "# )\n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_opt\", model_type=\"pretrain_opt6.7b\", is_eval=True, device=device\n",
    "# )\n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_opt\", model_type=\"caption_coco_opt2.7b\", is_eval=True, device=device\n",
    "# )\n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_opt\", model_type=\"caption_coco_opt6.7b\", is_eval=True, device=device\n",
    "# )\n",
    "#\n",
    "model, vis_processors, _ = load_model_and_preprocess(\n",
    "    name=\"blip2_t5\", model_type=\"pretrain_flant5xl\", is_eval=True, device=device\n",
    ")\n",
    "#\n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_t5\", model_type=\"caption_coco_flant5xl\", is_eval=True, device=device\n",
    "# )\n",
    "\n",
    "vis_processors.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Design Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_scene_metadata(scene_number, scene_data):\n",
    "    \"\"\"\n",
    "    Get metadata for a given scene number in the PTR dataset.\n",
    "    \"\"\"\n",
    "    # Find the scene with the given number\n",
    "    scene = scene_data.get(scene_number)\n",
    "    if scene is None:\n",
    "        return f'Scene {scene_number} not found.'\n",
    "    output = []\n",
    "    output_qns = []\n",
    "    output_ans = []\n",
    "\n",
    "    # Add metadata to output\n",
    "    metadata = scene['metadata']\n",
    "    output.append(f'Scene {scene_number}:\\n')\n",
    "    output.append(f'  Objects: {len(metadata[\"objects\"])}')\n",
    "    for obj in metadata[\"objects\"]:\n",
    "        # print(obj.keys())\n",
    "        output.append(f'    Object:')\n",
    "        output.append(f'      Category: {obj[\"category\"]}')\n",
    "        output.append(f'      Rotation: {obj[\"rotation\"]}')\n",
    "        output.append(f'      Scale: {obj[\"scale\"]}')\n",
    "        if \"stability\" in obj:\n",
    "            output.append(f'      Stability: {obj[\"stability\"]}')\n",
    "        output.append(f'      3D Coords: {obj[\"3d_coords\"]}')\n",
    "        output.append(f'      Support: {obj[\"pixel_coords\"]}')\n",
    "        output.append(f'      Part Colors: {obj[\"part_color\"]}')\n",
    "        output.append(f'      Part Count: {obj[\"part_count\"]}')\n",
    "        output.append('\\n')\n",
    "    output.append(f'  Relationships: {metadata[\"relationships\"]}\\n')\n",
    "    output.append(f'  Directions: {metadata[\"directions\"]}\\n')\n",
    "    output.append(f'  Image Filename: {metadata[\"image_filename\"]}\\n')\n",
    "    if \"physics\" in metadata:\n",
    "        output.append(f'  Physics: {metadata[\"physics\"]}\\n')\n",
    "    output.append(f'  Cam location: {metadata[\"cam_location\"]}\\n')\n",
    "    output.append(f'  Cam Rotation: {metadata[\"cam_rotation\"]}\\n')\n",
    "    \n",
    "\n",
    "    # Combine the output strings and return\n",
    "    return (\"\\n\".join(output))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_prompt(scene_number, scene_data,cot=False):\n",
    "    # metadata, qns, ans = get_scene_input(scene_number,scene_data)\n",
    "    # setup = 'Given the following scene:\\n'\n",
    "    scene_setup = \"The objects or things can have the following categories: 'Bed', 'Cart', 'Chair', 'Refrigerator', 'Table'. The different parts of the things can have the following categories: arm', 'arm horizontal bar', 'arm vertical bar', 'back', 'behind', 'body', 'central support', 'door', 'drawer', 'leg', 'leg bar', 'pedestal', 'seat', 'shelf', 'sleep area', 'top', 'wheel'. The things or objects can move in the following directions to make themselves stable: 'front', 'left', 'right'. The objects or their parts can have the following colors: 'blue', 'brown', 'cyan', 'gray', 'green', 'purple', 'red', 'yellow'. For numeric answers, give an answer in integers and not in words.\"\n",
    "    if not cot:\n",
    "        instruction = 'Now answer the following question in one word.'\n",
    "    else:\n",
    "        instruction = 'Now answer the following questions with step-by-step reasoning.'\n",
    "    prompt = scene_setup + instruction\n",
    "    if cot:\n",
    "        prompt += 'Give the final one word answer at the end of your reasoning. Thus, your response format should be:\\nReasoning for the answer\\nFinal answer:\\nFinal one word answer'\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Iterate dataset and run on PTR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(scene_data.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = 'PTR_val_003021'\n",
    "i = 3\n",
    "# setup_prompt = \"The objects or things can have the following categories: 'Bed', 'Cart', 'Chair', 'Refrigerator', 'Table'. The different parts of the things can have the following categories: arm', 'arm horizontal bar', 'arm vertical bar', 'back', 'behind', 'body', 'central support', 'door', 'drawer', 'leg', 'leg bar', 'pedestal', 'seat', 'shelf', 'sleep area', 'top', 'wheel'. The things or objects can move in the following directions to make themselves stable: 'front', 'left', 'right'. The objects or their parts can have the following colors: 'blue', 'brown', 'cyan', 'gray', 'green', 'purple', 'red', 'yellow'. For numeric answers, give an answer in integers and not in words. Now answer the following question in one word:\\n\"        \n",
    "scene_setup = \"The objects or things can have the following categories: 'Bed', 'Cart', 'Chair', 'Refrigerator', 'Table'. The different parts of the things can have the following categories: arm', 'arm horizontal bar', 'arm vertical bar', 'back', 'behind', 'body', 'central support', 'door', 'drawer', 'leg', 'leg bar', 'pedestal', 'seat', 'shelf', 'sleep area', 'top', 'wheel'. The things or objects can move in the following directions to make themselves stable: 'front', 'left', 'right'. The objects or their parts can have the following colors: 'blue', 'brown', 'cyan', 'gray', 'green', 'purple', 'red', 'yellow'. For numeric answers, give an answer in integers and not in words.\\n\"\n",
    "instruction = 'Now answer the following questions with step-by-step reasoning.'\n",
    "format_instruction = 'Give the final one word answer at the end of your reasoning. Thus, your response format must be:\\nReasoning for the answer\\nFinal answer:\\nFinal one word answer\\n'\n",
    "question = scene_data[key]['questions'][i]['question']\n",
    "qn_to_ask = scene_setup+ instruction+ format_instruction +'Question: '+ question + \" Answer: Let's think step by step.\"\n",
    "print(qn_to_ask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "instructions = get_prompt(i, scene_data)\n",
    "metadata = get_scene_metadata(i,scene_data)\n",
    "setup = 'Given the following scene:\\n'\n",
    "scene = scene_data.get(scene_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# iterate through the dataset\n",
    "\n",
    "def run_on_model(model,start = 0, end = 4, output_file='response_data.json'):\n",
    "    for i, key in tqdm.tqdm(enumerate(scene_data)):\n",
    "        if i< start:\n",
    "            continue\n",
    "        if i> end:\n",
    "            break\n",
    "\n",
    "        img_filename = scene_data[key]['metadata']['image_filename']\n",
    "        img_dir = os.path.join('./val_images',img_filename)\n",
    "        raw_image = Image.open(img_dir).convert('RGB') \n",
    "#         # Prepare image\n",
    "        image = vis_processors[\"eval\"](raw_image).unsqueeze(0).to(device)\n",
    "\n",
    "        questions = scene_data[key]['questions']\n",
    "        \n",
    "        # Prompt Setup\n",
    "        instructions = get_prompt(key, scene_data)\n",
    "        cot_instructions = get_prompt(key, scene_data, True)\n",
    "        metadata = get_scene_metadata(key,scene_data)\n",
    "        setup = 'Given the following scene:\\n'\n",
    "        \n",
    "        for j,q in enumerate(questions):\n",
    "            ground_truth = q['answer']\n",
    "            question = q['question']\n",
    "            qn_to_ask = setup+ metadata + instructions +\"\\n\"+\"Question:\"+ question\n",
    "#             print(qn_to_ask)\n",
    "#             qn_to_ask = setup_prompt +'Question: '+ question + ' Answer:'\n",
    "#             model_ans = model.generate({\n",
    "#                 'image': image,\n",
    "#                 'prompt': qn_to_ask\n",
    "#             })\n",
    "#             # print(model_ans)\n",
    "#             scene_data[key]['questions'][j]['blip2_t5'] = model_ans\n",
    "        \n",
    "        if output_file and i%1000==1:\n",
    "        with open(output_file, 'w') as f:\n",
    "            json.dump(scene_data, f)\n",
    "        \n",
    "    # Save scene data to JSON file after each iteration\n",
    "    if output_file:\n",
    "        with open(output_file, 'w') as f:\n",
    "            json.dump(scene_data, f)\n",
    "        \n",
    "    return scene_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response_map = run_on_model('model',0,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response_map['0']['questions'][7]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
