{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebd6f222",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pip install accelerate\n",
    "from transformers import T5Tokenizer, T5ForConditionalGeneration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "144f290d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# pip install accelerate\n",
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "\n",
    "tokenizer = T5Tokenizer.from_pretrained(\"google/flan-t5-xxl\")\n",
    "model = T5ForConditionalGeneration.from_pretrained(\"google/flan-t5-xxl\", device_map=\"auto\")\n",
    "\n",
    "input_text = \"translate English to German: How old are you?\"\n",
    "input_ids = tokenizer(input_text, return_tensors=\"pt\").input_ids.to(\"cuda\")\n",
    "\n",
    "outputs = model.generate(input_ids)\n",
    "print(tokenizer.decode(outputs[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "809fb3f0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4eb75c2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "num_params = model.num_parameters()\n",
    "print(f\"Number of parameters in {num_params}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa9afed8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "3c97c04e",
   "metadata": {},
   "source": [
    "# Define Functions to get and process dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f79ed45",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import requests\n",
    "import json\n",
    "import os\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb8aacf2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Replace with the path to the directory containing the CLEVR dataset\n",
    "CLEVR_DIR = '/home/user/Desktop/vqa_research/CLEVR_v1.0'\n",
    "SCENE_MAPPING_DIR = '/home/user/Desktop/vqa_research'\n",
    "\n",
    "def load_data():\n",
    "    \"\"\"\n",
    "    Load metadata and questions for all scenes in the CLEVR dataset.\n",
    "    \"\"\"\n",
    "    print('opening scenes...')\n",
    "    # Load metadata for all scenes\n",
    "    with open(os.path.join(CLEVR_DIR, 'scenes', 'CLEVR_train_scenes.json')) as f:\n",
    "        scenes = json.load(f)['scenes']\n",
    "    \n",
    "    print('Opening questions...')\n",
    "    # Load questions for all scenes\n",
    "    with open(os.path.join('CLEVR_t_questions.json')) as f:\n",
    "        questions = json.load(f)['questions']\n",
    "    \n",
    "    print('Creating map...')\n",
    "    # Create a dictionary mapping scene numbers to scene data and questions\n",
    "    scene_data = {}\n",
    "    for scene in tqdm.tqdm(scenes):\n",
    "        scene_number = scene['image_index']\n",
    "        scene_questions = [q for q in questions if q['image_index'] == scene_number]\n",
    "        scene_data[scene_number] = {'metadata': scene, 'questions': scene_questions}\n",
    "\n",
    "    return scene_data\n",
    "\n",
    "def load_scene_data():\n",
    "    \"\"\"\n",
    "    Loads scene_data.json \n",
    "    \"\"\"\n",
    "    with open(os.path.join('./clevr_val_scene_mapping.json')) as f:\n",
    "        scene_data = json.load(f)\n",
    "    \n",
    "    return scene_data\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3a7bd07",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_scene_data(scene_number, scene_data):\n",
    "    \"\"\"\n",
    "    Get metadata and questions for a given scene number in the CLEVR dataset.\n",
    "    \"\"\"\n",
    "    # Find the scene with the given number\n",
    "    scene = scene_data.get(scene_number)\n",
    "    if scene is None:\n",
    "        print(f'Scene {scene_number} not found.')\n",
    "        return\n",
    "    \n",
    "    # Print metadata\n",
    "    metadata = scene['metadata']\n",
    "    print()\n",
    "    print(f'Scene {scene_number}:')\n",
    "    print()\n",
    "    print(f'  Objects: {len(metadata[\"objects\"])}')\n",
    "    for obj in metadata[\"objects\"]:\n",
    "        print(f'    Object:')\n",
    "        print(f'      Color: {obj[\"color\"]}')\n",
    "        print(f'      Size: {obj[\"size\"]}')\n",
    "        print(f'      Rotation: {obj[\"rotation\"]}')\n",
    "        print(f'      Shape: {obj[\"shape\"]}')\n",
    "        print(f'      Material: {obj[\"material\"]}')\n",
    "        print(f'      3D Coords: {obj[\"3d_coords\"]}')\n",
    "        print(f'      Pixel Coords: {obj[\"pixel_coords\"]}')\n",
    "        print()\n",
    "    print(f'  Relationships: {metadata[\"relationships\"]}')\n",
    "    print()\n",
    "    print(f'  Directions: {metadata[\"directions\"]}')\n",
    "    print()\n",
    "    print(f'  Image Filename: {metadata[\"image_filename\"]}')\n",
    "    print()\n",
    "    \n",
    "    # Print questions and answers\n",
    "    questions = scene['questions']\n",
    "    print(f'  Questions: {len(questions)}')\n",
    "    for q in questions:\n",
    "        print(f'    Question: {q[\"question\"]}')\n",
    "        print()\n",
    "        print(f'      Answer: {q[\"answer\"]}')\n",
    "        print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e941ea22",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "scene_data = load_scene_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "462bc719",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "get_scene_data('0',scene_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64e1f22b",
   "metadata": {},
   "source": [
    "# Design Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dea9a5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene_data['0']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ff1221b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_scene_metadata(scene_number, scene_data):\n",
    "    \"\"\"\n",
    "    Get metadata for a given scene number in the CLEVR dataset.\n",
    "    \"\"\"\n",
    "    # Find the scene with the given number\n",
    "    scene = scene_data.get(scene_number)\n",
    "    if scene is None:\n",
    "        return f'Scene {scene_number} not found.'\n",
    "    output = []\n",
    "    output_qns = []\n",
    "    output_ans = []\n",
    "\n",
    "    # Add metadata to output\n",
    "    metadata = scene['metadata']\n",
    "    output.append(f'Scene {scene_number}:\\n')\n",
    "    output.append(f'  Objects: {len(metadata[\"objects\"])}')\n",
    "    for obj in metadata[\"objects\"]:\n",
    "        output.append(f'    Object:')\n",
    "        output.append(f'      Color: {obj[\"color\"]}')\n",
    "        output.append(f'      Size: {obj[\"size\"]}')\n",
    "        output.append(f'      Rotation: {obj[\"rotation\"]}')\n",
    "        output.append(f'      Shape: {obj[\"shape\"]}')\n",
    "        output.append(f'      Material: {obj[\"material\"]}')\n",
    "        output.append(f'      3D Coords: {obj[\"3d_coords\"]}')\n",
    "        output.append(f'      Pixel Coords: {obj[\"pixel_coords\"]}')\n",
    "        output.append('\\n')\n",
    "    output.append(f'  Relationships: {metadata[\"relationships\"]}\\n')\n",
    "    output.append(f'  Directions: {metadata[\"directions\"]}\\n')\n",
    "    output.append(f'  Image Filename: {metadata[\"image_filename\"]}\\n')\n",
    "\n",
    "    # Combine the output strings and return\n",
    "    return (\"\\n\".join(output))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd31dcb9-6d86-4212-a46c-040bddc7ae2f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_prompt(scene_number, scene_data,cot=False):\n",
    "    # metadata, qns, ans = get_scene_input(scene_number,scene_data)\n",
    "    # setup = 'Given the following scene:\\n'\n",
    "    scene_setup = 'You may assume that any metal object is shiny, and any rubber object is not shiny (\"matte\"). All objects are either \"metal\" or \"rubber\", and in 2 sizes: \"large\" or \"small\". All objects are one of the following colours: \"blue\", \"brown\", \"cyan\", \"gray\", \"green\", \"purple\", \"red\", \"yellow\". All objects are one of the following shapes: \"cube\", \"cylinder\", \"sphere\". For numeric answers, give an integer and not in words.\\n'\n",
    "    if not cot:\n",
    "        instruction = 'Now answer the following question in one word.'\n",
    "    else:\n",
    "        instruction = 'Now answer the following questions with step-by-step reasoning.'\n",
    "    prompt = scene_setup + instruction\n",
    "    if cot:\n",
    "        prompt += 'Give the final one word answer at the end of your reasoning. Thus, your response format should be:\\nReasoning for the answer\\nFinal answer:\\nFinal one word answer'\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67c92404",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76fb3f63",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2ac6fc8-67c0-4c58-9a5f-da93c4fcc8b7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def run_on_t5(scene_start=0, scene_end=14999, response_mapping={},response_mapping_cot={}):\n",
    "\n",
    "    error_scene_indexes = []\n",
    "    print(response_mapping.keys())\n",
    "\n",
    "    for scene_index in tqdm.tqdm(range(scene_start, scene_end)):\n",
    "        i = str(scene_index)\n",
    "        response_mapping[i] = []\n",
    "        response_mapping_cot[i] = []\n",
    "\n",
    "        instructions = get_prompt(i, scene_data)\n",
    "        cot_instructions = get_prompt(i, scene_data, True)\n",
    "        \n",
    "        metadata = get_scene_metadata(i,scene_data)\n",
    "        \n",
    "        setup = 'Given the following scene:\\n'\n",
    "        \n",
    "        scene = scene_data.get(i)\n",
    "        \n",
    "        for q in scene['questions']:\n",
    "            question = q['question']\n",
    "            cot_prompt = setup+ metadata + cot_instructions +\"\\n\"+\"Question:\"+ question+ \"Answer: Let's think step by step.\"\n",
    "            \n",
    "#             print(cot_prompt)\n",
    "            if scene_index % 1500 == 0:\n",
    "                with open('response_mapping_flan_t5_xxl_cot.json', 'w') as f:\n",
    "                    json.dump(response_mapping, f)\n",
    "\n",
    "            input_text = f\"context: {cot_prompt}\"\n",
    "            input_ids = tokenizer(input_text, return_tensors=\"pt\").input_ids.to(\"cuda\")\n",
    "            outputs = model.generate(input_ids, max_length = 200)\n",
    "            prompt_response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "            \n",
    "            response_mapping[i].append(prompt_response)\n",
    "            # response_mapping_cot[i].append(cot_response)\n",
    "\n",
    "    with open('response_mapping_flan_t5_xxl_cot.json', 'w') as f:\n",
    "        json.dump(response_mapping, f)\n",
    "\n",
    "    return response_mapping, response_mapping_cot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4a56822",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "response_map, response_map_cot = run_on_t5()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7055f1db",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "response_map['0'][5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6579660d-bf0e-4e69-9590-5c449ebc1bd4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "scene_data['0']['questions'][5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fe6d2cb-c691-492f-91f6-051aa60a0474",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
