{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31194f6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from PIL import Image\n",
    "import os\n",
    "\n",
    "with open(\"./val_sceneGraphs.json\",'rb') as f:\n",
    "    sgs =json.load(f)\n",
    "\n",
    "with open(\"./val_balanced_questions.json\",'rb') as f:\n",
    "    qset =json.load(f)\n",
    "\n",
    "img_dir = \"./images/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bf8373a",
   "metadata": {},
   "outputs": [],
   "source": [
    "list(qset.keys())[0:10], len(qset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0912f769",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_struct = qset['17197213']\n",
    "q_struct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fee3d12",
   "metadata": {},
   "outputs": [],
   "source": [
    "img = Image.open(os.path.join(img_dir, q_struct['imageId']+'.jpg'))\n",
    "img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5875173a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the image to the current working directory\n",
    "filename = q_struct['imageId']+'.jpg'\n",
    "img.save(filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e1060be",
   "metadata": {},
   "outputs": [],
   "source": [
    "filename"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19c6c396",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_sg = sgs[q_struct['imageId']]\n",
    "q_sg.keys() #sometimes other attrbs such as 'weather', 'location' might be present"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "491c148f",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_sg['objects'] #list of objects, attributes and relations, and bounding boxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce9fbfe4",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_struct['semantic'] #semantic program"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f54d8591",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_struct['types'] #question type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af324947",
   "metadata": {},
   "outputs": [],
   "source": [
    "program_len_distrb = {}\n",
    "program_len_distrb_cases = {}\n",
    "program_type_distrb = {}\n",
    "program_type_distrb_cases = {}\n",
    "for key, item in qset.items():\n",
    "    sem_len = len(item['semantic'])\n",
    "    if(sem_len not in program_len_distrb):\n",
    "        program_len_distrb[sem_len]= 0\n",
    "        program_len_distrb_cases[sem_len] = []\n",
    "    program_len_distrb_cases[sem_len].append(key)\n",
    "    program_len_distrb[sem_len]+=1 \n",
    "    \n",
    "    q_type = item['types']['structural']\n",
    "    if(q_type not in program_type_distrb):\n",
    "        program_type_distrb[q_type] = 0\n",
    "        program_type_distrb_cases[q_type] = []\n",
    "    program_type_distrb[q_type] +=1\n",
    "    program_type_distrb_cases[q_type].append(key)\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76b51cc9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f4fa23c",
   "metadata": {},
   "outputs": [],
   "source": [
    "program_len_distrb, program_type_distrb #gqa val stats"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b926dd9",
   "metadata": {},
   "source": [
    "# Experiment Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0d62b52",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_gqa_scene_description(imageId, sgs):\n",
    "    \"\"\"\n",
    "    Get description for a given imageId in the GQA dataset scene graphs.\n",
    "    \"\"\"\n",
    "    scene = sgs.get(imageId)\n",
    "    if scene is None:\n",
    "        return f'Scene with imageId {imageId} not found.'\n",
    "\n",
    "    output = []\n",
    "    \n",
    "    # Optional scene information\n",
    "    if 'location' in scene:\n",
    "        output.append(f'Location: {scene[\"location\"]}')\n",
    "    if 'weather' in scene:\n",
    "        output.append(f'Weather: {scene[\"weather\"]}')\n",
    "    \n",
    "    # Add metadata to output\n",
    "    output.append(f'Image Dimensions: {scene[\"width\"]}x{scene[\"height\"]}')\n",
    "    \n",
    "    obj_count = len(scene['objects'])\n",
    "    output.append(f'Objects: {obj_count}')\n",
    "    for obj_id, obj in scene[\"objects\"].items():\n",
    "        output.append(f'  Object ID {obj_id}:')\n",
    "        output.append(f'    Name: {obj[\"name\"]}')\n",
    "        output.append(f'    Coordinates: x={obj[\"x\"]}, y={obj[\"y\"]}')\n",
    "        output.append(f'    Dimensions: w={obj[\"w\"]}, h={obj[\"h\"]}')\n",
    "        output.append(f'    Attributes: {\", \".join(obj[\"attributes\"])}')\n",
    "        for relation in obj.get(\"relations\", []):  # Adjusted this line\n",
    "            output.append(f'      Relation:')\n",
    "            output.append(f'        Name: {relation[\"name\"]}')\n",
    "            output.append(f'        Object: {relation[\"object\"]}')\n",
    "        output.append('\\n')\n",
    "\n",
    "    # Combine the output strings and return\n",
    "    return (\"\\n\".join(output))\n",
    "\n",
    "\n",
    "def iterate_through_questions_and_generate_description(qset, sgs):\n",
    "    \"\"\"\n",
    "    Iterate through each question, generate scene description and ask the question.\n",
    "    \"\"\"\n",
    "    for q_id, q_struct in qset.items():\n",
    "        description = get_gqa_scene_description(q_struct['imageId'], sgs)\n",
    "        print(description)\n",
    "        print(f'Question: {q_struct[\"question\"]}\\n')\n",
    "        break\n",
    "\n",
    "iterate_through_questions_and_generate_description(qset, sgs)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f481f706",
   "metadata": {},
   "source": [
    "## Subset Data Selection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5005b12e",
   "metadata": {},
   "outputs": [],
   "source": [
    "answer_distribution = {}\n",
    "for question_id, q_struct in qset.items():\n",
    "    answer = q_struct['answer']\n",
    "    \n",
    "    # Check if the answer is already in the answer distribution dictionary\n",
    "    if answer in answer_distribution:\n",
    "        answer_distribution[answer]['count'] += 1\n",
    "        answer_distribution[answer]['question_ids'].append(question_id)\n",
    "    else:\n",
    "        answer_distribution[answer] = {\n",
    "            'count': 1,\n",
    "            'question_ids': [question_id]\n",
    "        }\n",
    "\n",
    "# Sort the answer distribution by count in descending order\n",
    "sorted_answers = sorted(answer_distribution.items(), key=lambda x: x[1]['count'], reverse=True)\n",
    "\n",
    "# Select the top 25 answers and calculate their total occurrences\n",
    "top_answers = sorted_answers[:25]\n",
    "total_occurrences = 0\n",
    "\n",
    "for answer, info in top_answers:\n",
    "    count = info['count']\n",
    "    total_occurrences += count\n",
    "    print(f\"Answer: {answer}, Count: {count}\")\n",
    "\n",
    "print(\"=\" * 30)\n",
    "print(f\"Total Occurrences of Top 25 Answers: {total_occurrences}\")\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "e4a94274",
   "metadata": {},
   "source": [
    "### Modified subset from filtered dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00b5cae7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "def calculate_program_distributions(qset):\n",
    "    program_len_distrb = {}\n",
    "    program_len_distrb_cases = {}\n",
    "\n",
    "    for key, item in qset.items():\n",
    "        sem_len = len(item['semantic'])\n",
    "        if sem_len not in program_len_distrb:\n",
    "            program_len_distrb[sem_len] = 0\n",
    "            program_len_distrb_cases[sem_len] = []\n",
    "        program_len_distrb_cases[sem_len].append(key)\n",
    "        program_len_distrb[sem_len] += 1\n",
    "    \n",
    "    return program_len_distrb, program_len_distrb_cases\n",
    "\n",
    "def select_subset_by_n(n, program_len_distrb_cases):\n",
    "    subset = []\n",
    "    \n",
    "    for prog_len, cases in program_len_distrb_cases.items():\n",
    "        if prog_len < 6:\n",
    "            subset.extend(random.sample(cases, min(n, len(cases))))\n",
    "        else:  \n",
    "            subset.extend(cases)\n",
    "    \n",
    "    return subset\n",
    "\n",
    "def select_subset_by_k(k, program_len_distrb, program_len_distrb_cases):\n",
    "    subset = []\n",
    "    \n",
    "    for prog_len, count in program_len_distrb.items():\n",
    "        num_to_select = int(count * k)\n",
    "        subset.extend(random.sample(program_len_distrb_cases[prog_len], num_to_select))\n",
    "    \n",
    "    return subset\n",
    "\n",
    "def select_subset(n=None, k=None, qset=None):\n",
    "    program_len_distrb, program_len_distrb_cases = calculate_program_distributions(qset)\n",
    "    if n:\n",
    "        return select_subset_by_n(n, program_len_distrb_cases)\n",
    "    elif k:\n",
    "        return select_subset_by_k(k, program_len_distrb, program_len_distrb_cases)\n",
    "    else:\n",
    "        raise ValueError(\"Either n or k should be provided.\")\n",
    "\n",
    "# Top n most occurring answers\n",
    "top_n = 25\n",
    "\n",
    "# Get all the question_ids for the top n answers\n",
    "top_n_question_ids = []\n",
    "\n",
    "for _, info in sorted_answers[:top_n]:\n",
    "    top_n_question_ids.extend(info['question_ids'])\n",
    "\n",
    "# Filter the qset by the top n answers\n",
    "filtered_qset = {qid: qset[qid] for qid in top_n_question_ids}\n",
    "\n",
    "# Now use the filtered_qset when calling the select functions:\n",
    "\n",
    "n = 5000\n",
    "subset_by_n = select_subset(n=n, qset=filtered_qset)\n",
    "print(f\"Selected {len(subset_by_n)} questions by n from filtered dataset.\")\n",
    "\n",
    "k = 0.3\n",
    "subset_by_k = select_subset(k=k, qset=filtered_qset)\n",
    "print(f\"Selected {len(subset_by_k)} questions by k from filtered dataset.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b271fb42",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_final_question_set(subset_keys, qset):\n",
    "    \"\"\"\n",
    "    Return the subset of qset based on the provided subset of keys.\n",
    "    \"\"\"\n",
    "    return {key: qset[key] for key in subset_keys if key in qset}\n",
    "\n",
    "# Original Usage\n",
    "# n = 5000\n",
    "# subset_by_n = select_subset(n=n)\n",
    "# final_qset_by_n = get_final_question_set(subset_by_n, qset)\n",
    "# print(f\"Final question set by n contains {len(final_qset_by_n)} questions.\")\n",
    "\n",
    "# k = 0.5\n",
    "# subset_by_k = select_subset(k=k)\n",
    "# final_qset_by_k = get_final_question_set(subset_by_k, qset)\n",
    "# print(f\"Final question set by k contains {len(final_qset_by_k)} questions.\")\n",
    "\n",
    "n = 10000\n",
    "subset_by_n = select_subset(n=n, qset=filtered_qset)\n",
    "final_qset_by_n = get_final_question_set(subset_by_n, filtered_qset)\n",
    "print(f\"Selected {len(subset_by_n)} questions by n from filtered dataset.\")\n",
    "\n",
    "k = 0.5\n",
    "subset_by_k = select_subset(k=k, qset=filtered_qset)\n",
    "final_qset_by_k = get_final_question_set(subset_by_k, filtered_qset)\n",
    "print(f\"Selected {len(subset_by_k)} questions by k from filtered dataset.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "591332ec",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "d5bdb768",
   "metadata": {},
   "source": [
    "# Model Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16f79479",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run this to install dependencies:\n",
    "# !pip3 install salesforce-lavis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4762e685",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "import requests\n",
    "import json\n",
    "import os\n",
    "import tqdm\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6d8bc3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check usage here: https://huggingface.co/docs/transformers/model_doc/flan-t5\n",
    "# We will have to run this experiment for flan-t5-xl and flan-t5-xxl\n",
    "tokenizer = T5Tokenizer.from_pretrained(\"google/flan-t5-xl\")\n",
    "model = T5ForConditionalGeneration.from_pretrained(\"google/flan-t5-xl\", device_map=\"auto\", torch_dtype=torch.float16)\n",
    "\n",
    "input_text = \"translate English to German: How old are you?\"\n",
    "input_ids = tokenizer(input_text, return_tensors=\"pt\").input_ids.to(\"cuda\")\n",
    "\n",
    "outputs = model.generate(input_ids)\n",
    "print(tokenizer.decode(outputs[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4107fad",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_params = model.num_parameters()\n",
    "print(f\"Number of parameters in {num_params}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "315105a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get set of correct answers from gqa\n",
    "def get_correct_answers(qset):\n",
    "    \"\"\"\n",
    "    Extract the set of correct answers from the question set.\n",
    "    \"\"\"\n",
    "    answers = {item['answer'] for _, item in qset.items()}\n",
    "    return answers\n",
    "\n",
    "correct_answers = get_correct_answers(qset)\n",
    "print(len(correct_answers))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e3df42a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_prompt(q_struct,cot=False, top_answers=True):\n",
    "    \"\"\"\n",
    "    Returns the prompt except the question\n",
    "    \"\"\"\n",
    "    global top_n\n",
    "    # Extract top 25 answers\n",
    "    top_n_answers = [answer for answer, _ in sorted_answers[:top_n]]\n",
    "\n",
    "    # Format the top answers into a string\n",
    "    formatted_top_n = ', '.join(top_n_answers)\n",
    "\n",
    "    description = get_gqa_scene_description(q_struct['imageId'], sgs)\n",
    "    # Add the top answers to the instruction\n",
    "    if top_answers:\n",
    "        description += f\" The possible answers could be: {formatted_top_n}.\"\n",
    "        \n",
    "    setup = 'Given the following scene:\\n'\n",
    "    if not cot:\n",
    "        instruction = 'Now answer the following question in one word.'\n",
    "    else:\n",
    "        instruction = 'Now answer the following questions with step-by-step reasoning.'\n",
    "    \n",
    "    prompt = setup + description + instruction\n",
    "    if cot:\n",
    "        prompt += 'Give the final one word answer at the end of your reasoning. Thus, your response format should be:\\nReasoning for the answer\\nFinal answer:\\nFinal one word answer'\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5c20361",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_text = get_prompt(q_struct)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c10f22fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(prompt_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ab3e027",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_input_text = f\"context: {prompt_text}\\nQuestion: {q_struct['question']}\"\n",
    "print(test_input_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71cb48a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "len_threshold = 6000\n",
    "def count_suitable_questions(qset):\n",
    "    count = 0\n",
    "    for q_id, q_struct in qset.items():\n",
    "        prompt = get_prompt(q_struct)\n",
    "        input_text = f\"context: {prompt}\\nQuestion: {q_struct['question']}\"\n",
    "        if len(input_text) < len_threshold:\n",
    "            count += 1\n",
    "    return count\n",
    "\n",
    "# Use the function to get the count:\n",
    "# n_suitable_questions = count_suitable_questions(final_qset_by_n)\n",
    "n_suitable_questions = count_suitable_questions(qset)\n",
    "print(f\"There are {n_suitable_questions} questions with input text length less than {len_threshold} characters.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c96cdf10",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(final_qset_by_n.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b1534de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# final_qset_by_n.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "222a09ed",
   "metadata": {},
   "source": [
    "# Run on T5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "271556ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def run_on_gqa(qset_in, response_mapping={}):\n",
    "    error_q_ids = []\n",
    "    print(response_mapping.keys())\n",
    "\n",
    "    q_ids = list(qset_in.keys())  # Get all the question IDs\n",
    "    num_done = 0\n",
    "    with torch.no_grad():\n",
    "        for q_id in tqdm(q_ids):\n",
    "            #if(num_done%3==0):\n",
    "             #   model = T5ForConditionalGeneration.from_pretrained(\"google/flan-t5-xl\", device_map=\"auto\", torch_dtype=torch.float16)\n",
    "            q_struct = qset_in[q_id]\n",
    "            response_mapping[q_id] = {}\n",
    "\n",
    "            prompt = get_prompt(q_struct)\n",
    "            # cot_prompt = get_prompt(q_struct, True)\n",
    "\n",
    "            if num_done % 100 == 0:  # Adjust this condition as needed\n",
    "                print(num_done)\n",
    "                print('Saving Checkpoint...')\n",
    "                with open('response_mapping_gqa_t5.json', 'w') as f:\n",
    "                    json.dump(response_mapping, f)\n",
    "                print('---------------------------------')\n",
    "\n",
    "            # Non COT response\n",
    "            input_text = f\"context: {prompt}\\nQuestion: {q_struct['question']}\"\n",
    "\n",
    "            if len(input_text) >= 5000:\n",
    "                continue  # Skip the iteration and move to the next question\n",
    "\n",
    "            input_ids = tokenizer(input_text, return_tensors=\"pt\").input_ids.to(\"cuda\")\n",
    "            print(input_ids.shape)\n",
    "            outputs = model.generate(input_ids)\n",
    "            prompt_response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "            response_mapping[q_id]['non_cot'] = prompt_response\n",
    "            #print(prompt_response, t)\n",
    "            #input()\n",
    "            num_done+=1\n",
    "            # COT response\n",
    "            # input_text = f\"context: {cot_prompt}\\nQuestion: {q_struct['question']}\"\n",
    "            # input_ids = tokenizer(input_text, return_tensors=\"pt\").input_ids.to(\"cuda\")\n",
    "            # outputs = model.generate(input_ids)\n",
    "            # cot_response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "            # response_mapping[q_id]['cot'] = cot_response\n",
    "\n",
    "    print('Final Save...')\n",
    "    with open('response_mapping_gqa_t5_final.json', 'w') as f:\n",
    "                json.dump(response_mapping, f)\n",
    "    if error_q_ids:\n",
    "        with open('error_q_ids.txt', 'w') as f:\n",
    "            for error_q_id in error_q_ids:\n",
    "                f.write(f\"{error_q_id}\\n\")\n",
    "\n",
    "    return response_mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f9a0920",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Usage:\n",
    "response_mapping = run_on_gqa(final_qset_by_n)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f9396f6",
   "metadata": {},
   "source": [
    "# Run on Blip2-Flant T5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30131a5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "import requests\n",
    "from lavis.models import load_model_and_preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d225c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup device to use\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14b80e8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load pretrained/finetuned blip2 captioning model\n",
    "# we associate a model with its preprocessors to make it easier for inference.\n",
    "model, vis_processors, _ = load_model_and_preprocess(\n",
    "    name=\"blip2_t5\", model_type=\"pretrain_flant5xl\", is_eval=True, device=device\n",
    ")\n",
    "\n",
    "# Other available models:\n",
    "# \n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_opt\", model_type=\"pretrain_opt2.7b\", is_eval=True, device=device\n",
    "# )\n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_opt\", model_type=\"pretrain_opt6.7b\", is_eval=True, device=device\n",
    "# )\n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_opt\", model_type=\"caption_coco_opt2.7b\", is_eval=True, device=device\n",
    "# )\n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_opt\", model_type=\"caption_coco_opt6.7b\", is_eval=True, device=device\n",
    "# )\n",
    "#\n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_t5\", model_type=\"pretrain_flant5xl\", is_eval=True, device=device\n",
    "# )\n",
    "#\n",
    "# model, vis_processors, _ = load_model_and_preprocess(\n",
    "#     name=\"blip2_t5\", model_type=\"caption_coco_flant5xl\", is_eval=True, device=device\n",
    "# )\n",
    "\n",
    "vis_processors.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef7c4079",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def run_gqa_on_vlm_model(qset_in, model, vis_processors, img_dir, output_file='gqa_blip2_response_data.json'):\n",
    "    i = 0\n",
    "    for q_id, q_struct in tqdm(qset_in.items()):\n",
    "        # Get the image associated with the question\n",
    "        img_filename = q_struct['imageId'] + '.jpg'\n",
    "        # break\n",
    "        raw_image_path = os.path.join(img_dir, img_filename)\n",
    "        if not os.path.exists(raw_image_path):\n",
    "            continue\n",
    "        raw_image = Image.open(raw_image_path).convert('RGB') \n",
    "        \n",
    "        # Prepare image\n",
    "        image = vis_processors[\"eval\"](raw_image).unsqueeze(0).to(device)\n",
    "\n",
    "        # Extract the ground truth and the question\n",
    "        question = q_struct['question']\n",
    "\n",
    "        # Combine the question and prepare the model input\n",
    "        qn_to_ask = 'Question: ' + question + ' Answer:'\n",
    "        model_ans = model.generate({\n",
    "            'image': image,\n",
    "            'prompt': qn_to_ask\n",
    "        })\n",
    "        \n",
    "        # Store the model answer in the question structure\n",
    "        qset_in[q_id]['blip2_t5'] = model_ans\n",
    "\n",
    "        # Save results to JSON file after every 100 questions\n",
    "        if output_file and i%100==0:\n",
    "            with open(output_file, 'w') as f:\n",
    "                json.dump(qset_in, f)\n",
    "        i +=1\n",
    "    \n",
    "    with open('gqa_blip2_response_final.json', 'w') as f:\n",
    "        json.dump(qset_in, f)\n",
    "        \n",
    "    return qset_in"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68198f79",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = run_gqa_on_vlm_model(final_qset_by_n, model,vis_processors, img_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c010069",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_qset_by_n.items()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6316618",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
