{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import json\n",
    "os.environ[\"HF_HOME\"] = \"/your/path/hf_cache\"\n",
    "from datasets import load_dataset \n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Using QA dialogues, get the answer for specific questions\n",
    "samples = []\n",
    "\n",
    "keywords = ['how many people', 'How many people', 'how many persons', 'How many persons', 'how many person', 'How many person']#, \\\n",
    "    #'how many folks', 'How many folks', \\\n",
    "    #'how many individuals', 'How many individuals', 'how many individual', 'How many individual', \\\n",
    "    #'count of people', 'Count of people', 'number of people', 'Number of people', \\\n",
    "    #'count of person', 'Count of person', 'number of person', 'Number of person', \\\n",
    "    #]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = []\n",
    "with open(\"v7w_tell.json\", \"r\") as f:\n",
    "    data = json.load(f)\n",
    "for exp in data:\n",
    "    qa = exp[\"qa_pairs\"]\n",
    "    for i in qa:\n",
    "        qu = i[\"question\"]\n",
    "        if any(k in qu for k in keywords) and exp[\"split\"] == \"val\":\n",
    "            sample = {}\n",
    "            sample[\"image_id\"]= exp[\"image_id\"]\n",
    "            sample[\"question\"] = qu\n",
    "            sample[\"answer\"] = i[\"answer\"]\n",
    "            samples.append(sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = sorted(samples, key=lambda x: x[\"image_id\"])\n",
    "os.makedirs(\"people_num\", exist_ok=True)\n",
    "# Save the samples to a JSON file\n",
    "with open(\"people_num/dialogs_v7w.json\", \"w\") as f:\n",
    "    json.dump(samples, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Loading the filtered annotations\n",
    "data = json.load(open(\"people_num/dialogs_v7w.json\"))\n",
    "# Below is just a list of image ids that are not good samples\n",
    "not_good_samples = json.load(open(\"people_num/not_good_samples.json\"))\n",
    "for i in data:\n",
    "    if i[\"image_id\"] in not_good_samples:\n",
    "        data.remove(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Directly using the filtered annotations(We can manually check the filtered images while deleting the improper annotations)\n",
    "\n",
    "for i, q in enumerate(data):\n",
    "    if q[\"answer\"] == \"1.\" or q[\"answer\"] == \"one\" or q[\"answer\"] == \"One\" or q[\"answer\"] == \"one.\" or q[\"answer\"] == \"One.\" or q[\"answer\"] == \"one person\" or q[\"answer\"] == \"One person\" or q[\"answer\"] == \"one person.\" or q[\"answer\"] == \"One person.\":\n",
    "        data[i][\"answer\"] = \"1\"\n",
    "    elif q[\"answer\"] == \"2.\" or q[\"answer\"] == \"two\" or q[\"answer\"] == \"Two\" or q[\"answer\"] == \"two.\" or q[\"answer\"] == \"Two.\" or q[\"answer\"] == \"two people\" or q[\"answer\"] == \"Two people\" or q[\"answer\"] == \"two people.\" or q[\"answer\"] == \"Two people.\":\n",
    "        data[i][\"answer\"] = \"2\"\n",
    "    elif q[\"answer\"] == \"3.\" or q[\"answer\"] == \"three\" or q[\"answer\"] == \"Three\" or q[\"answer\"] == \"three.\" or q[\"answer\"] == \"Three.\" or q[\"answer\"] == \"three people\" or q[\"answer\"] == \"Three people\" or q[\"answer\"] == \"three people.\" or q[\"answer\"] == \"Three people.\":\n",
    "        data[i][\"answer\"] = \"3\"\n",
    "    elif q[\"answer\"] == \"4.\" or q[\"answer\"] == \"four\" or q[\"answer\"] == \"Four\" or q[\"answer\"] == \"four.\" or q[\"answer\"] == \"Four.\" or q[\"answer\"] == \"four people\" or q[\"answer\"] == \"Four people\" or q[\"answer\"] == \"four people.\" or q[\"answer\"] == \"Four people.\":\n",
    "        data[i][\"answer\"] = \"4\"\n",
    "    elif q[\"answer\"] == \"5.\" or q[\"answer\"] == \"five\" or q[\"answer\"] == \"Five\" or q[\"answer\"] == \"five.\" or q[\"answer\"] == \"Five.\" or q[\"answer\"] == \"five people\" or q[\"answer\"] == \"Five people\" or q[\"answer\"] == \"five people.\" or q[\"answer\"] == \"Five people.\":\n",
    "        data[i][\"answer\"] = \"5\"\n",
    "    elif q[\"answer\"] == \"6.\" or q[\"answer\"] == \"six\" or q[\"answer\"] == \"Six\" or q[\"answer\"] == \"six.\" or q[\"answer\"] == \"Six.\" or q[\"answer\"] == \"six people\" or q[\"answer\"] == \"Six people\" or q[\"answer\"] == \"six people.\" or q[\"answer\"] == \"Six people.\":\n",
    "        data[i][\"answer\"] = \"6\"\n",
    "    elif q[\"answer\"] == \"7.\" or q[\"answer\"] == \"seven\" or q[\"answer\"] == \"Seven\" or q[\"answer\"] == \"seven.\" or q[\"answer\"] == \"Seven.\" or q[\"answer\"] == \"seven people\" or q[\"answer\"] == \"Seven people\" or q[\"answer\"] == \"seven people.\" or q[\"answer\"] == \"Seven people.\":\n",
    "        data[i][\"answer\"] = \"7\"\n",
    "    elif q[\"answer\"] == \"8.\" or q[\"answer\"] == \"eight\" or q[\"answer\"] == \"Eight\" or q[\"answer\"] == \"eight.\" or q[\"answer\"] == \"Eight.\" or q[\"answer\"] == \"eight people\" or q[\"answer\"] == \"Eight people\" or q[\"answer\"] == \"eight people.\" or q[\"answer\"] == \"Eight people.\":\n",
    "        data[i][\"answer\"] = \"8\"\n",
    "    elif q[\"answer\"] == \"9.\" or q[\"answer\"] == \"nine\" or q[\"answer\"] == \"Nine\" or q[\"answer\"] == \"nine.\" or q[\"answer\"] == \"Nine.\" or q[\"answer\"] == \"nine people\" or q[\"answer\"] == \"Nine people\" or q[\"answer\"] == \"nine people.\" or q[\"answer\"] == \"Nine people.\":\n",
    "        data[i][\"answer\"] = \"9\"\n",
    "    elif q[\"answer\"] == \"10.\" or q[\"answer\"] == \"ten\" or q[\"answer\"] == \"Ten\" or q[\"answer\"] == \"ten.\" or q[\"answer\"] == \"Ten.\" or q[\"answer\"] == \"ten people\" or q[\"answer\"] == \"Ten people\" or q[\"answer\"] == \"ten people.\" or q[\"answer\"] == \"Ten people.\":\n",
    "        data[i][\"answer\"] = \"10\"\n",
    "    elif q[\"answer\"] == \"0.\" or q[\"answer\"] == \"None\" or q[\"answer\"] == \"none\" or q[\"answer\"] == \"None.\" or q[\"answer\"] == \"none.\" or q[\"answer\"] == \"0\" or q[\"answer\"] == \"0 people\" or q[\"answer\"] == \"None people\" or q[\"answer\"] == \"none people\" or q[\"answer\"] == \"None people.\" or q[\"answer\"] == \"none people.\" \\\n",
    "        or q[\"answer\"] == \"Zero\" or q[\"answer\"] == \"zero\" or q[\"answer\"] == \"Zero.\" or q[\"answer\"] == \"zero.\" or q[\"answer\"] == \"Zero people\" or q[\"answer\"] == \"zero people\" or q[\"answer\"] == \"Zero people.\" or q[\"answer\"] == \"zero people.\":\n",
    "        data[i][\"answer\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = {}\n",
    "for d in data:\n",
    "    id = d[\"answer\"]\n",
    "    if id not in samples:\n",
    "        samples[id] = []\n",
    "    samples[id].append(\"v7w_\"+str(d[\"image_id\"])+\".jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_sample(count, image_ids):\n",
    "    # Generate a sample for the given people count and image IDs\n",
    "    if count == \"1\":\n",
    "        text = \"Find me an everyday image showing 1 person in the picture.\"\n",
    "    else:\n",
    "        text = f\"Find me an everyday image with {count} people in the picture.\"\n",
    "\n",
    "    sample = {\n",
    "        \"qry_text\": f\"{text}\",\n",
    "        \"qry_img_path\": \"\",\n",
    "        \"tgt_text\": \"<|image_1|> Represent the given image.\",\n",
    "        \"tgt_img_path\": f\"{image_ids}\"\n",
    "    }\n",
    "    return sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "# Generating negative samples\n",
    "# The negative samples are generated based on the positive samples\n",
    "final = []\n",
    "for k, v in samples.items():\n",
    "    nega = []\n",
    "    nega_ids = []\n",
    "    if k == \"0\":\n",
    "        nega = [\"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"10\", \"over 10\"]\n",
    "    elif k == \"1\":\n",
    "        nega = [\"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"10\", \"over 10\"]\n",
    "    elif k == \"2\":\n",
    "        nega = [\"5\", \"6\", \"7\", \"8\", \"9\", \"10\", \"over 10\"]\n",
    "    elif k == \"3\":\n",
    "        nega = [\"0\", \"6\", \"7\", \"8\", \"9\", \"10\", \"over 10\"]\n",
    "    elif k == \"4\":\n",
    "        nega = [\"0\", \"1\", \"7\", \"8\", \"9\", \"10\", \"over 10\"]\n",
    "    elif k == \"5\":\n",
    "        nega = [\"0\", \"1\", \"2\", \"8\", \"9\", \"10\", \"over 10\"]\n",
    "    elif k == \"6\":\n",
    "        nega = [\"0\", \"1\", \"2\", \"3\", \"9\", \"10\", \"over 10\"]\n",
    "    elif k == \"7\":\n",
    "        nega = [\"0\", \"1\", \"2\", \"3\", \"4\", \"10\", \"over 10\"]\n",
    "    elif k == \"8\":\n",
    "        nega = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"over 10\"]\n",
    "    elif k == \"9\":\n",
    "        nega = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\"]\n",
    "    elif k == \"10\":\n",
    "        nega = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\"]\n",
    "    elif k == \"over 10\":\n",
    "        nega = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\"]\n",
    "    for tag in nega:\n",
    "        if tag in samples:\n",
    "            nega_ids.extend(samples[tag])\n",
    "    for i in v:\n",
    "        nega_ids = random.sample(nega_ids, 99)\n",
    "        ids = [i]\n",
    "        ids.extend(nega_ids)\n",
    "        sample = generate_sample(k, ids)\n",
    "        final.append(sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Saving the final samples to a JSON file\n",
    "os.makedirs(\"people_num\", exist_ok=True)\n",
    "output_file = os.path.join(\"people_num\", \"people_num_retrieval.json\")\n",
    "with open(output_file, 'w', encoding='utf-8') as f:\n",
    "    json.dump(final, f, ensure_ascii=False, indent=4)"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
