{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "os.environ[\"HF_HOME\"] = \"/your/path/hf_cache\"\n",
    "from datasets import load_dataset \n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "eval_data = load_dataset('json', \n",
    "                         data_files=\"/your/path/visual7w/dataset.json\", \n",
    "                         split=\"train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "idx = 0\n",
    "# print(eval_data[idx].keys()) \n",
    "# # dict_keys(['images', 'version', 'dataset'])\n",
    "# print(eval_data[idx][\"images\"].keys()) \n",
    "# dict_keys(['filename', 'image_id', 'qa_pairs', 'split'])\n",
    "for qa_pair in eval_data[idx][\"images\"]['qa_pairs']:\n",
    "    print(qa_pair['question'], qa_pair['answer'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count = 0\n",
    "daycount = 0\n",
    "nightcount = 0\n",
    "eveningcount = 0\n",
    "afternooncount = 0\n",
    "duskcount = 0\n",
    "morningcount = 0\n",
    "risecount = 0\n",
    "for item in eval_data:\n",
    "    if item[\"images\"][\"split\"] != 'val': continue\n",
    "    for qa_pair in item[\"images\"][\"qa_pairs\"]:\n",
    "        if \"When\" in qa_pair[\"question\"] and \"taken\" in qa_pair[\"question\"]:\n",
    "            if \"day\" in qa_pair[\"answer\"].lower(): daycount += 1\n",
    "            elif \"night\" in qa_pair[\"answer\"].lower(): nightcount += 1\n",
    "            elif \"after dark\" in qa_pair[\"answer\"].lower(): \n",
    "                nightcount += 1\n",
    "                print(qa_pair[\"answer\"])\n",
    "            elif \"evening\" in qa_pair[\"answer\"].lower(): eveningcount += 1\n",
    "            elif \"afternoon\" in qa_pair[\"answer\"].lower(): afternooncount += 1\n",
    "            elif \"dusk\" in qa_pair[\"answer\"].lower(): \n",
    "                duskcount += 1\n",
    "                print(qa_pair[\"answer\"])\n",
    "            elif \"morning\" in qa_pair[\"answer\"].lower(): morningcount += 1\n",
    "            elif \"sunrise\" in qa_pair[\"answer\"].lower(): \n",
    "                risecount += 1\n",
    "        elif \"When is it\" in qa_pair[\"question\"] or \"When was it\" in qa_pair[\"question\"]:\n",
    "            if \"day\" in qa_pair[\"answer\"].lower(): daycount += 1\n",
    "            elif \"night\" in qa_pair[\"answer\"].lower(): nightcount += 1\n",
    "            elif \"after dark\" in qa_pair[\"answer\"].lower(): \n",
    "                nightcount += 1\n",
    "                print(qa_pair[\"answer\"])\n",
    "            elif \"evening\" in qa_pair[\"answer\"].lower(): eveningcount += 1\n",
    "            elif \"afternoon\" in qa_pair[\"answer\"].lower(): afternooncount += 1\n",
    "            elif \"dusk\" in qa_pair[\"answer\"].lower(): \n",
    "                duskcount += 1\n",
    "                print(qa_pair[\"answer\"])\n",
    "            elif \"morning\" in qa_pair[\"answer\"].lower(): morningcount += 1\n",
    "            elif \"sunrise\" in qa_pair[\"answer\"].lower(): \n",
    "                risecount += 1\n",
    "\n",
    "print(daycount, nightcount, eveningcount, afternooncount, duskcount, morningcount, risecount)\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Time: day, night/dark, evening, afternoon\n",
    "categories = {\"daytime\":[], \"night\":[], \"evening\":[], \"afternoon\":[], \"dusk\": [], \"morning\": [], \"sunrise\": []}\n",
    "for item in eval_data:\n",
    "    if item[\"images\"][\"split\"] != 'val': continue\n",
    "    gt = None\n",
    "    for qa_pair in item[\"images\"][\"qa_pairs\"]:\n",
    "        if \"When\" in qa_pair[\"question\"] and \"taken\" in qa_pair[\"question\"]:\n",
    "            if \"day\" in qa_pair[\"answer\"].lower(): gt = \"daytime\"\n",
    "            elif \"night\" in qa_pair[\"answer\"].lower(): gt = \"night\"\n",
    "            elif \"dark\" in qa_pair[\"answer\"].lower(): gt = \"night\"\n",
    "            elif \"evening\" in qa_pair[\"answer\"].lower(): gt = \"evening\"\n",
    "            elif \"afternoon\" in qa_pair[\"answer\"].lower(): gt = \"afternoon\"\n",
    "            elif \"dusk\" in qa_pair[\"answer\"].lower(): gt = \"dusk\"\n",
    "            elif \"morning\" in qa_pair[\"answer\"].lower(): gt = \"morning\"\n",
    "            elif \"sunrise\" in qa_pair[\"answer\"].lower(): gt = \"sunrise\"\n",
    "        elif \"When is it\" in qa_pair[\"question\"] or \"When was it\" in qa_pair[\"question\"]:\n",
    "            if \"day\" in qa_pair[\"answer\"].lower(): gt = \"daytime\"\n",
    "            elif \"night\" in qa_pair[\"answer\"].lower(): gt = \"night\"\n",
    "            elif \"dark\" in qa_pair[\"answer\"].lower(): gt = \"night\"\n",
    "            elif \"evening\" in qa_pair[\"answer\"].lower(): gt = \"evening\"\n",
    "            elif \"afternoon\" in qa_pair[\"answer\"].lower(): gt = \"afternoon\"\n",
    "            elif \"dusk\" in qa_pair[\"answer\"].lower(): gt = \"dusk\"\n",
    "            elif \"morning\" in qa_pair[\"answer\"].lower(): gt = \"morning\"\n",
    "            elif \"sunrise\" in qa_pair[\"answer\"].lower(): gt = \"sunrise\"\n",
    "\n",
    "    if gt != None: \n",
    "        categories[gt].append(item['images']['filename'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "new_data = []\n",
    "# time_of_day\n",
    "#   daytime\n",
    "#     morning\n",
    "#       sunrise\n",
    "#     afternoon\n",
    "#   night\n",
    "#   evening = dusk\n",
    "exclusive = {\"daytime\": [\"night\", \"evening\", \"dusk\"], \n",
    "             \"morning\": [\"night\", \"evening\", \"dusk\"], # Here, we exclude afternoon\n",
    "             \"sunrise\": [\"night\", \"evening\", \"dusk\"],\n",
    "             \"afternoon\": [\"night\", \"evening\", \"dusk\"],\n",
    "             \"night\": [\"daytime\", \"morning\", \"sunrise\", \"afternoon\", \"dusk\", \"evening\"],\n",
    "             \"evening\": [\"daytime\", \"morning\", \"sunrise\", \"afternoon\", \"night\"],\n",
    "             \"dusk\": [\"daytime\", \"morning\", \"sunrise\", \"afternoon\", \"night\"]}\n",
    "for label in categories:\n",
    "    for data in categories[label]:\n",
    "        new_item = {}\n",
    "        new_item['qry_text'] = f\"Find me an everyday image that is taken during the {label}.\\n\" # the scene of ...\n",
    "        new_item['qry_img_path'] = ''\n",
    "        new_item['tgt_text'] = \"<|image_1|> Represent the given image.\"\n",
    "        new_item['tgt_img_path'] = [data]\n",
    "        all_paths = []\n",
    "        for place in categories:\n",
    "            if place in exclusive[label]:\n",
    "                all_paths.extend(categories[place])\n",
    "        new_item['tgt_img_path'].extend(random.sample(all_paths, 99))\n",
    "        new_data.append(new_item)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(categories[\"dusk\"]))\n",
    "print(len(new_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('Visual7W_time_retrieval.json', 'w') as f:\n",
    "    json.dump(new_data, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test whether we can load it using load_dataset\n",
    "new_eval_data = load_dataset('json', \n",
    "                      data_files='Visual7W_time_retrieval.json',\n",
    "                      split=\"train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 1\n",
    "print(new_eval_data[idx].keys())\n",
    "print(new_eval_data[idx]['qry_text'])\n",
    "img = Image.open(\"/your/path/visual7w/images/\"+new_eval_data[idx]['tgt_img_path'][29]).convert(\"RGB\")\n",
    "img"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
