{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from edf.visual_utils import visualize_samples, visualize_sample_cluster\n",
    "from edf.utils import voxelize_sample\n",
    "\n",
    "import gzip\n",
    "import pickle\n",
    "\n",
    "folder_name = 'demo'\n",
    "file_name = 'mug_task_mixed.gzip'\n",
    "path = f'{folder_name}/{file_name}'\n",
    "\n",
    "with gzip.open(path,'rb') as f:\n",
    "    samples = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = samples[0]\n",
    "\n",
    "fig, axes = plt.subplots(4,3, figsize=(20,20))\n",
    "for j in range(3):\n",
    "    axes[0,j].imshow(sample['images'][j]['color'])\n",
    "for j in range(3):\n",
    "    axes[1,j].imshow(sample['images_pick'][j]['color'])\n",
    "for j in range(3):\n",
    "    axes[2,j].imshow(sample['images_pick'][j+3]['color'])\n",
    "for j in range(3):\n",
    "    axes[3,j].imshow(sample['images_place'][j]['color'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sample Demonstration\n",
    "\n",
    "- Point clouds (Scene before pick / Gripper after pick / Scene before place)\n",
    "- End Effector Pose generated by oracle policy (Visualized as RGB frame)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_samples(samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Voxelize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "voxelized_samples = []\n",
    "for sample in samples[:]:\n",
    "    sample = voxelize_sample(sample)\n",
    "    voxelized_samples.append(sample)\n",
    "visualize_samples(voxelized_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Jitter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "voxelized_samples = []\n",
    "for sample in samples[:4]:\n",
    "    sample = voxelize_sample(sample, coord_jitter=0.1, color_jitter=0.03)\n",
    "    voxelized_samples.append(sample)\n",
    "visualize_samples(voxelized_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize Neighbor Clusterning\n",
    "\n",
    "- Blue points are the neighboring nodes to the green point.\n",
    "- Neighboring nodes are the points within max_radius\n",
    "- Therefore, max_radius effectively defines the receptive field size of the layer.\n",
    "- Increasing max_radius will result in quadratic growth of graph edge numbers and so does the memory usage.\n",
    "- Decreasing max_radius will harm the receptive field size of the layer but effectively reduces memory consumption.\n",
    "- We use 6-layer SE(3)-Transformer, so the receptive field of the model is 6x larger than the single layer receptive fields, which are visualized as blue points below. \n",
    "- Note that this is sufficiently large to cover the whole mug, but not the whole scene."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = voxelize_sample(samples[0])\n",
    "visualize_sample_cluster(sample, max_radius=0.04, max_radius_pick=0.04, max_radius_place=0.04, figsize = 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.12 ('edf')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "c0953e14757fdc95393fc3797619880897b68d726561afa3ebe46daeb55f7087"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
