{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64452494",
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae30e7ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d5cad78",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "import openai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e7df93e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Replace with the path to the directory containing the CLEVR dataset\n",
    "CLEVR_DIR = 'CLEVR_v1.0'\n",
    "SCENE_MAPPING_DIR = 'vqa_research'\n",
    "\n",
    "def load_data():\n",
    "    \"\"\"\n",
    "    Load metadata and questions for all scenes in the CLEVR dataset.\n",
    "    \"\"\"\n",
    "    print('opening scenes...')\n",
    "    # Load metadata for all scenes\n",
    "    with open(os.path.join(CLEVR_DIR, 'scenes', 'CLEVR_train_scenes.json')) as f:\n",
    "        scenes = json.load(f)['scenes']\n",
    "    \n",
    "    print('Opening questions...')\n",
    "    # Load questions for all scenes\n",
    "    with open(os.path.join(CLEVR_DIR, 'questions', 'CLEVR_train_questions.json')) as f:\n",
    "        questions = json.load(f)['questions']\n",
    "    \n",
    "    print('Creating map...')\n",
    "    # Create a dictionary mapping scene numbers to scene data and questions\n",
    "    scene_data = {}\n",
    "    for scene in tqdm.tqdm(scenes):\n",
    "        scene_number = scene['image_index']\n",
    "        scene_questions = [q for q in questions if q['image_index'] == scene_number]\n",
    "        scene_data[scene_number] = {'metadata': scene, 'questions': scene_questions}\n",
    "\n",
    "    return scene_data\n",
    "\n",
    "def load_scene_data():\n",
    "    \"\"\"\n",
    "    Loads scene_data.json \n",
    "    \"\"\"\n",
    "    with open('ptr_val_scene_mapping.json') as f:\n",
    "        scene_data = json.load(f)\n",
    "    \n",
    "    return scene_data\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7986213a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_scene_data(scene_key, scene_data):\n",
    "    \"\"\"\n",
    "    Get metadata and questions for a given scene number in the CLEVR dataset.\n",
    "    \"\"\"\n",
    "    # Find the scene with the given number\n",
    "    scene = scene_data.get(scene_key)\n",
    "    if scene is None:\n",
    "        print(f'Scene {scene_key} not found.')\n",
    "        return\n",
    "    \n",
    "    # Print metadata\n",
    "    metadata = scene['metadata']\n",
    "    print()\n",
    "    print(f'Scene {scene_key}:')\n",
    "    print()\n",
    "    print(f'  Objects: {len(metadata[\"objects\"])}')\n",
    "    for obj in metadata[\"objects\"]:\n",
    "        print(f'    Object:')\n",
    "        print(f'      Category: {obj[\"category\"]}')\n",
    "        print(f'      Rotation: {obj[\"rotation\"]}')\n",
    "        print(f'      Scale: {obj[\"scale\"]}')\n",
    "        print(f'      Stability: {obj[\"stability\"]}')\n",
    "        print(f'      3D Coords: {obj[\"3d_coords\"]}')\n",
    "        print(f'      Support: {obj[\"pixel_coords\"]}')\n",
    "        print(f'      Part Colors: {obj[\"part_color\"]}')\n",
    "        print(f'      Part Count: {obj[\"part_count\"]}')\n",
    "        print()\n",
    "    print(f'  Relationships: {metadata[\"relationships\"]}')\n",
    "    print()\n",
    "    print(f'  Directions: {metadata[\"directions\"]}')\n",
    "    print()\n",
    "    print(f'  Image Filename: {metadata[\"image_filename\"]}')\n",
    "    print()\n",
    "    print(f'  Physics: {metadata[\"physics\"]}')\n",
    "    print()\n",
    "    print(f'  Cam location: {metadata[\"cam_location\"]}')\n",
    "    print()\n",
    "    print(f'  Cam Rotation: {metadata[\"cam_rotation\"]}')\n",
    "    print()\n",
    "\n",
    "    \n",
    "    # Print questions and answers\n",
    "    questions = scene['questions']\n",
    "    print(f'  Questions: {len(questions)}')\n",
    "    for q in questions:\n",
    "        print(f'    Question: {q[\"question\"]}')\n",
    "        print()\n",
    "        print(f'      Answer: {q[\"answer\"]}')\n",
    "        print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbc5112a",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene_data = load_scene_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "692c6440",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene_data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60c107f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "get_scene_data('PTR_val_007665',scene_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62695710",
   "metadata": {},
   "source": [
    "# Ask GPT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "787c2ad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#         engine=\"text-davinci-003\",\n",
    "# text-ada-001\n",
    "# text-babbage-001\n",
    "# text-curie-001\n",
    "# text-davinci-003\n",
    "def ask_gpt(prompt):\n",
    "#     print(prompt)\n",
    "#     print()\n",
    "    response = openai.Completion.create(\n",
    "        model = 'text-babbage-001',\n",
    "        prompt=prompt,\n",
    "        max_tokens=150,\n",
    "        temperature=0.5,\n",
    "    )\n",
    "\n",
    "    r = response.choices[0].text\n",
    "    usage = response.usage\n",
    "    return (r, usage)\n",
    "\n",
    "def ask_chat_gpt(messages):\n",
    "    \"\"\"\n",
    "    Sample message: \n",
    "    [\n",
    "            {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
    "            {\"role\": \"user\", \"content\": \"Who won the world series in 2020?\"},\n",
    "            {\"role\": \"assistant\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n",
    "            {\"role\": \"user\", \"content\": \"Where was it played?\"}\n",
    "    ]\n",
    "    \"\"\"\n",
    "    response = openai.ChatCompletion.create(\n",
    "      model=\"gpt-3.5-turbo\",\n",
    "      messages= messages\n",
    "    )\n",
    "    r = response.choices[0]\n",
    "    usage = response.usage\n",
    "    return (r, usage)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "267507b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "openai.api_key = ''\n",
    "openai.Model.list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "558751b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [{\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
    "            {'role':'user','content':'I am going to provide you scene metadata and ask you questions about the scene'}\n",
    "           ]\n",
    "response, usage = ask_chat_gpt(messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faafc044",
   "metadata": {},
   "outputs": [],
   "source": [
    "response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ca9627a",
   "metadata": {},
   "outputs": [],
   "source": [
    "usage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9fe06fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_number(n):\n",
    "    \"\"\"\n",
    "    Format a number to 4 significant digits.\n",
    "    \"\"\"\n",
    "    if isinstance(n, float):\n",
    "        return format(n, '.4g')\n",
    "    return n\n",
    "\n",
    "def format_dict(d):\n",
    "    \"\"\"\n",
    "    Recursively format all numbers in a dictionary to 4 significant digits.\n",
    "    \"\"\"\n",
    "    for key, value in d.items():\n",
    "        if isinstance(value, float):\n",
    "            d[key] = format_number(value)\n",
    "        elif isinstance(value, dict):\n",
    "            d[key] = format_dict(value)\n",
    "        elif isinstance(value, list):\n",
    "            d[key] = format_list(value)\n",
    "    return d\n",
    "\n",
    "def format_list(l):\n",
    "    \"\"\"\n",
    "    Recursively format all numbers in a list to 4 significant digits.\n",
    "    \"\"\"\n",
    "    for i, value in enumerate(l):\n",
    "        if isinstance(value, float):\n",
    "            l[i] = format_number(value)\n",
    "        elif isinstance(value, dict):\n",
    "            l[i] = format_dict(value)\n",
    "        elif isinstance(value, list):\n",
    "            l[i] = format_list(value)\n",
    "    return l\n",
    "\n",
    "def get_scene_input(scene_number, scene_data):\n",
    "    \"\"\"\n",
    "    Get metadata for a given scene number in the PTR dataset.\n",
    "    \"\"\"\n",
    "    # Find the scene with the given number\n",
    "    scene = scene_data.get(scene_number)\n",
    "    if scene is None:\n",
    "        return f'Scene {scene_number} not found.'\n",
    "    output = []\n",
    "    output_qns = []\n",
    "    output_ans = []\n",
    "\n",
    "    # Add metadata to output\n",
    "    metadata = scene['metadata']\n",
    "    # Format all numbers in the metadata to 4 significant digits\n",
    "    metadata = format_dict(metadata)\n",
    "    \n",
    "    output.append(f'Scene {scene_number}:\\n')\n",
    "    output.append(f'  Objects: {len(metadata[\"objects\"])}')\n",
    "    for obj in metadata[\"objects\"]:\n",
    "        output.append(f'    Object:')\n",
    "        output.append(f'      Category: {obj[\"category\"]}')\n",
    "        output.append(f'      Rotation: {obj[\"rotation\"]}')\n",
    "        output.append(f'      Scale: {obj[\"scale\"]}')\n",
    "        if \"stability\" in obj:\n",
    "            output.append(f'      Stability: {obj[\"stability\"]}')\n",
    "        output.append(f'      3D Coords: {obj[\"3d_coords\"]}')\n",
    "        output.append(f'      Support: {obj[\"pixel_coords\"]}')\n",
    "        output.append(f'      Part Colors: {obj[\"part_color\"]}')\n",
    "        output.append(f'      Part Count: {obj[\"part_count\"]}')\n",
    "        output.append('\\n')\n",
    "    output.append(f'  Relationships: {metadata[\"relationships\"]}\\n')\n",
    "    output.append(f'  Directions: {metadata[\"directions\"]}\\n')\n",
    "    output.append(f'  Image Filename: {metadata[\"image_filename\"]}\\n')\n",
    "    if \"physics\" in metadata:\n",
    "        output.append(f'  Physics: {metadata[\"physics\"]}\\n')\n",
    "    output.append(f'  Cam location: {metadata[\"cam_location\"]}\\n')\n",
    "    output.append(f'  Cam Rotation: {metadata[\"cam_rotation\"]}\\n')\n",
    "    \n",
    "    # Add questions and answers to output\n",
    "    questions = scene['questions']\n",
    "    output_qns.append(f'Questions: {len(questions)}')\n",
    "    for c,q in enumerate(questions):\n",
    "        output_qns.append(f'    Question {c+1} : {q[\"question\"]}\\n')\n",
    "        output_ans.append(f'{q[\"answer\"]}\\n')\n",
    "    \n",
    "    # Combine the output strings and return\n",
    "    return (\"\\n\".join(output),\"\\n\".join(output_qns),\"\\n\".join(output_ans))\n",
    "\n",
    "def get_scene_input_separate(scene_number, scene_data):\n",
    "    \"\"\"\n",
    "    Get metadata for a given scene number in the PTR dataset.\n",
    "    \"\"\"\n",
    "    # Find the scene with the given number\n",
    "    scene = scene_data.get(scene_number)\n",
    "    if scene is None:\n",
    "        return f'Scene {scene_number} not found.'\n",
    "    output = []\n",
    "    output_qns = []\n",
    "    output_ans = []\n",
    "\n",
    "    # Add metadata to output\n",
    "    metadata = scene['metadata']\n",
    "    # Format all numbers in the metadata to 4 significant digits\n",
    "    metadata = format_dict(metadata)\n",
    "    \n",
    "    output.append(f'Scene {scene_number}:\\n')\n",
    "    output.append(f'  Objects: {len(metadata[\"objects\"])}')\n",
    "    for obj in metadata[\"objects\"]:\n",
    "        output.append(f'    Object:')\n",
    "        output.append(f'      Category: {obj[\"category\"]}')\n",
    "        output.append(f'      Rotation: {obj[\"rotation\"]}')\n",
    "        output.append(f'      Scale: {obj[\"scale\"]}')\n",
    "        if \"stability\" in obj:\n",
    "            output.append(f'      Stability: {obj[\"stability\"]}')\n",
    "        output.append(f'      3D Coords: {obj[\"3d_coords\"]}')\n",
    "        output.append(f'      Support: {obj[\"pixel_coords\"]}')\n",
    "        output.append(f'      Part Colors: {obj[\"part_color\"]}')\n",
    "        output.append(f'      Part Count: {obj[\"part_count\"]}')\n",
    "        output.append('\\n')\n",
    "    output.append(f'  Relationships: {metadata[\"relationships\"]}\\n')\n",
    "    output.append(f'  Directions: {metadata[\"directions\"]}\\n')\n",
    "    output.append(f'  Image Filename: {metadata[\"image_filename\"]}\\n')\n",
    "    if \"physics\" in metadata:\n",
    "        output.append(f'  Physics: {metadata[\"physics\"]}\\n')\n",
    "    output.append(f'  Cam location: {metadata[\"cam_location\"]}\\n')\n",
    "    output.append(f'  Cam Rotation: {metadata[\"cam_rotation\"]}\\n')\n",
    "    \n",
    "    # Add questions and answers to output\n",
    "    questions = scene['questions']\n",
    "#     output_qns.append(f'Questions: {len(questions)}')\n",
    "    for c,q in enumerate(questions):\n",
    "        output_qns.append(f'Question : {q[\"question\"]}\\n')\n",
    "        output_ans.append(f'{q[\"answer\"]}\\n')\n",
    "    \n",
    "    # Combine the output strings and return\n",
    "    return (\"\\n\".join(output),output_qns,\"\\n\".join(output_ans))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0446dda5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b51c8ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "metadata, qns, ans = get_scene_input_separate('PTR_val_007239',scene_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0fb42ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "type(qns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ac02e09",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_prompt(scene_number, scene_data,cot=False):\n",
    "    metadata, qns, ans = get_scene_input_separate(scene_number,scene_data)\n",
    "    scene_filename = scene_data[scene_number]['metadata']['image_filename']\n",
    "    \n",
    "    prompts = []\n",
    "    for q in qns: \n",
    "        setup = f'Answer the following question from the val split of the PTR Dataset for image {scene_filename}\\n'\n",
    "        scene_setup = \"The objects or things can have the following categories: 'Bed', 'Cart', 'Chair', 'Refrigerator', 'Table'. The different parts of the things can have the following categories: arm', 'arm horizontal bar', 'arm vertical bar', 'back', 'behind', 'body', 'central support', 'door', 'drawer', 'leg', 'leg bar', 'pedestal', 'seat', 'shelf', 'sleep area', 'top', 'wheel'. The things or objects can move in the following directions to make themselves stable: 'front', 'left', 'right'. The objects or their parts can have the following colors: 'blue', 'brown', 'cyan', 'gray', 'green', 'purple', 'red', 'yellow'. For numeric answers, give an answer in integers and not in words.\\n\"\n",
    "        \n",
    "        instruction = 'Always answer the following question in a single word from the options provided above. Your response should be just a single word.\\n'\n",
    "        prompt = setup+ scene_setup+ instruction\n",
    "        prompt +=q+'\\n'\n",
    "        prompt += 'Answer:'\n",
    "        prompts.append(prompt)\n",
    "    \n",
    "    return prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5962cfa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# with open('response_mapping_ptr_gpt.json') as json_file:\n",
    "#     response_mapping = json.load(json_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fff7b4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# response_mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba7437cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompts = get_prompt('PTR_val_007239',scene_data, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4766e2e0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82665a1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(prompts[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a14dcf5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_json(filename):\n",
    "    \"\"\"\n",
    "    Load JSON data from a file.\n",
    "\n",
    "    :param filename: Path to the file to be read.\n",
    "    :return: Decoded JSON data.\n",
    "    \"\"\"\n",
    "    with open(filename, 'r') as file:\n",
    "        data = json.load(file)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b29d3455",
   "metadata": {},
   "outputs": [],
   "source": [
    "response_ckpt = load_json('./response_mapping_ptr_chat_gpt_ablation_non_cot.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afb62de6",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(response_ckpt.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7846101",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_on_chat_gpt_non_cot_ablation(scene_start=0, scene_end=5,response_mapping = {}):\n",
    "    error_scene_indices = []\n",
    "    print(response_mapping.keys())\n",
    "\n",
    "    for index, i in tqdm.tqdm(enumerate(scene_data)):\n",
    "#         if index>scene_end:\n",
    "#             break\n",
    "        \n",
    "        if i not in response_mapping:\n",
    "            response_mapping[i] = {}\n",
    "            cot_prompts = get_prompt(i, scene_data, True)\n",
    "#             print(cot_prompts)\n",
    "\n",
    "            response_mapping[i]['non_cot_ablation'] = []\n",
    "    \n",
    "            for p in cot_prompts: \n",
    "                try:\n",
    "                    messages = [{\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
    "                                {'role': 'user', 'content': p}\n",
    "                                ]\n",
    "#                     print(messages)\n",
    "                    prompt_response, prompt_usage = ask_chat_gpt(messages)\n",
    "#                     print(prompt_response)\n",
    "                    response_mapping[i]['non_cot_ablation'].append([prompt_response, prompt_usage])\n",
    "                except Exception as e:\n",
    "                    print(e)\n",
    "                    error_scene_indices.append((i, str(e)))\n",
    "                    print('Except: ',len(error_scene_indices))\n",
    "        \n",
    "        if index %20 == 1:\n",
    "            with open('response_mapping_ptr_chat_gpt_ablation_non_cot.json', 'w') as f:\n",
    "                json.dump(response_mapping, f)\n",
    "\n",
    "    return response_mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efc3b50d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "692c24cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "response_map_instruct_gpt = run_on_chat_gpt_non_cot_ablation(0,2,response_ckpt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12f7bf6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(response_map_instruct_gpt.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed7ffa7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('response_mapping_ptr_chat_gpt_ablation_non_cot_final.json', 'w') as f:\n",
    "    json.dump(response_map_instruct_gpt, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
