{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import json\n",
    "os.environ[\"HF_HOME\"] = \"/your/path/hf_cache\"\n",
    "from datasets import load_dataset \n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loading the annotations of COCO dataset\n",
    "# Using the categories in a specific annotation (COCO-Stuff)\n",
    "with open('COCO-Stuff/stuff_val.json', 'r') as f:\n",
    "    data = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use the number of categories to filter the images\n",
    "# The categories are: wood, stone, metal, paper, brick\n",
    "wood_ids, stone_ids, metal_ids, paper_ids, brick_ids = [], [], [], [], []\n",
    "for anno in data['annotations']:\n",
    "    if anno['category_id'] == 182 or anno['category_id'] == 177 or anno['category_id'] == 118:\n",
    "        wood_ids.append(anno['image_id'])\n",
    "    elif anno['category_id'] == 116 or anno['category_id'] == 162 or anno['category_id'] == 175:\n",
    "        stone_ids.append(anno['image_id'])\n",
    "    elif anno['category_id'] == 132:\n",
    "        metal_ids.append(anno['image_id'])\n",
    "    elif anno['category_id'] == 139:\n",
    "        paper_ids.append(anno['image_id'])\n",
    "    elif anno['category_id'] == 171:\n",
    "        brick_ids.append(anno['image_id'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mix_ids = []\n",
    "for id in wood_ids:\n",
    "    if id in stone_ids or id in metal_ids or id in paper_ids or id in brick_ids:\n",
    "        mix_ids.append(id)\n",
    "for id in stone_ids:\n",
    "    if id in metal_ids or id in paper_ids or id in brick_ids:\n",
    "        mix_ids.append(id)\n",
    "for id in metal_ids:\n",
    "    if id in paper_ids or id in brick_ids:\n",
    "        mix_ids.append(id)\n",
    "for id in paper_ids:\n",
    "    if id in brick_ids:\n",
    "        mix_ids.append(id)\n",
    "        \n",
    "print(len(mix_ids))\n",
    "print(len(wood_ids), len(stone_ids), len(metal_ids), len(paper_ids), len(brick_ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('material_ids.txt', 'w') as f:\n",
    "    f.write('wood_ids:\\n')\n",
    "    for id in reversed(wood_ids):\n",
    "        if id in mix_ids:\n",
    "            wood_ids.remove(id)\n",
    "            continue\n",
    "        f.write(str(id) + '\\n')\n",
    "    f.write('stone_ids:\\n')\n",
    "    for id in reversed(stone_ids):\n",
    "        if id in mix_ids:\n",
    "            stone_ids.remove(id)\n",
    "            continue\n",
    "        f.write(str(id) + '\\n')\n",
    "    f.write('metal_ids:\\n')\n",
    "    for id in reversed(metal_ids):\n",
    "        if id in mix_ids:\n",
    "            metal_ids.remove(id)\n",
    "            continue\n",
    "        f.write(str(id) + '\\n')\n",
    "    f.write('paper_ids:\\n')\n",
    "    for id in reversed(paper_ids):\n",
    "        if id in mix_ids:\n",
    "            paper_ids.remove(id)\n",
    "            continue\n",
    "        f.write(str(id) + '\\n')\n",
    "    f.write('brick_ids:\\n')\n",
    "    for id in reversed(brick_ids):\n",
    "        if id in mix_ids:\n",
    "            brick_ids.remove(id)\n",
    "            continue\n",
    "        f.write(str(id) + '\\n')\n",
    "    f.write('mix_ids:\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"bad_ids.txt\" ,'r') as f:\n",
    "    material_ids = f.readlines()\n",
    "    for i in range(len(material_ids)):\n",
    "        id = material_ids[i].strip()\n",
    "        if id in wood_ids:\n",
    "            wood_ids.remove(id)\n",
    "        elif id in stone_ids:\n",
    "            stone_ids.remove(id)\n",
    "        elif id in metal_ids:\n",
    "            metal_ids.remove(id)\n",
    "        elif id in paper_ids:\n",
    "            paper_ids.remove(id)\n",
    "        elif id in brick_ids:\n",
    "            brick_ids.remove(id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# materials = ['wood', 'stone', 'brick', 'paper', 'metal']\n",
    "import random\n",
    "# Generating negative samples\n",
    "# The negative samples are generated based on the positive samples\n",
    "nega_map_wood, nega_map_stone, nega_map_metal, nega_map_paper, nega_map_brick = {}, {}, {}, {}, {}\n",
    "for i in wood_ids:\n",
    "    nega_ids =[]\n",
    "    nega_ids.extend(stone_ids)\n",
    "    nega_ids.extend(metal_ids)\n",
    "    nega_ids.extend(paper_ids)\n",
    "    nega_ids.extend(brick_ids)\n",
    "    nega_ids = random.sample(nega_ids, 99)\n",
    "    nega_map_wood[str(i)] = nega_ids\n",
    "for i in stone_ids:\n",
    "    nega_ids =[]\n",
    "    nega_ids.extend(wood_ids)\n",
    "    nega_ids.extend(metal_ids)\n",
    "    nega_ids.extend(paper_ids)\n",
    "    nega_ids = random.sample(nega_ids, 99)\n",
    "    nega_map_stone[str(i)] = nega_ids\n",
    "for i in metal_ids:\n",
    "    nega_ids =[]\n",
    "    nega_ids.extend(wood_ids)\n",
    "    nega_ids.extend(stone_ids)\n",
    "    nega_ids.extend(paper_ids)\n",
    "    nega_ids.extend(brick_ids)\n",
    "    nega_ids = random.sample(nega_ids, 99)\n",
    "    nega_map_metal[str(i)] = nega_ids\n",
    "for i in paper_ids:\n",
    "    nega_ids =[]\n",
    "    nega_ids.extend(wood_ids)\n",
    "    nega_ids.extend(stone_ids)\n",
    "    nega_ids.extend(metal_ids)\n",
    "    nega_ids.extend(brick_ids)\n",
    "    nega_ids = random.sample(nega_ids, 99)\n",
    "    nega_map_paper[str(i)] = nega_ids\n",
    "for i in brick_ids:\n",
    "    nega_ids =[]\n",
    "    nega_ids.extend(wood_ids)\n",
    "    nega_ids.extend(metal_ids)\n",
    "    nega_ids.extend(paper_ids)\n",
    "    nega_ids = random.sample(nega_ids, 99)\n",
    "    nega_map_brick[str(i)] = nega_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_sample(material, image_ids):\n",
    "    # Generate a sample for the given material and image IDs\n",
    "    \n",
    "    text = f\"Find me an everyday image showing some object or surface made of {material}.\"\n",
    "\n",
    "    sample = {\n",
    "        \"qry_text\": f\"{text}\",\n",
    "        \"qry_img_path\": \"\",\n",
    "        \"tgt_text\": \"<|image_1|> Represent the given image.\",\n",
    "        \"tgt_img_path\": f\"{image_ids}\"\n",
    "    }\n",
    "    return sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "final = []\n",
    "for mapping in nega_map_wood:\n",
    "    image_ids = [int(mapping)] + nega_map_wood[mapping]\n",
    "    material = 'wood'\n",
    "    # Generate a sample for the given people count and image IDs\n",
    "    sample = generate_sample(material, image_ids)\n",
    "    final.append(sample)\n",
    "for mapping in nega_map_stone:\n",
    "    image_ids = [int(mapping)] + nega_map_stone[mapping]\n",
    "    material = 'stone'\n",
    "    # Generate a sample for the given people count and image IDs\n",
    "    sample = generate_sample(material, image_ids)\n",
    "    final.append(sample)\n",
    "for mapping in nega_map_metal:\n",
    "    image_ids = [int(mapping)] + nega_map_metal[mapping]\n",
    "    material = 'metal'\n",
    "    # Generate a sample for the given people count and image IDs\n",
    "    sample = generate_sample(material, image_ids)\n",
    "    final.append(sample)\n",
    "for mapping in nega_map_paper:\n",
    "    image_ids = [int(mapping)] + nega_map_paper[mapping]\n",
    "    material = 'paper'\n",
    "    # Generate a sample for the given people count and image IDs\n",
    "    sample = generate_sample(material, image_ids)\n",
    "    final.append(sample)\n",
    "for mapping in nega_map_brick:\n",
    "    image_ids = [int(mapping)] + nega_map_brick[mapping]\n",
    "    material = 'brick'\n",
    "    # Generate a sample for the given people count and image IDs\n",
    "    sample = generate_sample(material, image_ids)\n",
    "    final.append(sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Saving the final samples to a JSON file\n",
    "os.makedirs(\"material\", exist_ok=True)\n",
    "output_file = os.path.join(\"material\", \"COCOStuff_material_retrieval.json\")\n",
    "with open(output_file, 'w', encoding='utf-8') as f:\n",
    "    json.dump(final, f, ensure_ascii=False, indent=4)"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
