{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebd6f222",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pip install accelerate\n",
    "from transformers import T5Tokenizer, T5ForConditionalGeneration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "144f290d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "\n",
    "tokenizer = T5Tokenizer.from_pretrained(\"google/flan-t5-xl\")\n",
    "model = T5ForConditionalGeneration.from_pretrained(\"google/flan-t5-xl\", device_map=\"auto\")\n",
    "\n",
    "input_text = \"translate English to German: How old are you?\"\n",
    "input_ids = tokenizer(input_text, return_tensors=\"pt\").input_ids.to(\"cuda\")\n",
    "\n",
    "outputs = model.generate(input_ids)\n",
    "print(tokenizer.decode(outputs[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "809fb3f0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4eb75c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_params = model.num_parameters()\n",
    "print(f\"Number of parameters in {num_params}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa9afed8",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_text = \"translate English to German: How old are you?\"\n",
    "input_ids = tokenizer(input_text, return_tensors=\"pt\").input_ids.to(\"cuda\")\n",
    "\n",
    "outputs = model.generate(input_ids)\n",
    "print(tokenizer.decode(outputs[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c97c04e",
   "metadata": {},
   "source": [
    "# Define Functions to get and process dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f79ed45",
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import json\n",
    "import os\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb8aacf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Replace with the path to the directory containing the PTR dataset\n",
    "PTR_DIR = './'\n",
    "SCENE_MAPPING_DIR = './'\n",
    "\n",
    "def load_data():\n",
    "    \"\"\"\n",
    "    Load metadata and questions for all scenes in the PTR dataset.\n",
    "    \"\"\"\n",
    "    print('opening scenes...')\n",
    "    # Load metadata for all scenes\n",
    "    with open(os.path.join(PTR_DIR, 'scenes', 'PTR_val_scenes.json')) as f:\n",
    "        scenes = json.load(f)['scenes']\n",
    "    \n",
    "    print('Opening questions...')\n",
    "    # Load questions for all scenes\n",
    "    with open(os.path.join(PTR_DIR, 'questions', 'PTR_val_questions.json')) as f:\n",
    "        questions = json.load(f)['questions']\n",
    "    \n",
    "    print('Creating map...')\n",
    "    # Create a dictionary mapping scene numbers to scene data and questions\n",
    "    scene_data = {}\n",
    "    for scene in tqdm.tqdm(scenes):\n",
    "        scene_number = scene['image_index']\n",
    "        scene_questions = [q for q in questions if q['image_index'] == scene_number]\n",
    "        scene_data[scene_number] = {'metadata': scene, 'questions': scene_questions}\n",
    "\n",
    "    return scene_data\n",
    "\n",
    "def load_scene_data():\n",
    "    \"\"\"\n",
    "    Loads scene_data.json \n",
    "    \"\"\"\n",
    "    with open(os.path.join('./ptr_val_scene_mapping.json')) as f:\n",
    "        scene_data = json.load(f)\n",
    "    \n",
    "    return scene_data\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3a7bd07",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_scene_data(scene_key, scene_data):\n",
    "    \"\"\"\n",
    "    Get metadata and questions for a given scene number in the PTR dataset.\n",
    "    \"\"\"\n",
    "    # Find the scene with the given number\n",
    "    scene = scene_data.get(scene_key)\n",
    "    if scene is None:\n",
    "        print(f'Scene {scene_key} not found.')\n",
    "        return\n",
    "    \n",
    "    # Print metadata\n",
    "    metadata = scene['metadata']\n",
    "    print()\n",
    "    print(f'Scene {scene_key}:')\n",
    "    print()\n",
    "    print(f'  Objects: {len(metadata[\"objects\"])}')\n",
    "    for obj in metadata[\"objects\"]:\n",
    "        print(f'    Object:')\n",
    "        print(f'      Category: {obj[\"category\"]}')\n",
    "        print(f'      Rotation: {obj[\"rotation\"]}')\n",
    "        print(f'      Scale: {obj[\"scale\"]}')\n",
    "        if 'stability' in obj:\n",
    "            print(f'      Stability: {obj[\"stability\"]}')\n",
    "        print(f'      3D Coords: {obj[\"3d_coords\"]}')\n",
    "        print(f'      Support: {obj[\"pixel_coords\"]}')\n",
    "        print(f'      Part Colors: {obj[\"part_color\"]}')\n",
    "        print(f'      Part Count: {obj[\"part_count\"]}')\n",
    "        print()\n",
    "    print(f'  Relationships: {metadata[\"relationships\"]}')\n",
    "    print()\n",
    "    print(f'  Directions: {metadata[\"directions\"]}')\n",
    "    print()\n",
    "    print(f'  Image Filename: {metadata[\"image_filename\"]}')\n",
    "    print()\n",
    "    if 'physics' in metadata:\n",
    "        print(f'  Physics: {metadata[\"physics\"]}')\n",
    "    print()\n",
    "    print(f'  Cam location: {metadata[\"cam_location\"]}')\n",
    "    print()\n",
    "    print(f'  Cam Rotation: {metadata[\"cam_rotation\"]}')\n",
    "    print()\n",
    "\n",
    "    \n",
    "    # Print questions and answers\n",
    "    questions = scene['questions']\n",
    "    print(f'  Questions: {len(questions)}')\n",
    "    for q in questions:\n",
    "        print(f'    Question: {q[\"question\"]}')\n",
    "        print()\n",
    "        print(f'      Answer: {q[\"answer\"]}')\n",
    "        print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e941ea22",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene_data = load_scene_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "462bc719",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene_data['PTR_val_007239']['metadata'].keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64e1f22b",
   "metadata": {},
   "source": [
    "# Design Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dea9a5e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ff1221b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_scene_metadata(scene_number, scene_data):\n",
    "    \"\"\"\n",
    "    Get metadata for a given scene number in the PTR dataset.\n",
    "    \"\"\"\n",
    "    # Find the scene with the given number\n",
    "    scene = scene_data.get(scene_number)\n",
    "    if scene is None:\n",
    "        return f'Scene {scene_number} not found.'\n",
    "    output = []\n",
    "    output_qns = []\n",
    "    output_ans = []\n",
    "\n",
    "    # Add metadata to output\n",
    "    metadata = scene['metadata']\n",
    "    output.append(f'Scene {scene_number}:\\n')\n",
    "    output.append(f'  Objects: {len(metadata[\"objects\"])}')\n",
    "    for obj in metadata[\"objects\"]:\n",
    "        # print(obj.keys())\n",
    "        output.append(f'    Object:')\n",
    "        output.append(f'      Category: {obj[\"category\"]}')\n",
    "        output.append(f'      Rotation: {obj[\"rotation\"]}')\n",
    "        output.append(f'      Scale: {obj[\"scale\"]}')\n",
    "        if \"stability\" in obj:\n",
    "            output.append(f'      Stability: {obj[\"stability\"]}')\n",
    "        output.append(f'      3D Coords: {obj[\"3d_coords\"]}')\n",
    "        output.append(f'      Support: {obj[\"pixel_coords\"]}')\n",
    "        output.append(f'      Part Colors: {obj[\"part_color\"]}')\n",
    "        output.append(f'      Part Count: {obj[\"part_count\"]}')\n",
    "        output.append('\\n')\n",
    "    output.append(f'  Relationships: {metadata[\"relationships\"]}\\n')\n",
    "    output.append(f'  Directions: {metadata[\"directions\"]}\\n')\n",
    "    output.append(f'  Image Filename: {metadata[\"image_filename\"]}\\n')\n",
    "    if \"physics\" in metadata:\n",
    "        output.append(f'  Physics: {metadata[\"physics\"]}\\n')\n",
    "    output.append(f'  Cam location: {metadata[\"cam_location\"]}\\n')\n",
    "    output.append(f'  Cam Rotation: {metadata[\"cam_rotation\"]}\\n')\n",
    "    \n",
    "\n",
    "    # Combine the output strings and return\n",
    "    return (\"\\n\".join(output))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd31dcb9-6d86-4212-a46c-040bddc7ae2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_prompt(scene_number, scene_data,cot=False):\n",
    "    # metadata, qns, ans = get_scene_input(scene_number,scene_data)\n",
    "    # setup = 'Given the following scene:\\n'\n",
    "    scene_setup = \"The objects or things can have the following categories: 'Bed', 'Cart', 'Chair', 'Refrigerator', 'Table'. The different parts of the things can have the following categories: arm', 'arm horizontal bar', 'arm vertical bar', 'back', 'behind', 'body', 'central support', 'door', 'drawer', 'leg', 'leg bar', 'pedestal', 'seat', 'shelf', 'sleep area', 'top', 'wheel'. The things or objects can move in the following directions to make themselves stable: 'front', 'left', 'right'. The objects or their parts can have the following colors: 'blue', 'brown', 'cyan', 'gray', 'green', 'purple', 'red', 'yellow'. For numeric answers, give an answer in integers and not in words.\"\n",
    "    if not cot:\n",
    "        instruction = 'Now answer the following question in one word.'\n",
    "    else:\n",
    "        instruction = 'Now answer the following questions with step-by-step reasoning.'\n",
    "    prompt = scene_setup + instruction\n",
    "    if cot:\n",
    "        prompt += 'Give the final one word answer at the end of your reasoning. Thus, your response format should be:\\nReasoning for the answer\\nFinal answer:\\nFinal one word answer'\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67c92404",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2ac6fc8-67c0-4c58-9a5f-da93c4fcc8b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_on_t5(scene_start=0, scene_end=10000, response_mapping={},response_mapping_cot={}):\n",
    "\n",
    "    error_scene_indexes = []\n",
    "    print(response_mapping.keys())\n",
    "\n",
    "    for i, scene_index in tqdm.tqdm(enumerate(scene_data)):\n",
    "        \n",
    "        if i> scene_end:\n",
    "            break\n",
    "            \n",
    "        response_mapping[scene_index] = []\n",
    "        response_mapping_cot[scene_index] = []\n",
    "\n",
    "        instructions = get_prompt(scene_index, scene_data)\n",
    "        cot_instructions = get_prompt(scene_index, scene_data, True)\n",
    "        \n",
    "        metadata = get_scene_metadata(scene_index,scene_data)\n",
    "        \n",
    "        setup = 'Given the following scene:\\n'\n",
    "        \n",
    "        scene = scene_data.get(scene_index)\n",
    "        \n",
    "        \n",
    "        for q in scene['questions']:\n",
    "            question = q['question']\n",
    "            prompt = setup+ metadata + instructions +\"\\n\"+\"Question:\"+ question\n",
    "\n",
    "            input_text = f\"context: {prompt}\"\n",
    "            input_ids = tokenizer(input_text, return_tensors=\"pt\").input_ids.to(\"cuda\")\n",
    "            outputs = model.generate(input_ids)\n",
    "            prompt_response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "            \n",
    "            # print(prompt_response)\n",
    "            response_mapping[scene_index].append(prompt_response)\n",
    "        \n",
    "        if i % 100 == 1:\n",
    "            print('Saving progress...')\n",
    "            with open('response_mapping_flan_t5_xl.json', 'w') as f:\n",
    "                json.dump(response_mapping, f)\n",
    "\n",
    "    with open('response_mapping_flan_t5_xl.json', 'w') as f:\n",
    "        json.dump(response_mapping, f)\n",
    "\n",
    "    return response_mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4a56822",
   "metadata": {},
   "outputs": [],
   "source": [
    "response_map = run_on_t5()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7055f1db",
   "metadata": {},
   "outputs": [],
   "source": [
    "response_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6579660d-bf0e-4e69-9590-5c449ebc1bd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene_data['PTR_val_001543']['questions'][5]['answer']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fe6d2cb-c691-492f-91f6-051aa60a0474",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
