{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b9c1f385-e3af-4023-9600-6d99f7e36547",
   "metadata": {},
   "source": [
    "# Polyps Segmentation Masks Generator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8834f2f5-aca9-4925-b811-363b0855d5a3",
   "metadata": {},
   "source": [
    "This notebook is designed to read the xml files contained within the Polyps dataset to create a dataset with segmentation labels for training SPADEGAN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b38be57-ec18-4fa6-a64d-98995e3b61c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install opencv-python numpy lxml\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ce39499-d414-4ea7-9b5a-8e81b3794da1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import cv2\n",
    "import numpy as np\n",
    "import xml.etree.ElementTree as ET\n",
    "import shutil\n",
    "import random\n",
    "\n",
    "# Parameters\n",
    "sample_size = 500  # Set the number of samples you want to randomly select\n",
    "\n",
    "# Paths\n",
    "image_folder = 'path/to/Image'\n",
    "xml_folder = 'path/to/Annotation'\n",
    "output_image_folder = 'path/to/Images'\n",
    "output_mask_folder = 'path/to/Masks'\n",
    "\n",
    "# Ensure output directories are empty\n",
    "if os.path.exists(output_image_folder):\n",
    "    shutil.rmtree(output_image_folder)\n",
    "os.makedirs(output_image_folder)\n",
    "print(f\"Created new directory: {output_image_folder}\")\n",
    "\n",
    "if os.path.exists(output_mask_folder):\n",
    "    shutil.rmtree(output_mask_folder)\n",
    "os.makedirs(output_mask_folder)\n",
    "print(f\"Created new directory: {output_mask_folder}\")\n",
    "\n",
    "# Get a list of all XML files\n",
    "xml_files = [f for f in os.listdir(xml_folder) if f.endswith('.xml')]\n",
    "\n",
    "# Randomly sample the XML files\n",
    "sampled_xml_files = random.sample(xml_files, min(sample_size, len(xml_files)))\n",
    "\n",
    "# Function to create segmentation masks and save images\n",
    "def process_file(xml_file):\n",
    "    print(f\"Processing file: {xml_file}\")\n",
    "    \n",
    "    # Parse XML file\n",
    "    tree = ET.parse(os.path.join(xml_folder, xml_file))\n",
    "    root = tree.getroot()\n",
    "\n",
    "    # Extract the base filename without extension\n",
    "    base_filename = os.path.splitext(xml_file)[0]\n",
    "    width = int(root.find('size/width').text)\n",
    "    height = int(root.find('size/height').text)\n",
    "\n",
    "    print(f\"Creating mask for {base_filename} with dimensions ({width}x{height})\")\n",
    "\n",
    "    # Create an empty mask\n",
    "    mask = np.zeros((height, width), dtype=np.uint8)\n",
    "\n",
    "    # Loop through each object in the XML\n",
    "    for obj in root.findall('object'):\n",
    "        # Get bounding box coordinates\n",
    "        xmin = int(obj.find('bndbox/xmin').text)\n",
    "        ymin = int(obj.find('bndbox/ymin').text)\n",
    "        xmax = int(obj.find('bndbox/xmax').text)\n",
    "        ymax = int(obj.find('bndbox/ymax').text)\n",
    "\n",
    "        print(f\"Drawing bounding box: ({xmin}, {ymin}) to ({xmax}, {ymax})\")\n",
    "        \n",
    "        # Draw the bounding box on the mask\n",
    "        cv2.rectangle(mask, (xmin, ymin), (xmax, ymax), color=255, thickness=-1)\n",
    "\n",
    "    # Save the mask with the same name as the original XML file, but in .png format\n",
    "    mask_filename = os.path.join(output_mask_folder, base_filename + '.png')\n",
    "    cv2.imwrite(mask_filename, mask)\n",
    "    print(f\"Mask saved as: {mask_filename}\")\n",
    "\n",
    "    # Load the corresponding image\n",
    "    image_path = os.path.join(image_folder, base_filename + '.jpg')\n",
    "    if not os.path.exists(image_path):\n",
    "        image_path = os.path.join(image_folder, base_filename + '.png')\n",
    "    \n",
    "    image = cv2.imread(image_path)\n",
    "\n",
    "    # Check if the image was loaded successfully\n",
    "    if image is None:\n",
    "        print(f\"Warning: Image '{image_path}' could not be read. Skipping this file.\")\n",
    "        return\n",
    "\n",
    "    # Save the image in PNG format in the output image folder\n",
    "    output_image_filename = os.path.join(output_image_folder, base_filename + '.png')\n",
    "    cv2.imwrite(output_image_filename, image)\n",
    "    print(f\"Image saved as: {output_image_filename}\\n\")\n",
    "\n",
    "# Process each sampled file\n",
    "print(f\"Starting the process with {len(sampled_xml_files)} samples...\\n\")\n",
    "for xml_file in sampled_xml_files:\n",
    "    process_file(xml_file)\n",
    "\n",
    "print(\"Dataset creation process completed.\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
