{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import json\n",
    "os.environ[\"HF_HOME\"] = \"/your/path/hf_cache\"\n",
    "from datasets import load_dataset \n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loading the captions of COCO dataset\n",
    "captions_data = load_dataset('json', \n",
    "                         data_files=\"/your/path/annotations/captions_val2017.json\", \n",
    "                         split=\"val\")\n",
    "print(f\"Totally {len(captions_data['annotations'])} descriptions, regarding {len(captions_data['images'])} images.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filtering the captions with specific keywords\n",
    "# Only as an example of some keywords of gestures, we can modify the keywords as needed\n",
    "keywords = [' stand ', ' stands ', ' standing', ' sit ', ' sits ', ' sitting', ' jump ', ' jumps ', ' jumping', ' lie ', ' lies ', ' lying', \\\n",
    "    ' bend ', ' bends ', ' bending', ' kneel ', ' kneels ', ' kneeling', ' squat ', ' squats ', ' squating', \\\n",
    "    ' crawl ', ' crawls ', ' crawling']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The codes below are used to filter the images that are related to the positive keywords\n",
    "filtered_annotations = []\n",
    "for ann in captions_data['annotations']:\n",
    "    for keyword in keywords:\n",
    "        if keyword in ann['caption'].lower():\n",
    "            if not keyword.endswith('ing') and not keyword.endswith('s'):\n",
    "                ann[\"category\"] = keyword\n",
    "            elif keyword.endswith('s'):\n",
    "                ann[\"category\"] = keyword[:-1]\n",
    "            elif keyword == 'lying':\n",
    "                ann[\"category\"] = 'lie'\n",
    "            elif keyword == 'sitting':\n",
    "                ann[\"category\"] = 'sit'\n",
    "            else:\n",
    "                ann[\"category\"] = keyword[:-3]\n",
    "            filtered_annotations.append(ann)\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filtering the images based on the filtered annotations\n",
    "filtered_image_ids = [ann['image_id'] for ann in filtered_annotations]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "with open(\"bad_ids.txt\" ,'r') as f:\n",
    "    gesture_ids = f.readlines()\n",
    "    for i in range(len(gesture_ids)):\n",
    "        id = gesture_ids[i].strip()\n",
    "        if id in filtered_image_ids:\n",
    "            filtered_image_ids.remove(id)\n",
    "            for ann in captions_data['annotations']:\n",
    "                if ann['image_id'] == id:\n",
    "                    filtered_annotations.remove(ann)\n",
    "                    break\n",
    "                \n",
    "filtered_images = [\n",
    "    img for img in captions_data['images']\n",
    "    if img['id'] in filtered_image_ids\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optional\n",
    "# Saving the filtered annotations and images to a new JSON file\n",
    "filtered_data = {\n",
    "    \"info\": captions_data[\"info\"],\n",
    "    \"licenses\": captions_data[\"licenses\"],\n",
    "    \"images\": filtered_images,\n",
    "    \"annotations\": filtered_annotations\n",
    "}\n",
    "os.system(\"mkdir anno_filter\")\n",
    "with open('anno_filter/filtered.json', 'w') as f:\n",
    "    json.dump(filtered_data, f, indent=2)\n",
    "print(f\"Extracting {len(filtered_annotations)} descriptions, regarding {len(filtered_images)} images.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optional: Copying the filtered images to a new directory\n",
    "# Note: Helping us to check the filtered images manually\n",
    "os.system(\"mkdir fil_imgs\")\n",
    "for id in filtered_image_ids:\n",
    "    file = \"val2017/\"+str(id).zfill(12)+\".jpg\"\n",
    "    file2 = \"fil_imgs/\"+str(id).zfill(12)+\".jpg\"\n",
    "    os.system(f'cp \"{file}\" \"{file2}\"')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_sample(gesture, image_ids):\n",
    "    # Generate a sample for the given gesture and image IDs\n",
    "    if gesture == \"lie\":\n",
    "        text = \"Find me an everyday image that contains someone lying.\"\n",
    "    elif gesture == \"sit\":\n",
    "        text = \"Find me an everyday image that contains someone sitting.\"\n",
    "    else:\n",
    "        text = f\"Find me an everyday image that contains someone {gesture}ing.\"\n",
    "\n",
    "    sample = {\n",
    "        \"qry_text\": f\"{text}\",\n",
    "        \"qry_img_path\": \"\",\n",
    "        \"tgt_text\": \"<|image_1|> Represent the given image.\",\n",
    "        \"tgt_img_path\": f\"{image_ids}\"\n",
    "    }\n",
    "    return sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "# Generating negative samples\n",
    "# The negative samples are generated based on the positive samples\n",
    "final = []\n",
    "for anno in filtered_annotations:\n",
    "    image_id = anno['image_id']\n",
    "    gesture = anno['category']\n",
    "    gestures = ['stand', 'sit', 'jump', 'lie', 'bend', 'kneel', 'squat', 'crawl']\n",
    "    if gesture == 'squat' or gesture == 'kneel':\n",
    "        gestures.remove('squat')\n",
    "        gestures.remove('kneel')\n",
    "    else:\n",
    "        gestures.remove(gesture)\n",
    "        \n",
    "    # Randomly select negative gesture samples from the list of gestures\n",
    "    nega_ids = []\n",
    "    for a in filtered_annotations:\n",
    "        if a['category'] in gestures:\n",
    "            nega_ids.append(a['image_id'])\n",
    "    nega_ids = random.sample(nega_ids, 99)\n",
    "    image_ids = [image_id] + nega_ids\n",
    "    \n",
    "    # Generate a sample for the given people count and image IDs\n",
    "    sample = generate_sample(gesture, image_ids)\n",
    "    final.append(sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Saving the final samples to a JSON file\n",
    "os.makedirs(\"gesture\", exist_ok=True)\n",
    "output_file = os.path.join(\"gesture\", \"COCO_gesture_retrieval.json\")\n",
    "with open(output_file, 'w', encoding='utf-8') as f:\n",
    "    json.dump(final, f, ensure_ascii=False, indent=4)"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
