{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "os.environ[\"HF_HOME\"] = \"/your/path/hf_cache\"\n",
    "from datasets import load_dataset \n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "v7w_data = load_dataset('json', \n",
    "                         data_files=\"/your/path/visual7w/dataset.json\", \n",
    "                         split=\"train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "idx = 0\n",
    "# print(v7w_data[idx].keys()) \n",
    "# # dict_keys(['images', 'version', 'dataset'])\n",
    "# print(v7w_data[idx][\"images\"].keys()) \n",
    "# dict_keys(['filename', 'image_id', 'qa_pairs', 'split'])\n",
    "for qa_pair in v7w_data[idx][\"images\"]['qa_pairs']:\n",
    "    print(qa_pair['question'], qa_pair['answer'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open(\"/your/path/visdial/dataset.json\", \"r\") as f:\n",
    "    visdial_data = json.load(f)\n",
    "\n",
    "# Find the proper questions\n",
    "Qs = []\n",
    "for i, ques in enumerate(visdial_data[\"data\"][\"questions\"]):\n",
    "    if \"weather\" in ques:\n",
    "        Qs.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dial_data = []\n",
    "for dialog in visdial_data[\"data\"][\"dialogs\"]:\n",
    "    for qa in dialog[\"dialog\"]:\n",
    "        if qa[\"question\"] in Qs:\n",
    "            QA = {}\n",
    "            QA[\"question\"] = visdial_data[\"data\"][\"questions\"][qa[\"question\"]]\n",
    "            QA[\"answer\"] = visdial_data[\"data\"][\"answers\"][qa[\"answer\"]]\n",
    "            QA[\"image_id\"] = dialog[\"image_id\"]\n",
    "            dial_data.append(QA)\n",
    "            \n",
    "bad_data = []\n",
    "with open(\"/your/path/visdial/bad_data.txt\", \"r\") as f:\n",
    "    for line in f.readlines():\n",
    "        bad_data.append(line.strip())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bad_data_coco = []\n",
    "with open(\"/your/path/visdial/bad_data_coco.txt\", \"r\") as f:\n",
    "    for line in f.readlines():\n",
    "        bad_data_coco.append(line.strip())\n",
    "with open(\"/your/path/COCO-2017/annotations/captions_val2017.json\", \"r\") as f:\n",
    "    coco_data = json.load(f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count = 0\n",
    "sunnycount = 0\n",
    "rainycount = 0\n",
    "snowycount = 0\n",
    "cloudycount = 0\n",
    "duskcount = 0\n",
    "stormycount = 0\n",
    "coldcount = 0\n",
    "warmcount = 0\n",
    "all_other_answers = []\n",
    "\n",
    "# There are instances with multiple weather queries\n",
    "# First for Visual7W\n",
    "for item in v7w_data:\n",
    "    if item[\"images\"][\"split\"] != 'val': continue\n",
    "    for qa_pair in item[\"images\"][\"qa_pairs\"]:\n",
    "        if \"weather\" in qa_pair[\"question\"]:\n",
    "            if \"sunny\" in qa_pair[\"answer\"].lower() or \"clear\" in qa_pair[\"answer\"].lower() or \"sunshine\" in qa_pair[\"answer\"].lower(): \n",
    "                sunnycount += 1\n",
    "            elif \"cloudy\" in qa_pair[\"answer\"].lower() or \"overcast\" in qa_pair[\"answer\"].lower() or \"misty\" in qa_pair[\"answer\"].lower(): \n",
    "                cloudycount += 1\n",
    "            elif \"rainy\" in qa_pair[\"answer\"].lower() or \"drizzly\" in qa_pair[\"answer\"].lower() or \"rain\" in qa_pair[\"answer\"].lower():\n",
    "                rainycount += 1\n",
    "            elif \"snow\" in qa_pair[\"answer\"].lower(): snowycount += 1\n",
    "            elif \"stormy\" in qa_pair[\"answer\"].lower(): stormycount += 1\n",
    "            elif \"warm\" in qa_pair[\"answer\"].lower(): \n",
    "                print(qa_pair['question'], qa_pair['answer'])\n",
    "                warmcount += 1\n",
    "            elif \"cold\" in qa_pair[\"answer\"].lower() or \"chilly\" in qa_pair[\"answer\"].lower(): coldcount += 1\n",
    "            else:\n",
    "                # print(qa_pair['question'], qa_pair['answer'])\n",
    "                all_other_answers.append(qa_pair['answer'])\n",
    "                continue\n",
    "            count += 1\n",
    "        elif \"cloudy\" in qa_pair[\"answer\"].lower():\n",
    "            cloudycount += 1\n",
    "            count += 1\n",
    "        \n",
    "# Then for Visual Dialog\n",
    "for item in reversed(dial_data):\n",
    "    if item[\"image_id\"] in bad_data: dial_data.remove(item)\n",
    "    if \"sunny\" in item[\"answer\"].lower() or \"clear\" in item[\"answer\"].lower() or \"sunshine\" in item[\"answer\"].lower(): \n",
    "        sunnycount += 1\n",
    "    elif \"cloudy\" in item[\"answer\"].lower() or \"overcast\" in item[\"answer\"].lower() or \"misty\" in item[\"answer\"].lower(): \n",
    "        cloudycount += 1\n",
    "    elif \"rainy\" in item[\"answer\"].lower() or \"drizzly\" in item[\"answer\"].lower() or \"rain\" in item[\"answer\"].lower():\n",
    "        rainycount += 1\n",
    "    elif \"snow\" in item[\"answer\"].lower(): snowycount += 1\n",
    "    elif \"stormy\" in item[\"answer\"].lower(): stormycount += 1\n",
    "    elif \"warm\" in item[\"answer\"].lower(): warmcount += 1\n",
    "    elif \"cold\" in item[\"answer\"].lower() or \"chilly\" in item[\"answer\"].lower(): coldcount += 1\n",
    "    else:\n",
    "        all_other_answers.append(item['answer'])\n",
    "        continue\n",
    "    count += 1\n",
    "    \n",
    "# Finally for val2017\n",
    "for anno in reversed(coco_data[\"annotations\"]):\n",
    "    if anno[\"image_id\"] in bad_data_coco:\n",
    "        coco_data[\"annotations\"].remove(anno)\n",
    "    if \"sunny\" in anno[\"caption\"].lower() or \"clear\" in anno[\"caption\"].lower() or \"sunshine\" in anno[\"caption\"].lower():\n",
    "        sunnycount += 1\n",
    "    elif \"cloudy\" in anno[\"caption\"].lower() or \"overcast\" in anno[\"caption\"].lower() or \"misty\" in anno[\"caption\"].lower():\n",
    "        cloudycount += 1\n",
    "    elif \"rainy\" in anno[\"caption\"].lower() or \"drizzly\" in anno[\"caption\"].lower() or \"rain\" in anno[\"caption\"].lower():\n",
    "        rainycount += 1\n",
    "    elif \"snow\" in anno[\"caption\"].lower(): snowycount += 1\n",
    "    elif \"stormy\" in anno[\"caption\"].lower(): stormycount += 1\n",
    "    elif \"warm\" in anno[\"caption\"].lower(): warmcount += 1\n",
    "    elif \"cold\" in anno[\"caption\"].lower() or \"chilly\" in anno[\"caption\"].lower(): coldcount += 1\n",
    "    else:\n",
    "        all_other_answers.append(anno['caption'])\n",
    "        continue\n",
    "    count += 1\n",
    "    \n",
    "\n",
    "print(count, sunnycount, rainycount, snowycount, cloudycount, stormycount, warmcount, coldcount)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from wordcloud import WordCloud\n",
    "\n",
    "# Word cloud\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Generate a word cloud\n",
    "wordcloud = WordCloud(width=800, height=400, background_color='white').generate(' '.join(all_other_answers))\n",
    "\n",
    "# Display the word cloud\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.imshow(wordcloud, interpolation='bilinear')\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Weather categories\n",
    "# Create a dictionary to hold the image IDs for each weather category\n",
    "categories = {\"sunny\":[], \"cloudy\":[], \"rainy\":[], \"snowy\":[], \"stormy\": [], \"warm\": [], \"cold\": [],\n",
    "              \"clear\":[], \"misty\": [], \"overcast\":[], \"drizzly\": [], \"chilly\": []}\n",
    "images_1, images_2, images_3 = [], [], []\n",
    "for item in v7w_data:\n",
    "    if item[\"images\"][\"split\"] != 'val': continue\n",
    "    gt = None\n",
    "    for qa_pair in item[\"images\"][\"qa_pairs\"]:\n",
    "        if \"weather\" in qa_pair[\"question\"]:\n",
    "            if \"sunny\" in qa_pair[\"answer\"].lower() or \"sunshine\" in qa_pair[\"answer\"].lower(): \n",
    "                gt = \"sunny\"\n",
    "            elif \"clear\" in qa_pair[\"answer\"].lower():\n",
    "                gt = \"clear\"\n",
    "            elif \"cloudy\" in qa_pair[\"answer\"].lower(): \n",
    "                gt = \"cloudy\"\n",
    "            elif \"overcast\" in qa_pair[\"answer\"].lower():\n",
    "                gt = \"overcast\" \n",
    "            elif \"misty\" in qa_pair[\"answer\"].lower(): \n",
    "                gt = \"misty\"\n",
    "            elif \"rainy\" in qa_pair[\"answer\"].lower() or \"rain\" in qa_pair[\"answer\"].lower():\n",
    "                gt = \"rainy\"\n",
    "            elif \"drizzly\" in qa_pair[\"answer\"].lower():\n",
    "                gt = \"drizzly\"\n",
    "            elif \"snow\" in qa_pair[\"answer\"].lower(): gt = \"snowy\"\n",
    "            elif \"stormy\" in qa_pair[\"answer\"].lower(): gt = \"stormy\"\n",
    "            elif \"warm\" in qa_pair[\"answer\"].lower(): gt = \"warm\"\n",
    "            elif \"cold\" in qa_pair[\"answer\"].lower(): gt = \"cold\"\n",
    "            elif \"chilly\" in qa_pair[\"answer\"].lower(): gt = \"chilly\"\n",
    "        elif \"cloudy\" in qa_pair[\"answer\"].lower():\n",
    "            gt = \"cloudy\"\n",
    "\n",
    "    if gt != None: \n",
    "        categories[gt].append(item['images']['filename'])\n",
    "        images_1.append(item['images']['filename'])\n",
    "        \n",
    "for item in dial_data:\n",
    "    if \"sunny\" in item[\"answer\"].lower() or \"sunshine\" in item[\"answer\"].lower(): \n",
    "        categories[\"sunny\"].append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_2.append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"clear\" in item[\"answer\"].lower():\n",
    "        categories[\"clear\"].append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_2.append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"cloudy\" in item[\"answer\"].lower(): \n",
    "        categories[\"cloudy\"].append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_2.append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"overcast\" in item[\"answer\"].lower():\n",
    "        categories[\"overcast\"].append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_2.append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"misty\" in item[\"answer\"].lower(): \n",
    "        categories[\"misty\"].append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_2.append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"rainy\" in item[\"answer\"].lower() or \"rain\" in item[\"answer\"].lower():\n",
    "        categories[\"rainy\"].append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_2.append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"drizzly\" in item[\"answer\"].lower():\n",
    "        categories[\"drizzly\"].append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_2.append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"snow\" in item[\"answer\"].lower(): \n",
    "        categories[\"snowy\"].append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_2.append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"stormy\" in item[\"answer\"].lower(): \n",
    "        categories[\"stormy\"].append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_2.append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"warm\" in item[\"answer\"].lower(): \n",
    "        categories[\"warm\"].append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_2.append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"cold\" in item[\"answer\"].lower(): \n",
    "        categories[\"cold\"].append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_2.append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"chilly\" in item[\"answer\"].lower(): \n",
    "        categories[\"chilly\"].append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_2.append(\"VisualDialog_val2018_\"+(item['image_id']).zfill(12)+\".jpg\")\n",
    "    \n",
    "for item in coco_data[\"annotations\"]:\n",
    "    if \"sunny\" in item[\"caption\"].lower() or \"sunshine\" in item[\"caption\"].lower(): \n",
    "        categories[\"sunny\"].append((item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_3.append((item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"clear\" in item[\"caption\"].lower():\n",
    "        categories[\"clear\"].append((item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_3.append((item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"cloudy\" in item[\"caption\"].lower():\n",
    "        categories[\"cloudy\"].append((item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_3.append((item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"overcast\" in item[\"caption\"].lower():\n",
    "        categories[\"overcast\"].append((item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_3.append((item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"misty\" in item[\"caption\"].lower():\n",
    "        categories[\"misty\"].append((item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_3.append((item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"rainy\" in item[\"caption\"].lower() or \"rain\" in item[\"caption\"].lower():\n",
    "        categories[\"rainy\"].append((item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_3.append((item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"drizzly\" in item[\"caption\"].lower():\n",
    "        categories[\"drizzly\"].append((item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_3.append((item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"snow\" in item[\"caption\"].lower(): \n",
    "        categories[\"snowy\"].append((item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_3.append((item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"stormy\" in item[\"caption\"].lower(): \n",
    "        categories[\"stormy\"].append((item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_3.append((item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"warm\" in item[\"caption\"].lower(): \n",
    "        categories[\"warm\"].append((item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_3.append((item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"cold\" in item[\"caption\"].lower(): \n",
    "        categories[\"cold\"].append((item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_3.append((item['image_id']).zfill(12)+\".jpg\")\n",
    "    elif \"chilly\" in item[\"caption\"].lower(): \n",
    "        categories[\"chilly\"].append((item['image_id']).zfill(12)+\".jpg\")\n",
    "        images_3.append((item['image_id']).zfill(12)+\".jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check the repetition between sets\n",
    "import imagehash\n",
    "\n",
    "def are_similar_hash(image_path1, image_path2, hash_size=8, similarity_threshold=10):\n",
    "    \"\"\"\n",
    "    Use hash to show similarity\n",
    "\n",
    "    Args:\n",
    "        image_path1 (str)\n",
    "        image_path2 (str)\n",
    "        hash_size (int)\n",
    "        similarity_threshold (int)\n",
    "\n",
    "    Returns:\n",
    "        bool\n",
    "    \"\"\"\n",
    "    try:\n",
    "        img1 = Image.open(image_path1).resize((hash_size, hash_size), Image.Resampling.LANCZOS).convert('L')\n",
    "        img2 = Image.open(image_path2).resize((hash_size, hash_size), Image.Resampling.LANCZOS).convert('L')\n",
    "\n",
    "        hash1 = imagehash.phash(img1, hash_size=hash_size)\n",
    "        hash2 = imagehash.phash(img2, hash_size=hash_size)\n",
    "\n",
    "        if (hash1 - hash2) < similarity_threshold:\n",
    "            return True\n",
    "        else:\n",
    "            return False\n",
    "    except FileNotFoundError:\n",
    "        print(\"Error: One or both image files not found.\")\n",
    "        print(f\"File not found: {image_path1} or {image_path2}\")\n",
    "        return False\n",
    "    except Exception as e:\n",
    "        print(f\"An error occurred: {e}\")\n",
    "        return False\n",
    "\n",
    "def are_similar_ahash(image_path1, image_path2, hash_size=8, similarity_threshold=10):\n",
    "    \"\"\"\n",
    "    Use the average hash to decide the similarity\n",
    "    \"\"\"\n",
    "    try:\n",
    "        img1 = Image.open(image_path1).resize((hash_size, hash_size), Image.Resampling.LANCZOS).convert('L')\n",
    "        img2 = Image.open(image_path2).resize((hash_size, hash_size), Image.Resampling.LANCZOS).convert('L')\n",
    "\n",
    "        hash1 = imagehash.average_hash(img1, hash_size=hash_size)\n",
    "        hash2 = imagehash.average_hash(img2, hash_size=hash_size)\n",
    "\n",
    "        if (hash1 - hash2) < similarity_threshold:\n",
    "            return True\n",
    "        else:\n",
    "            return False\n",
    "    except FileNotFoundError:\n",
    "        print(\"Error: One or both image files not found.\")\n",
    "        return False\n",
    "    except Exception as e:\n",
    "        print(f\"An error occurred: {e}\")\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for file in images_2:\n",
    "    for another in images_1:\n",
    "        hash = are_similar_ahash(\"you/path/images\"+file, \"you/path/images\"+another)\n",
    "        ahash = are_similar_ahash(\"you/path/images\"+file, \"you/path/images\"+another)\n",
    "        if hash or ahash:\n",
    "            for key in categories:\n",
    "                if file in categories[key]:\n",
    "                    categories[key].remove(file)\n",
    "                    break\n",
    "        \n",
    "for file in images_3:\n",
    "    for another in images_1:\n",
    "        hash = are_similar_ahash(\"you/path/images\"+file, \"you/path/images\"+another)\n",
    "        ahash = are_similar_ahash(\"you/path/images\"+file, \"you/path/images\"+another)\n",
    "        if hash or ahash:\n",
    "            for key in categories:\n",
    "                if file in categories[key]:\n",
    "                    categories[key].remove(file)\n",
    "                    break\n",
    "    for another in images_2:\n",
    "        hash = are_similar_ahash(\"you/path/images\"+file, \"you/path/images\"+another)\n",
    "        ahash = are_similar_ahash(\"you/path/images\"+file, \"you/path/images\"+another)\n",
    "        if hash or ahash:\n",
    "            for key in categories:\n",
    "                if file in categories[key]:\n",
    "                    categories[key].remove(file)\n",
    "                    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "new_data = []\n",
    "exclusive = {\"sunny\": [\"cloudy\", \"misty\", \"overcast\", \"rainy\", \"drizzly\", \"snowy\", \"stormy\", \"cold\", \"chilly\"], \n",
    "             \"clear\":  [\"cloudy\", \"misty\", \"overcast\", \"rainy\", \"drizzly\", \"snowy\", \"stormy\", \"cold\", \"chilly\"], \n",
    "             \"misty\":  [\"sunny\", \"clear\", \"snowy\"], \n",
    "             \"overcast\": [\"sunny\", \"clear\", \"snowy\"],\n",
    "             \"cloudy\": [\"sunny\", \"clear\", \"snowy\"],\n",
    "             \"rainy\": [\"sunny\", \"clear\", \"warm\", \"snowy\"], \n",
    "             \"drizzly\": [\"sunny\", \"clear\", \"warm\", \"snowy\"], \n",
    "             \"stormy\": [\"sunny\", \"clear\", \"warm\", \"snowy\"],\n",
    "             \"snowy\": [\"sunny\", \"clear\", \"rainy\", \"drizzly\", \"stormy\", \"warm\"],\n",
    "             \"warm\": [\"rainy\", \"stormy\", \"snowy\", \"cold\", \"chilly\", \"overcast\", \"misty\"],\n",
    "             \"cold\": [\"sunny\", \"clear\", \"warm\"],\n",
    "             \"chilly\": [\"sunny\", \"clear\", \"warm\"]}\n",
    "for label in categories:\n",
    "    for data in categories[label]:\n",
    "        new_item = {}\n",
    "        if label == \"warm\":\n",
    "            continue # We exclude warm here\n",
    "        if label == \"overcast\":\n",
    "            new_item['qry_text'] = f\"Find me an everyday image that is taken in an {label} day.\\n\"\n",
    "        else:\n",
    "            new_item['qry_text'] = f\"Find me an everyday image that is taken in a {label} day.\\n\" \n",
    "        new_item['qry_img_path'] = ''\n",
    "        new_item['tgt_text'] = \"<|image_1|> Represent the given image.\"\n",
    "        new_item['tgt_img_path'] = [data]\n",
    "        all_paths = []\n",
    "        for place in categories:\n",
    "            if place in exclusive[label]:\n",
    "                all_paths.extend(categories[place])\n",
    "        new_item['tgt_img_path'].extend(random.sample(all_paths, 99))\n",
    "        new_data.append(new_item)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('mix_weather_retrieval.json', 'w') as f:\n",
    "    json.dump(new_data, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test whether we can load it using load_dataset\n",
    "new_v7w_data = load_dataset('json', \n",
    "                      data_files='mix_weather_retrieval.json',\n",
    "                      split=\"train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 1\n",
    "print(new_v7w_data[idx].keys())\n",
    "print(new_v7w_data[idx]['qry_text'])\n",
    "img = Image.open(\"/your/path/visual7w/images/\"+new_v7w_data[idx]['tgt_img_path'][60]).convert(\"RGB\")\n",
    "img"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
