{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Script to label object points.\n",
    "\n",
    "Adapted from P3PO - [link](https://github.com/mlevy2525/P3PO/blob/main/p3po/data_generation/label_points.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run this the first time you use this notebook for point annotation\n",
    "# This will install the necessary dependencies\n",
    "!pip install ipywidgets ipycanvas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import pickle\n",
    "from pathlib import Path\n",
    "\n",
    "#TODO: Set the task name here -- this will be used to save the output\n",
    "task_name = \"close_oven\"\n",
    "object_name = \"objects\"\n",
    "\n",
    "# If the image that shows at the bottom is bgr set original_bgr to True\n",
    "pickle_path = f\"/path/to/data/processed_data_pkl/{task_name}.pkl\"\n",
    "traj_idx = 0\n",
    "original_bgr = True\n",
    "\n",
    "#TODO: If its hard to see the image, you can increase the size_multiplier, this won't affect the selected coordinates\n",
    "size_multiplier = 1\n",
    "\n",
    "coordinates_path = f\"../../../coordinates/{task_name}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "%gui asyncio\n",
    "\n",
    "import os\n",
    "from PIL import Image\n",
    "\n",
    "import numpy as np\n",
    "from IPython.display import display, Javascript\n",
    "import ipywidgets as widgets\n",
    "from ipycanvas import Canvas, hold_canvas\n",
    "import pickle\n",
    "\n",
    "import io\n",
    "import asyncio\n",
    "import logging\n",
    "\n",
    "# Define an async function to wait for button click\n",
    "async def wait_for_click(button):\n",
    "    # Create a future object\n",
    "    future = asyncio.Future()\n",
    "    # Define the click event handler\n",
    "    def on_button_clicked(b):\n",
    "        future.set_result(None)\n",
    "    # Attach the event handler to the button\n",
    "    button.on_click(on_button_clicked)\n",
    "    # Wait until the future is set\n",
    "    await future\n",
    "\n",
    "class Points():\n",
    "    def __init__(self, pixel_key, img, coordinates_path, size_multiplier=1):\n",
    "        logging.getLogger().setLevel(logging.DEBUG)\n",
    "        logging.info(\"Starting the Points class\")\n",
    "        self.img = img\n",
    "        self.size_multiplier = size_multiplier\n",
    "        self.coordinates_path = coordinates_path\n",
    "        self.pixel_key = pixel_key\n",
    "\n",
    "        # Save the image to a bytes buffer\n",
    "        image = Image.fromarray(self.img)\n",
    "        size = img.shape\n",
    "        image = image.resize((size[1] * self.size_multiplier, size[0] * self.size_multiplier))\n",
    "        buffer = io.BytesIO()\n",
    "        image.save(buffer, format='PNG')\n",
    "        buffer.seek(0)\n",
    "\n",
    "        # Create an IPyWidgets Image widget\n",
    "        self.canvas = Canvas(width=size[1] * self.size_multiplier, height=size[0] * self.size_multiplier)\n",
    "        # Define the size of each cell\n",
    "\n",
    "        self.canvas.put_image_data(np.array(image), 0, 0)\n",
    "\n",
    "        # Display coordinates\n",
    "        coords_label = widgets.Label(value=\"Click on the image to select the coordinates\")\n",
    "\n",
    "        # Define the click event handler\n",
    "        self.coords = []\n",
    "        def on_click(x, y):\n",
    "            coords_label.value = f\"Coordinates: ({x}, {y})\"\n",
    "            self.coords.append((0, x, y))\n",
    "\n",
    "            with hold_canvas(self.canvas):\n",
    "                self.canvas.put_image_data(np.array(image), 0, 0)  # Redraw the original image\n",
    "\n",
    "                self.canvas.fill_style = 'red'\n",
    "                for coord in self.coords:\n",
    "                    x, y = coord[1] // self.size_multiplier, coord[2] // self.size_multiplier\n",
    "                    self.canvas.fill_circle(x, y, 2)\n",
    "\n",
    "        # Connect the click event to the handler\n",
    "        self.canvas.on_mouse_down(on_click)\n",
    "\n",
    "        self.button = widgets.Button(description=\"Save Points\")\n",
    "\n",
    "        # Display the widgets\n",
    "        self.vbox = widgets.VBox([self.canvas, coords_label, self.button])\n",
    "\n",
    "        # # Display the widget\n",
    "        display(self.vbox)\n",
    "\n",
    "    def on_done(self):\n",
    "        logging.info(\"saving\")\n",
    "        Path(self.coordinates_path + \"/coords/\").mkdir(parents=True, exist_ok=True)\n",
    "        with open(self.coordinates_path + \"/coords/\" + f\"{self.pixel_key}_{object_name}\" + \".pkl\", 'wb') as f:\n",
    "            try:\n",
    "                pickle.dump(self.coords, f)\n",
    "            except Exception as e:\n",
    "                logging.info(e)\n",
    "        Path(self.coordinates_path + \"/images/\").mkdir(parents=True, exist_ok=True)\n",
    "        with open(self.coordinates_path + \"/images/\" + f\"{self.pixel_key}\" + \".png\", 'wb') as f:\n",
    "            try:\n",
    "                image = Image.fromarray(self.img)\n",
    "                image.save(f)\n",
    "            except Exception as e:\n",
    "                logging.info(e)\n",
    "        logging.info(\"saved\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOTE: Label points for each pixel key. Make sure the order \n",
    "# of points is the same across pixel keys.\n",
    "pixel_key = \"pixels2\"\n",
    "    \n",
    "with open(pickle_path, 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "img = data['observations'][traj_idx][pixel_key][0]\n",
    "use_video = False\n",
    "if original_bgr:\n",
    "    img = img[:,:,::-1]        \n",
    "\n",
    "async def f():\n",
    "    point = Points(pixel_key, img, coordinates_path, size_multiplier)\n",
    "    x = await wait_for_click(point.button)\n",
    "    point.vbox.close()\n",
    "    point.canvas.close()\n",
    "    point.on_done()\n",
    "asyncio.ensure_future(f())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "p3po",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
