{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b04cf286-3733-4d68-acfc-f1c768562cbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install ipywidgets\n",
    "# !pip install ipyevents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6a78af8b-a35c-428a-bdbc-33ffe7b469c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from IPython.display import display\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.image as mpimg\n",
    "import numpy as np\n",
    "\n",
    "import ipywidgets as widgets\n",
    "from ipyevents import Event\n",
    "from ipywidgets import GridspecLayout\n",
    "import traceback\n",
    "import json\n",
    "import glob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "365480c0-7a35-4392-9861-8053894c8e85",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_label_index(root_folder: str) -> dict:\n",
    "    label_folders = [dn for dn in os.listdir(root_folder) if os.path.isdir(os.path.join(root_folder, dn))]\n",
    "    label_index = {\"classes\": label_folders, \"data\": {}, \"meta_data\": {\"last_task_id\": 0}}\n",
    "\n",
    "    for idx, dn in enumerate(label_folders):\n",
    "        folder = os.path.join(root_folder, dn)\n",
    "        fns = [os.path.join(folder, file) for file in os.listdir(folder) if file.endswith((\"png\", \"jpg\", \"jpeg\", \"JPG\", \"JPEG\"))]\n",
    "        for fn in fns:\n",
    "            label_index[\"data\"][fn] = {\"label\": idx, \"checked\": False}\n",
    "\n",
    "    return label_index\n",
    "\n",
    "def initialize_label_index(root_folder: str) -> dict:\n",
    "    index_fn = os.path.join(root_folder, \"labels.json\")\n",
    "    \n",
    "    if os.path.exists(index_fn):\n",
    "        with open(index_fn, \"r\") as f:\n",
    "            label_index = json.load(f)\n",
    "    else:\n",
    "        label_index = create_label_index(root_folder)\n",
    "        save_label_index(root_folder, label_index)\n",
    "    return label_index\n",
    "\n",
    "def save_label_index(root_folder: str, label_index: dict) -> None:\n",
    "    index_fn = os.path.join(root_folder, \"labels.json\")\n",
    "    with open(index_fn, \"w\") as f:\n",
    "          json.dump(label_index, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "d7b5149b-c5a1-460c-a40f-44316b3589be",
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE_BORDER_WIDTH = 6\n",
    "N_IMAGES_H, N_IMAGES_V = 5, 5\n",
    "\n",
    "IMAGE_SIZE = 125\n",
    "MARGIN = 5\n",
    "\n",
    "N_IMAGES_PER_TASK = N_IMAGES_H * N_IMAGES_V\n",
    "IMAGE_BORDER_STYLES = {\n",
    "    \"-1_normal\": f\"{IMAGE_BORDER_WIDTH}px solid black\",\n",
    "    \n",
    "    \"0_normal\": f\"{IMAGE_BORDER_WIDTH}px solid #A93226\",\n",
    "    \"0_hover\": f\"{IMAGE_BORDER_WIDTH}px solid #CB4335\",\n",
    "\n",
    "    \"1_normal\": f\"{IMAGE_BORDER_WIDTH}px solid #2471A3\",\n",
    "    \"1_hover\": f\"{IMAGE_BORDER_WIDTH}px solid #2E86C1\",\n",
    "\n",
    "    \"2_normal\": f\"{IMAGE_BORDER_WIDTH}px solid #229954\",\n",
    "    \"2_hover\": f\"{IMAGE_BORDER_WIDTH}px solid #28B463\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "903f3c93-aaff-49a2-9467-b07ed6edbac8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def initialize_image(image_size: int) -> widgets.Box:\n",
    "    image = widgets.Image(\n",
    "        value=bytes(),\n",
    "        format='jpeg',\n",
    "    )\n",
    "    \n",
    "    box = widgets.Box([image], height=image_size, width=image_size)\n",
    "    box.meta_data = {}\n",
    "    box.meta_data[\"file_name\"] = None\n",
    "    box.meta_data[\"label\"] = None\n",
    "\n",
    "    box.layout.border = \"\"\n",
    "    \n",
    "    image.layout.width = \"100%\"\n",
    "    image.layout.height = \"100%\"\n",
    "    image.layout.object_fit = \"contain\"\n",
    "    image.layout.overflow = \"hidden\"\n",
    "\n",
    "    box.layout.width = f\"{image_size}px\"\n",
    "    box.layout.height = f\"{image_size}px\"\n",
    "\n",
    "    def handle_event(target, event):\n",
    "        lines = ['{}: {}'.format(k, v) for k, v in event.items()]\n",
    "        try:\n",
    "            label = target.meta_data[\"label\"]\n",
    "            if event[\"type\"] == \"click\":\n",
    "                label = target.meta_data[\"label\"] = (label + 1) % 3\n",
    "    \n",
    "                target.layout.border = IMAGE_BORDER_STYLES[f\"{label}_hover\"]\n",
    "            elif event[\"type\"] == \"mouseenter\":\n",
    "                target.layout.border = IMAGE_BORDER_STYLES[f\"{label}_hover\"]\n",
    "            elif event[\"type\"] == \"mouseleave\":\n",
    "                target.layout.border = IMAGE_BORDER_STYLES[f\"{label}_normal\"]\n",
    "            lines.append( \"{}: {}\".format(\"label\", label))\n",
    "            lines.append(\"{}: {}\".format(\"style\", image.layout.border))\n",
    "        except Exception as ex:\n",
    "            info.value = traceback.format_exc()\n",
    "    \n",
    "    box_event = Event(source=box, watched_events=['click', 'mouseenter', 'mouseleave'])\n",
    "    box_event.on_dom_event(lambda e: handle_event(box, e))\n",
    "\n",
    "    return box\n",
    "\n",
    "\n",
    "def update_task_grid(grid) -> None:\n",
    "    try:\n",
    "        for idx in range(N_IMAGES_PER_TASK):\n",
    "            i = idx // N_IMAGES_H\n",
    "            j = idx % N_IMAGES_H\n",
    "            \n",
    "            fn = grid[i, j].meta_data[\"file_name\"]\n",
    "            label = grid[i, j].meta_data[\"label\"]\n",
    "            if fn is None:\n",
    "                continue\n",
    "\n",
    "            label_index[\"data\"][fn][\"label\"] = label\n",
    "            label_index[\"data\"][fn][\"checked\"] = True\n",
    "            info.value = \"updated\"  + str(label)\n",
    "        label_index[\"meta_data\"][\"last_task_id\"] = grid.meta_data[\"current_task_id\"]\n",
    "        save_label_index(data_root_folder, label_index)\n",
    "\n",
    "        current_task_id = grid.meta_data[\"current_task_id\"]\n",
    "        \n",
    "        for idx, fn in enumerate(image_paths[current_task_id * N_IMAGES_PER_TASK:(current_task_id + 1) * N_IMAGES_PER_TASK]):\n",
    "            i = idx // N_IMAGES_H\n",
    "            j = idx % N_IMAGES_H\n",
    "            grid[i, j].layout.visibility = \"visible\"\n",
    "            with open(fn, \"rb\") as f:\n",
    "                image = f.read()\n",
    "            grid[i, j].children[0].value = image\n",
    "            label = label_index[\"data\"][fn][\"label\"]\n",
    "            \n",
    "            grid[i, j].meta_data[\"label\"] = label\n",
    "            grid[i, j].meta_data[\"file_name\"] = fn\n",
    "            grid[i, j].layout.border = IMAGE_BORDER_STYLES[f\"{label}_normal\"]\n",
    "        for idx in range(idx + 1, N_IMAGES_PER_TASK):\n",
    "            i = idx // N_IMAGES_H\n",
    "            j = idx % N_IMAGES_H\n",
    "            grid[i, j].layout.visibility = \"hidden\"\n",
    "            grid[i, j].meta_data[\"label\"] = None\n",
    "            grid[i, j].meta_data[\"file_name\"] = None\n",
    "    except:\n",
    "        info.value = traceback.format_exc()    \n",
    "\n",
    "\n",
    "def goto_next_task(grid):\n",
    "    grid.meta_data[\"current_task_id\"] = (grid.meta_data[\"current_task_id\"] + 1) % grid.meta_data[\"n_tasks\"]\n",
    "    update_task_grid(grid)\n",
    "\n",
    "\n",
    "def goto_previous_task(grid):\n",
    "    grid.meta_data[\"current_task_id\"] = (grid.meta_data[\"current_task_id\"] - 1) % grid.meta_data[\"n_tasks\"]\n",
    "    update_task_grid(grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "c2901089-4d7c-4010-9d1e-da9679800af9",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_root_folder = \"\"  # format .../dataset/0/, .../dataset/1/, .../dataset/2/.\n",
    "\n",
    "label_index = initialize_label_index(data_root_folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "d22c33bb-9d78-41a4-bd88-dd8a9b23ff90",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>.container { width:100% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:100% !important; }</style>\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "582114eb-b965-4244-8830-794f287b35d9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f0c80313e88e48739a8381ea0bb14b3d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "GridspecLayout(children=(Box(children=(Image(value=b'\\xff\\xd8\\xff\\xe0\\x00\\x10JFIF\\x00\\x01\\x01\\x00\\x00\\x01\\x00\\…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7fd3b4f350e741b48370dad58e2dd01b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HTML(value='')"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def grid_on_keyup(grid, event):\n",
    "    try:\n",
    "        if event[\"type\"] != \"keyup\":\n",
    "            return\n",
    "        if str(event[\"key\"]) == \"ArrowLeft\":\n",
    "            goto_previous_task(grid)\n",
    "        elif str(event[\"key\"]) == \"ArrowRight\":\n",
    "            goto_next_task(grid)\n",
    "    except:\n",
    "        info.value = \"Exception: \" + traceback.format_exc()\n",
    "\n",
    "n_tasks = int(np.ceil(len(label_index[\"data\"]) / N_IMAGES_PER_TASK))\n",
    "image_paths = list(label_index[\"data\"].keys())\n",
    "\n",
    "grid = GridspecLayout(N_IMAGES_V, N_IMAGES_H, height=f\"{IMAGE_SIZE * N_IMAGES_V + MARGIN * (N_IMAGES_V - 1)}px\", width=f\"{IMAGE_SIZE * N_IMAGES_H + MARGIN * (N_IMAGES_H - 1)}px\")\n",
    "grid.meta_data = {}\n",
    "grid.meta_data[\"n_tasks\"] = n_tasks\n",
    "grid.meta_data[\"current_task_id\"] = label_index[\"meta_data\"][\"last_task_id\"]\n",
    "grid_event = Event(source=grid, watched_events=['keyup'])\n",
    "grid_event.on_dom_event(lambda e: grid_on_keyup(grid, e))\n",
    "\n",
    "info = widgets.HTML(\"\")\n",
    "for idx in range(N_IMAGES_PER_TASK):\n",
    "    i = idx // N_IMAGES_H\n",
    "    j = idx % N_IMAGES_H\n",
    "    grid[i, j] = initialize_image(IMAGE_SIZE)\n",
    "\n",
    "update_task_grid(grid)\n",
    "\n",
    "display(grid, info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7187ec32-217a-4b9a-a823-955b5e1bf649",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1719807e-b76a-4ffa-9559-80c74dcfe4e5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
