{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "06ec3851-be47-42e1-9b81-59cc412f9c9b",
   "metadata": {},
   "source": [
    "# Experiments with GQA nouns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1c4bc99-81ee-484c-a152-5ec82db6296b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import nltk\n",
    "import argparse\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "from nltk.tokenize import word_tokenize\n",
    "from nltk.tag import pos_tag\n",
    "from nltk.corpus import wordnet\n",
    "from nltk.stem import WordNetLemmatizer\n",
    "from nltk import FreqDist\n",
    "import pprint\n",
    "\n",
    "pp = pprint.PrettyPrinter(indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85fd4ff9-8722-4a47-a567-717d6a56f560",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download necessary NLTK resources\n",
    "nltk.download(\"punkt\")\n",
    "nltk.download(\"averaged_perceptron_tagger\")\n",
    "nltk.download(\"stopwords\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39060fcd-ae61-48d0-946f-74dfed6921fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize lemmatizer\n",
    "lemmatizer = WordNetLemmatizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ab8b144-88a5-4de4-80de-ad9938ab582d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to convert NLTK POS tags to WordNet POS tags\n",
    "def get_wordnet_pos(treebank_tag):\n",
    "    if treebank_tag.startswith(\"J\"):\n",
    "        return wordnet.ADJ\n",
    "    elif treebank_tag.startswith(\"V\"):\n",
    "        return wordnet.VERB\n",
    "    elif treebank_tag.startswith(\"N\"):\n",
    "        return wordnet.NOUN\n",
    "    elif treebank_tag.startswith(\"R\"):\n",
    "        return wordnet.ADV\n",
    "    else:\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b10d35d-4ffc-4adc-a7f3-ed857dd2f38e",
   "metadata": {},
   "outputs": [],
   "source": [
    "lemmatizer_cache = {}\n",
    "\n",
    "def lazy_lemmatize(word):\n",
    "    if word not in lemmatizer_cache:\n",
    "        lemmatizer_cache[word] = lemmatizer.lemmatize(word)\n",
    "    return lemmatizer_cache[word]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23417aca-ded9-40e3-9d07-30641f8e4665",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to extract and lemmatize nouns from text\n",
    "def extract_and_lemmatize_nouns(text):\n",
    "    tokens = word_tokenize(text)\n",
    "    tagged_tokens = pos_tag(tokens)\n",
    "    nouns = [\n",
    "        lazy_lemmatize(word)\n",
    "        for word, pos in tagged_tokens\n",
    "        if pos in [\"NN\", \"NNS\", \"NNP\", \"NNPS\"]\n",
    "    ]\n",
    "    return nouns"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cff92c55-20d3-460d-85ee-e35484ae27f9",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61e7006e-150b-44ba-9283-1c3a85cbe2b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import urllib.request\n",
    "from tqdm import tqdm\n",
    "\n",
    "def download_with_progress(url, zip_path):\n",
    "    # Define a function to update the progress bar\n",
    "    def show_progress(block_num, block_size, total_size):\n",
    "        if show_progress.bar is None:\n",
    "            # Initialize the progress bar\n",
    "            show_progress.bar = tqdm(total=total_size, unit='B', unit_scale=True, desc=\"Downloading\")\n",
    "        downloaded = block_num * block_size\n",
    "        if downloaded < total_size:\n",
    "            show_progress.bar.update(block_size)\n",
    "        else:\n",
    "            show_progress.bar.update(total_size - show_progress.bar.n)\n",
    "            show_progress.bar.close()\n",
    "\n",
    "    show_progress.bar = None\n",
    "\n",
    "    # Download the file with a progress bar\n",
    "    urllib.request.urlretrieve(url, zip_path, show_progress)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb293b9d-b2c0-4098-9a3a-64f076846616",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load image data\n",
    "url = 'https://<anon>/gqa_val.zip'\n",
    "# Unzip into ../data/gqa_images/val/\n",
    "\n",
    "import os\n",
    "import zipfile\n",
    "import urllib.request\n",
    "import requests, zipfile, io\n",
    "\n",
    "\n",
    "# Define the path to the directory\n",
    "gqa_images_dir = '../data/gqa_images/'\n",
    "\n",
    "# Define the path to the directory\n",
    "val_images_dir = '../data/gqa_images/val/'\n",
    "\n",
    "# Check if the directory does not exist\n",
    "if not os.path.exists(gqa_images_dir):\n",
    "    os.mkdir(gqa_images_dir)\n",
    "\n",
    "# Check if the directory does not exist\n",
    "if not os.path.exists(val_images_dir):\n",
    "    # Define the URL for the image data\n",
    "    url = 'https://<anon>/gqa_val_images.zip'\n",
    "    \n",
    "    # Define the path to save the zip file\n",
    "    zip_path = '../data/gqa_val.zip'\n",
    "    \n",
    "    # Download the image data\n",
    "    print(\"Downloading data...\")\n",
    "    \n",
    "    #r = requests.get(url)\n",
    "    #urllib.request.urlretrieve(url, zip_path)\n",
    "    \n",
    "\n",
    "    download_with_progress(url, zip_path)\n",
    "    print(\"Download complete.\")\n",
    "    \n",
    "    # Unzip the file into the specified directory\n",
    "    print(\"Unzipping data...\")\n",
    "    \n",
    "    #z = zipfile.ZipFile(io.BytesIO(r.content))\n",
    "    \n",
    "    #z.extractall(directory_path)\n",
    "    with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
    "        zip_ref.extractall(val_images_dir)\n",
    "    print(\"Unzipping complete.\")\n",
    "\n",
    "    \n",
    "    # Remove the zip file\n",
    "    print(\"Removing zip file...\")\n",
    "    os.remove(zip_path)\n",
    "    print(\"Zip file removed.\")\n",
    "\n",
    "print(\"Operation completed.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ead7aee-5718-4a21-b877-63a8ec35b52c",
   "metadata": {},
   "outputs": [],
   "source": [
    "questions_path = '../data/gqa_dataset/val_balanced_questions.json'\n",
    "questions = json.load(open(questions_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "226b973b-944d-4945-9cea-b0dd9fd20081",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_id = '05515938'\n",
    "sample = questions[sample_id]\n",
    "print(sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4893c68c-e278-4926-b0e6-730a0930211b",
   "metadata": {},
   "source": [
    "## Extracting nouns using WordNet lemmatizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a91aa0f-60dc-496b-9e8f-65abbd5e4f09",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to process GQA dataset\n",
    "def extract_nouns_from_wordnet(data, include_answer=False):\n",
    "    extracted_data = {}\n",
    "\n",
    "    for id, entry in data.items():\n",
    "        question = entry.get(\"question\", \"\")\n",
    "        answer = entry.get(\"fullAnswer\", entry.get(\"answer\", \"\"))\n",
    "\n",
    "        nouns = extract_and_lemmatize_nouns(question)\n",
    "        if include_answer:\n",
    "            nouns += extract_and_lemmatize_nouns(answer)\n",
    "        entry['nouns'] = nouns\n",
    "\n",
    "        # question, answer, nouns\n",
    "        extracted_data[id] = dict(entry)\n",
    "\n",
    "    return extracted_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c1c1523-de14-471e-9e14-6a99f28c11f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_lemmatized = extract_nouns_from_wordnet(questions)\n",
    "sample = data_lemmatized[sample_id]\n",
    "print(sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4046826-1455-4899-b949-0126026b2dd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(sample['nouns'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e503f1c-047f-47a0-a646-fd8a63f71aed",
   "metadata": {},
   "source": [
    "## Extract semantic arguments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f7940ba-4da4-43e2-a535-31e997717175",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(sample['semantic'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7970b555-b6e5-4b0e-8c07-18605ce803eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to process GQA dataset\n",
    "def extract_argnouns_from_semantics(data, include_answer=False):\n",
    "    extracted_data = {}\n",
    "\n",
    "    for id, entry in data.items():\n",
    "        question = entry.get(\"question\", \"\")\n",
    "        semantics = entry.get('semantic', [])\n",
    "        for operation in semantics:\n",
    "            if operation['operation'] == 'select':\n",
    "                if 'argnouns' not in entry:\n",
    "                    entry['argnouns'] = []\n",
    "                noun = lemmatizer.lemmatize(operation['argument'].split(' (')[0])\n",
    "                if noun in question:\n",
    "                    entry['argnouns'].append(noun)\n",
    "        if len(entry['argnouns']) > 0:\n",
    "            extracted_data[id] = dict(entry)\n",
    "\n",
    "    return extracted_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f104ebc4-e9ec-4b60-972f-0a637fbaefca",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_semantics = extract_argnouns_from_semantics(data_lemmatized)\n",
    "print(len(data_semantics))\n",
    "sample = data_semantics[sample_id]\n",
    "print(sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63bdf9c4-7bd5-44a3-bb49-a1c08b3eaf3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter = 0\n",
    "for id, entry in data_semantics.items():\n",
    "    \n",
    "    if len(entry['argnouns']) > 1:\n",
    "        counter = counter+1\n",
    "        print(id, entry['argnouns'])\n",
    "        if counter > 10:\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ef7d5ae-9bd5-4e58-86ff-f15e2e519123",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(sample['nouns'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "184b1e22-620a-45d5-abf6-a731a070871e",
   "metadata": {},
   "source": [
    "## Extract nouns from scene descriptions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4b3a75d-e86f-430a-b830-5fe2b556ec89",
   "metadata": {},
   "outputs": [],
   "source": [
    "def map_images_to_objects(filepath):\n",
    "    \n",
    "    img2objs = {}\n",
    "    with open(filepath, \"r\") as file:\n",
    "        scenes = json.load(file)\n",
    "        for id, entry in scenes.items():\n",
    "            objects = entry.get(\"objects\", [])\n",
    "            all_objects = []\n",
    "            for o in objects:\n",
    "                o = objects[o]\n",
    "                all_objects.append(lemmatizer.lemmatize(o[\"name\"]))\n",
    "            \n",
    "            img2objs[id] = all_objects\n",
    "\n",
    "    return img2objs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2464dff-57aa-4381-8e62-dac9ee4dbc71",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_nouns_from_scenes(data, img2objs):\n",
    "    \n",
    "    extracted_data = {}\n",
    "    for id, entry in data.items():\n",
    "        img_id = entry['imageId']\n",
    "        objs =  [o for o in img2objs[img_id] if o in entry['question'] or o in entry['answer']]\n",
    "        \n",
    "        entry['scene_objects'] = objs\n",
    "        extracted_data[id] = dict(entry)\n",
    "\n",
    "    return extracted_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79ff1f27-e40b-43f1-bc5d-f3de83444916",
   "metadata": {},
   "outputs": [],
   "source": [
    "img2objs = map_images_to_objects('../data/gqa_dataset/val_sceneGraphs.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5867a994-d910-47b6-ae2d-068fdc35075e",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = list(img2objs.keys())[0]\n",
    "print(img2objs[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5ed2f4b-4301-478e-b10c-acca2ead013e",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_scenes = extract_nouns_from_scenes(data_lemmatized, img2objs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "078f91d9-3523-4863-bc79-383d7a7ee671",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = data_scenes[sample_id]\n",
    "print(sample['question'], sample['scene_objects'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59996269-b67c-46fc-aa08-8a178ca3ac9d",
   "metadata": {},
   "source": [
    "## Comparison of nouns extracted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faac641d-55b4-48d6-9942-ae8b2cad7fcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "from IPython.display import display\n",
    "\n",
    "def visualize_image(image_path, img_id, size=300):\n",
    "    image = Image.open(f\"{image_path}/{img_id}.jpg\")\n",
    "    max_size = (size, size)  # The maximum width and height\n",
    "    image.thumbnail(max_size)\n",
    "    display(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc195f82-ae83-4c57-8eed-26d5608bf06a",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter = 0\n",
    "for entry_id in data_lemmatized.keys():\n",
    "    img_id = data_lemmatized[entry_id]['imageId']\n",
    "    visualize_image(\"../data/gqa_images/val/\", img_id)\n",
    "    print(f'QA: {data_lemmatized[entry_id][\"question\"]} - {data_lemmatized[entry_id][\"answer\"]}')\n",
    "    print('Lemmatized: ', data_lemmatized[entry_id]['nouns'])\n",
    "    if entry_id in data_semantics and 'argnouns' in data_semantics[entry_id]:\n",
    "        print('Semantics: ', data_semantics[entry_id]['argnouns'])\n",
    "    print('Scene objects: ', data_scenes[entry_id]['scene_objects'])\n",
    "    print()\n",
    "    counter = counter + 1\n",
    "    \n",
    "    if counter > 20:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18e11b78-b70f-4de4-afb5-d56eca6151a6",
   "metadata": {},
   "source": [
    "Comparing the extracted nouns and the nouns from the semantic annotations. Perhaps we should only use the ones that are matching?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4353d280-c9ed-47af-b54e-ab703dfb1009",
   "metadata": {},
   "outputs": [],
   "source": [
    "mismatched_entries = []\n",
    "counter = 0\n",
    "for q_id in data_lemmatized.keys():\n",
    "    if q_id not in data_semantics or (set(data_lemmatized[q_id]['nouns']) != set(data_semantics[q_id]['argnouns'])):\n",
    "        mismatched_entries.append(q_id)\n",
    "    counter = counter + 1\n",
    "print(\"Number of samples where nouns and argnouns are the same: \", len(data_lemmatized.keys())-len(mismatched_entries), \"of\", len(data_lemmatized.keys()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adfe0ccc-65ad-495a-81cb-cb444a5a1edc",
   "metadata": {},
   "source": [
    "## Questions with equivalents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4684771-88b6-491a-be3d-106b64f42771",
   "metadata": {},
   "outputs": [],
   "source": [
    "questions_with_equivalents = {}\n",
    "for q_id, entry in data_lemmatized.items():\n",
    "    if len(entry['equivalent']) > 1:\n",
    "        questions_with_equivalents[q_id] = dict(entry)\n",
    "print(len(questions_with_equivalents))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a513b34-2643-492a-88e6-33310132483e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "def get_random_equivalent(equivalents, sample_id):\n",
    "\n",
    "    sample = questions_with_equivalents[sample_id]\n",
    "\n",
    "    eq_id = sample['equivalent'][0]\n",
    "    for eq_id in sample['equivalent']:\n",
    "        if eq_id != sample_id:\n",
    "            return eq_id\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0b21ffd-973d-4bc8-9818-6d824b7b01e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "eq_id = None\n",
    "tries_left = 10 #temp solution to break out if no equivalents are present (as in the balanced sets)\n",
    "while eq_id not in questions:\n",
    "    sample_id = random.choice(list(questions_with_equivalents.keys()))\n",
    "    eq_id = get_random_equivalent(questions_with_equivalents, sample_id)\n",
    "    tries_left = tries_left - 1\n",
    "    if tries_left == 0:\n",
    "        break\n",
    "if tries_left:\n",
    "    print(questions[sample_id]['question'])\n",
    "    print(questions[eq_id]['question'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76bd8ca2-c02a-4fed-aedb-1ae7488cbd21",
   "metadata": {},
   "source": [
    "## Filter data on hypernyms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6842b6f3-cccb-4530-92f0-f06481ef1898",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_on_hypernyms(data, hypernyms):\n",
    "    extended_data = {}\n",
    "    for id, entry in data.items():\n",
    "        nyms = {}\n",
    "        nouns = entry[\"nouns\"]\n",
    "        filtered_nouns = []\n",
    "        for noun in nouns:\n",
    "            if noun in hypernyms:\n",
    "                found_nyms = False\n",
    "                found_overlap = False\n",
    "                for hypnym in hypernyms[noun]:\n",
    "                    if not hypnym in nyms.values():\n",
    "                        nyms[noun] = hypernyms[noun]\n",
    "                        found_nyms = True\n",
    "                    else:\n",
    "                        found_overlap = True\n",
    "                if found_nyms:\n",
    "                    filtered_nouns.append(noun)\n",
    "        \n",
    "        if len(filtered_nouns) > 0 and not found_overlap:\n",
    "            entry[\"hypernyms\"] = nyms\n",
    "            entry[\"nouns\"] = list(filtered_nouns)\n",
    "            extended_data[id] = dict(entry)\n",
    "    return extended_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3ed372f-a771-46be-a278-cdda933a8efc",
   "metadata": {},
   "outputs": [],
   "source": [
    "hypernyms = '../data/gqa_entities/noun-hypernyms.json'\n",
    "hypernyms = json.load(open(hypernyms))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73e763db-0168-4bbc-b18e-83a58a48e7f7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "filtered_data = filter_on_hypernyms(data_semantics, hypernyms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "963463fc-77e1-4fcc-9e8e-16fa53e1df34",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(filtered_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcdc01e9-7c50-4e3a-98ee-71d09d77145d",
   "metadata": {},
   "outputs": [],
   "source": [
    "rnd_id = random.choice(list(filtered_data.keys()))\n",
    "print(filtered_data[rnd_id])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b8b6e07-07e1-47f2-8401-6d3589c5d719",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(filtered_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e25e3b6-69ef-48a0-b613-e9a3cc339a77",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(data_semantics))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d5603a7-21a1-4ca8-84f9-5d3d678fa00d",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(filtered_data[rnd_id])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45553c1b-b624-4611-b05f-5a6cff09e335",
   "metadata": {},
   "source": [
    "# Examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "911dafc3-f5ec-42fa-a2ce-7222054dfbd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_seed = 1984"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1c27546-cad4-4826-a6cc-a1d0eb00cf21",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(sample_seed)\n",
    "rnd_samples = random.choices(list(filtered_data.keys()), k=12)\n",
    "print(rnd_samples)\n",
    "img_arr = [filtered_data[sample]['imageId'] for sample in rnd_samples]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c35955a4-d0dd-4329-91a7-bb4dc1b7c0c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def wrap_description(text, width=45):\n",
    "    return \"\\n\".join(textwrap.wrap(text, width=width))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2d2469f-e4d0-461a-a8fa-f506dd6c53a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1 import ImageGrid\n",
    "import numpy as np\n",
    "\n",
    "import textwrap\n",
    "\n",
    "fig = plt.figure(figsize=(20., 20.))\n",
    "grid = ImageGrid(fig, 111, \n",
    "                 nrows_ncols=(4, 3),  # creates 2x2 grid of axes\n",
    "                 axes_pad=1.2,  # pad between axes\n",
    "                 )\n",
    "\n",
    "for ax, (im, id) in zip(grid, zip(img_arr, rnd_samples)):\n",
    "    img = plt.imread(f\"../data/gqa_images/val/{im}.jpg\")\n",
    "    ax.imshow(img)\n",
    "    question = f\"{filtered_data[id]['question']}\"\n",
    "    wrapped_question = wrap_description(question, width=45)\n",
    "    title_nouns = f\"Nouns: {filtered_data[id]['nouns']}\"\n",
    "    wrapped_nouns = wrap_description(title_nouns, width=45)\n",
    "    title_argnoun = f\"Arg Noun: {filtered_data[id]['argnouns']}\"\n",
    "    wrapped_argnoun = wrap_description(title_argnoun, width=45)\n",
    "    title_hyper = f\"Hyper: {filtered_data[id]['hypernyms']}\"\n",
    "    wrapped_hyper = wrap_description(title_hyper, width=45)\n",
    "    ax.set_title(wrapped_question+\"\\n\"+wrapped_nouns+\"\\n\"+wrapped_argnoun+\"\\n\"+wrapped_hyper)\n",
    "\n",
    "plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6a130f4-6f7b-48bc-9d4d-1999e3eb3e79",
   "metadata": {},
   "source": [
    "## Hypernym substitution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5da44769-3e0d-4954-b481-0981b29a7fab",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(sample_seed)\n",
    "def replace_with_hypernym(sample, main_noun, nym_id=-1):\n",
    "    question = f\"{sample['question']}\"\n",
    "    # TODO ugly quickfix\n",
    "    if nym_id != 0:\n",
    "        # TODO ugly quickfix\n",
    "        for word, nyms in sample['hypernyms'].items():\n",
    "            if word in main_noun:\n",
    "                if nym_id < 0:\n",
    "                    nym = random.choice(nyms)\n",
    "                else: \n",
    "                    nym = nyms[nym_id-1]\n",
    "                return question.replace(main_noun, nym)\n",
    "    return sample['question']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0b395ca-35a7-48ea-a35a-9051aad1c070",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "fig = plt.figure(figsize=(20., 20.))\n",
    "grid = ImageGrid(fig, 111, \n",
    "                 nrows_ncols=(4, 3),  # creates 2x2 grid of axes\n",
    "                 axes_pad=2.0,  # pad between axes\n",
    "                 )\n",
    "\n",
    "for ax, (im, id) in zip(grid, zip(img_arr, rnd_samples)):\n",
    "    img = plt.imread(f\"../data/gqa_images/val/{im}.jpg\")\n",
    "    ax.imshow(img)\n",
    "    noun = random.choice(list(filtered_data[id]['hypernyms'].keys()))\n",
    "    question = replace_with_hypernym(filtered_data[id], noun)\n",
    "    wrapped_question =wrap_description(question, width=45)\n",
    "    title_nouns = f\"Nouns: {filtered_data[id]['nouns']}\"\n",
    "    wrapped_nouns = wrap_description(title_nouns, width=45)\n",
    "    title_argnoun = f\"Arg Noun: {filtered_data[id]['argnouns']}\"\n",
    "    wrapped_argnoun = wrap_description(title_argnoun, width=45)\n",
    "    title_hyper = f\"Hyper: {filtered_data[id]['hypernyms']}\"\n",
    "    wrapped_hyper = wrap_description(title_hyper, width=45)\n",
    "    ax.set_title(wrapped_question+\"\\n\"+wrapped_nouns+\"\\n\"+wrapped_argnoun+\"\\n\"+wrapped_hyper)\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f4734f3-1fc9-48b9-8118-8030cac72a66",
   "metadata": {},
   "source": [
    "# Visualizing the hypernym substitutions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c31d2aab-0c5b-440e-bbde-18f39cf775ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_visualization_title(data, sample_id):\n",
    "    \n",
    "    question = data[sample_id]['question']\n",
    "    wrapped_question = wrap_description(question)\n",
    "    \n",
    "    answer = data[sample_id]['fullAnswer']\n",
    "    wrapped_answer = wrap_description(answer)\n",
    "    \n",
    "    title_nouns = f\"Nouns: {data[sample_id]['nouns']}\"\n",
    "    wrapped_nouns = wrap_description(title_nouns)\n",
    "    \n",
    "    title_hyper = f\"Hyper: {data[sample_id]['hypernyms']}\"\n",
    "    wrapped_hyper = wrap_description(title_hyper)\n",
    "    return wrapped_question+\"\\n\"+answer+\"\\n\"+wrapped_nouns+\"\\n\"+wrapped_hyper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4c509b6-4c8f-434d-8ca1-25e87a4e5f60",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_example(data, sample_id):\n",
    "    \n",
    "    fig = plt.figure()\n",
    "    img_id = data[sample_id]['imageId']\n",
    "\n",
    "    img = plt.imread(f\"../data/gqa_images/val/{img_id}.jpg\")\n",
    "    plt.imshow(img)\n",
    "    title = build_visualization_title(data, sample_id)\n",
    "    plt.title(title)\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dea517a7-8b58-4739-bb97-96ab53967c7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_number_of_hypernyms(data, sample_id):\n",
    "    main_noun = data[sample_id]['argnouns']\n",
    "    nr_nyms = 0\n",
    "    for nyms in data[sample_id]['hypernyms'].values():\n",
    "        nr_nyms = nr_nyms + len(nyms)\n",
    "    return nr_nyms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83068d65-2b11-4cd4-8fc6-9bdfc42f8257",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "\n",
    "def generate_combinations(data, sample_id):\n",
    "    # Generate the Cartesian product of the two lists\n",
    "    nouns = [[noun] for noun in data[sample_id]['hypernyms'].keys()]\n",
    "    all_combinations = []\n",
    "    \n",
    "    # Iterate over all possible subset sizes (including 0 for empty combination)\n",
    "    for r in range(len(nouns) + 1):\n",
    "        # Generate all possible combinations of lists for the current subset size\n",
    "        for subset in itertools.combinations(nouns, r):\n",
    "            if subset:\n",
    "                # Generate all possible products for the current subset\n",
    "                for product in itertools.product(*subset):\n",
    "                    combo = list(product)\n",
    "                    if combo not in all_combinations:\n",
    "                        all_combinations.append(combo)\n",
    "    \n",
    "    return all_combinations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5256a546-933a-44c9-8e67-20896a01d571",
   "metadata": {},
   "outputs": [],
   "source": [
    "combos = generate_combinations(filtered_data, '08549024')\n",
    "print(combos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8e38acc-1fc1-4eda-a4cf-9c4ea2ea0110",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_all_substitutions(data, sample_id):\n",
    "    nr_examples = get_number_of_hypernyms(data, sample_id) + 1# all subs + original\n",
    "    nr_rows = 1+(nr_examples // 3)\n",
    "    nr_cols = 3 if nr_examples > 3 else nr_examples\n",
    "    fig = plt.figure(figsize=(20., 20.))\n",
    "    plt.title(\"Question: \" + data[sample_id]['question'])\n",
    "    grid = ImageGrid(fig, 111, \n",
    "                 nrows_ncols=(nr_rows, nr_cols),  # creates rows x cols grid\n",
    "                 axes_pad=2.0,  # pad between axes\n",
    "                 )\n",
    "    sample = data[sample_id]\n",
    "    \n",
    "   # answer = \"A: \"+sample['fullAnswer']\n",
    "    #wrapped_answer = wrap_description(answer)\n",
    "\n",
    "    title_nouns = f\"Nouns: {sample['nouns']}\"\n",
    "    wrapped_nouns = wrap_description(title_nouns)\n",
    "\n",
    "    title_hyper = f\"Hyper: {sample['hypernyms']}\"\n",
    "    wrapped_hyper = wrap_description(title_hyper)\n",
    "\n",
    "    \n",
    "    img_id = data[sample_id]['imageId']\n",
    "    img = plt.imread(f\"../data/gqa_images/val/{img_id}.jpg\")\n",
    "\n",
    "    \n",
    "    nym_nouns = list(sample['hypernyms'].keys())\n",
    "    grid_counter = 0\n",
    "    for noun in sample['hypernyms']:\n",
    "        for nym in sample['hypernyms'][noun]:\n",
    "            \n",
    "            ax = grid[grid_counter]\n",
    "            ax.imshow(img)\n",
    "    \n",
    "            question = data[sample_id]['question'].replace(noun, nym)\n",
    "            \n",
    "            wrapped_question = wrap_description(question)\n",
    "        \n",
    "            ax.set_title(wrapped_question+\"\\n\"+wrapped_nouns+\"\\n\"+wrapped_hyper)\n",
    "            grid_counter = grid_counter + 1\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0ceb425-563a-43cc-9c46-11959cb37a6f",
   "metadata": {},
   "source": [
    "## Naive substitution not ideal\n",
    "There are some clear issues with a naive hypernym replacement, as I'm pretty sure using 'cheese dish' or 'cheese food' to describe pizza is pretty out there and would be a dataset generation/statistical artifact that a model could pick up on. `Cheese pizza` is not extracted by any of the methods. How do we fix or filter this?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "772b689d-88c2-4c65-87ca-22c8394b8836",
   "metadata": {},
   "outputs": [],
   "source": [
    "pizza_id = '11119108'\n",
    "pizza_sample = filtered_data[pizza_id]\n",
    "visualize_all_substitutions(filtered_data, pizza_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35f5de2d-d73a-499c-bfca-b116dc60a9b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "helmet_id = '08549024'\n",
    "helmet_example = filtered_data[helmet_id]\n",
    "visualize_all_substitutions(filtered_data, helmet_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a78b59a-bc03-409f-a161-cd58e49b5f22",
   "metadata": {},
   "source": [
    "The entire JSON object with all annotations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9b0fc79-5124-409a-b0f3-0d1d7d1dff8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "pp.pprint(helmet_example)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c56997ef-c0cc-464a-9fe9-3789035c5be0",
   "metadata": {},
   "source": [
    "### Add more interesting examples that we could discuss here\n",
    "\n",
    "Duplicate the cell below and change the ID. Choose from e.g. `rnd_samples` from before:\n",
    "\n",
    "* 11119108\n",
    "* 11714339\n",
    "* 13875779\n",
    "* 18394800\n",
    "* 19663962\n",
    "* 08549024\n",
    "* 011029890\n",
    "* 07726065\n",
    "* 091044035\n",
    "* 05925408\n",
    "* 15128686\n",
    "* 0450779"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "791716c1-07ad-4ab1-9b56-6dbeacbe6b08",
   "metadata": {},
   "outputs": [],
   "source": [
    "example_id = '11119108'#TODO\n",
    "example = filtered_data[example_id]\n",
    "visualize_all_substitutions(filtered_data, example_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e40133d8-0616-445a-abd0-c5513b791476",
   "metadata": {},
   "source": [
    "# Design decisions to make"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "693a6c17-ad2e-4686-a44f-796c64d8e176",
   "metadata": {},
   "source": [
    "* Using argnouns or nouns: argnouns ensures that the nouns and their substitutions refer to objects where we can better guarantee that the substitution sounds natural and also does not introduce ambiguity (e.g. replacing `gender` with `group affiliation` in `Do all these people have the same gender?` opens up for answers that do not match the ground truth).\n",
    "* Filtering out examples based on overlap in hypernyms does not cover all examples of ambiguity, how do we fix that?\n",
    "\n",
    "The following example illustrates issues with using argnouns, as the argnoun is not mentioned in the question:\n",
    "```\n",
    "QA: Which place is it? - shore\n",
    "nouns:  ['place']\n",
    "argnouns:  ['scene']\n",
    "Scene objects:  ['shore']\n",
    "```\n",
    "\n",
    "We can filter out all samples where nouns and argnouns do not match."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8442e283-060a-4f28-916d-fcc6d703c7d7",
   "metadata": {},
   "source": [
    "# Ideas"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75e062fb-91f1-4634-ad2b-0eb0f9779ac8",
   "metadata": {},
   "source": [
    "* Substitue the noun central to the question (i.e., as given by the semantics field) vs. non-central nouns\n",
    "* Use `equivalent` as control questions to help disentangle substitution from hierarchy\n",
    "* Test multimodal models with the extended text-only representation of the images.\n",
    "* Generate sub-datasets where all substitutions are from the same level in the hierarchy (or the same number of hops away at least)\n",
    "* Substitutions in combinations (e.g. only substituting `man` vs. substituting `man` and `helmet` in `Does the man wear a helmet?`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "457d5dcb-b9b7-49c6-8a3a-51f413673223",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
