{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import argparse\n",
    "import time\n",
    "import datetime\n",
    "import random\n",
    "import numpy as np\n",
    "import yaml\n",
    "import gzip\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image, ImageDraw\n",
    "import cv2\n",
    "\n",
    "import torch\n",
    "from pytorch3d import transforms\n",
    "\n",
    "from edf.utils import preprocess, voxelize_sample, OrthoTransform\n",
    "from edf.dist import GaussianDistSE3\n",
    "\n",
    "from baselines.equiv_tn.sixdof_non_equi_transporter import TransporterAgent\n",
    "from baselines.equiv_tn.utils import perturb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 0\n",
    "device = 'cuda'\n",
    "\n",
    "\n",
    "if device == 'cpu':\n",
    "    torch.use_deterministic_algorithms(True)\n",
    "elif device == 'cuda':\n",
    "    #torch.use_deterministic_algorithms(True)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "torch.set_printoptions(precision=4, sci_mode=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_dir = 'demo/mug_task_rim.gzip'\n",
    "with gzip.open(sample_dir,'rb') as f:\n",
    "    train_samples = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "H = W = 160\n",
    "crop_size = 16*6\n",
    "ortho_ranges = np.array([[0.4, 0.8],[-0.2, 0.2], [0., 0.4]])\n",
    "ortho_transform = OrthoTransform(W = W, ranges = ortho_ranges[:2])\n",
    "pix_size = (ortho_ranges[0,1] - ortho_ranges[0,0]) / H\n",
    "\n",
    "perturb_dist = GaussianDistSE3(std_theta = 2./180*np.pi, std_X = 0.002)\n",
    "perturb_dist.dist_R.get_inv_cdf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = TransporterAgent(name='any', task='any', root_dir='checkpoint_tn/rim', device=device, load=False, crop_size = crop_size, pix_size = pix_size, bounds = ortho_ranges, H=H, W=W, n_rotations=36, lr=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample = train_samples[1].copy()\n",
    "# sample['d'] = 0.001\n",
    "# sample = voxelize_sample(sample, coord_jitter=3., color_jitter=0.03, pick=True, place=False)\n",
    "\n",
    "# in_range_idx = ((sample['coord'][..., -1] > ortho_ranges[-1][0]) * (sample['coord'][..., -1] < ortho_ranges[-1][1]))\n",
    "# coord = sample['coord'][in_range_idx]\n",
    "# color = sample['color'][in_range_idx]\n",
    "\n",
    "# img = ortho_transform.orthographic(coord, color)\n",
    "\n",
    "\n",
    "# pick, place = sample['grasp'], sample['place']\n",
    "# pick = torch.cat([transforms.matrix_to_quaternion(torch.from_numpy(pick[1])), torch.from_numpy(pick[0])], dim=-1)\n",
    "# place = torch.cat([transforms.matrix_to_quaternion(torch.from_numpy(place[1])), torch.from_numpy(place[0])], dim=-1)\n",
    "\n",
    "# pick = perturb_dist.propose(pick)\n",
    "# place = perturb_dist.propose(place)\n",
    "\n",
    "# pick = (pick[4:].numpy(), transforms.quaternion_to_matrix(pick[:4]).numpy())\n",
    "# place = (place[4:].numpy(), transforms.quaternion_to_matrix(place[:4]).numpy())\n",
    "\n",
    "# pick = ortho_transform.pose2pix_yaw_zrp(pick) # grasp_pix, yaw, height, roll, pitch \n",
    "# place = ortho_transform.pose2pix_yaw_zrp(place) # grasp_pix, yaw, height, roll, pitch\n",
    "\n",
    "# img_test = (img.copy()[...,:3]*255).astype(np.uint8)\n",
    "# # img_test = cv2.arrowedLine(img_test, pick[0][...,::-1], (pick[0][...,::-1] + np.array([np.cos(pick[1]/180*np.pi), -np.sin(pick[1]/180*np.pi)]) * 30).astype(int), (255,0,255), thickness = 3, tipLength=0.3)\n",
    "# # img_test = cv2.arrowedLine(img_test, place[0][...,::-1], (place[0][...,::-1] + np.array([np.cos(place[1]/180*np.pi), -np.sin(place[1]/180*np.pi)]) * 30).astype(int), (0,0,255), thickness = 3, tipLength=0.3)\n",
    "# # #img_test = Image.fromarray(img_test)\n",
    "# # crop_size = 16*14\n",
    "# # crop_test = (img.copy()[...,:3]*255).astype(np.uint8)\n",
    "# # crop_test = crop_test[pick[0][0]-crop_size//2:pick[0][0]+crop_size//2, pick[0][1]-crop_size//2:pick[0][1]+crop_size//2]\n",
    "# # #crop_test = Image.fromarray(crop_test)\n",
    "\n",
    "# img_mean, img_std = np.array([[0.5, 0.5, 0.5, 0.25]]), np.array([[0.5, 0.5, 0.5, 0.25]])\n",
    "# img = (img - img_mean) / img_std\n",
    "# img = np.concatenate((img[Ellipsis, :3],\n",
    "#                     img[Ellipsis, 3:4],\n",
    "#                     img[Ellipsis, 3:4],\n",
    "#                     img[Ellipsis, 3:4]), axis=2).astype(np.float32)\n",
    "\n",
    "# pick[1] = pick[1] /180 *np.pi\n",
    "# pick[3] = pick[3] /180 *np.pi\n",
    "# pick[4] = pick[4] /180 *np.pi\n",
    "# place[1] = place[1] /180 *np.pi\n",
    "# place[3] = place[3] /180 *np.pi\n",
    "# place[4] = place[4] /180 *np.pi\n",
    "\n",
    "# p0, p1 = agent.act(img=img)\n",
    "# img_test = cv2.arrowedLine(img_test, np.array(p0[0])[...,::-1], (np.array(p0[0])[...,::-1] + np.array([np.cos(p0[1]), -np.sin(p0[1])]) * 30).astype(int), (255,0,255), thickness = 3, tipLength=0.3)\n",
    "# img_test = cv2.arrowedLine(img_test, np.array(p1[0])[...,::-1], (np.array(p1[0])[...,::-1] + np.array([np.cos(p1[1]), -np.sin(p1[1])]) * 30).astype(int), (0,0,255), thickness = 3, tipLength=0.3)\n",
    "# fig, axes = plt.subplots(1, figsize=(8,8))\n",
    "# axes.imshow(img_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#train_samples = train_samples[::2]\n",
    "#train_samples = train_samples[0:1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_epochs = 200\n",
    "iters = 0\n",
    "\n",
    "for epoch in range(1, max_epochs+1):\n",
    "    train_sample_indices = list(range(len(train_samples)))\n",
    "    np.random.shuffle(train_sample_indices)\n",
    "    for train_sample_idx in train_sample_indices:\n",
    "        iters += 1\n",
    "        sample = train_samples[train_sample_idx].copy()\n",
    "        sample['d'] = 0.001\n",
    "        sample = voxelize_sample(sample, coord_jitter=3., color_jitter=0.03, pick=True, place=False)\n",
    "\n",
    "        in_range_idx = ((sample['coord'][..., -1] > ortho_ranges[-1][0]) * (sample['coord'][..., -1] < ortho_ranges[-1][1]))\n",
    "        coord = sample['coord'][in_range_idx]\n",
    "        color = sample['color'][in_range_idx]\n",
    "\n",
    "        img = ortho_transform.orthographic(coord, color)\n",
    "\n",
    "\n",
    "        pick, place = sample['grasp'], sample['place']\n",
    "        pick = torch.cat([transforms.matrix_to_quaternion(torch.from_numpy(pick[1])), torch.from_numpy(pick[0])], dim=-1)\n",
    "        place = torch.cat([transforms.matrix_to_quaternion(torch.from_numpy(place[1])), torch.from_numpy(place[0])], dim=-1)\n",
    "\n",
    "        pick = perturb_dist.propose(pick)\n",
    "        place = perturb_dist.propose(place)\n",
    "\n",
    "        pick = (pick[4:].numpy(), transforms.quaternion_to_matrix(pick[:4]).numpy())\n",
    "        place = (place[4:].numpy(), transforms.quaternion_to_matrix(place[:4]).numpy())\n",
    "\n",
    "        pick = ortho_transform.pose2pix_yaw_zrp(pick, grasp='top') # grasp_pix, yaw, height, roll, pitch \n",
    "        place = ortho_transform.pose2pix_yaw_zrp(place, grasp='top') # grasp_pix, yaw, height, roll, pitch \n",
    "\n",
    "        # img_test = (img.copy()[...,:3]*255).astype(np.uint8)\n",
    "        # img_test = cv2.arrowedLine(img_test, pick[0][...,::-1], (pick[0][...,::-1] + np.array([np.cos(pick[1]/180*np.pi), -np.sin(pick[1]/180*np.pi)]) * 30).astype(int), (255,0,255), thickness = 3, tipLength=0.3)\n",
    "        # img_test = cv2.arrowedLine(img_test, place[0][...,::-1], (place[0][...,::-1] + np.array([np.cos(place[1]/180*np.pi), -np.sin(place[1]/180*np.pi)]) * 30).astype(int), (0,0,255), thickness = 3, tipLength=0.3)\n",
    "        # #img_test = Image.fromarray(img_test)\n",
    "        # crop_size = 16*14\n",
    "        # crop_test = (img.copy()[...,:3]*255).astype(np.uint8)\n",
    "        # crop_test = crop_test[pick[0][0]-crop_size//2:pick[0][0]+crop_size//2, pick[0][1]-crop_size//2:pick[0][1]+crop_size//2]\n",
    "        # #crop_test = Image.fromarray(crop_test)\n",
    "\n",
    "        img_mean, img_std = np.array([[0.5, 0.5, 0.5, 0.25]]), np.array([[0.5, 0.5, 0.5, 0.25]])\n",
    "        img = (img - img_mean) / img_std\n",
    "        img = np.concatenate((img[Ellipsis, :3],\n",
    "                            img[Ellipsis, 3:4],\n",
    "                            img[Ellipsis, 3:4],\n",
    "                            img[Ellipsis, 3:4]), axis=2).astype(np.float32)\n",
    "\n",
    "        pick[1] = (pick[1] /180 *np.pi + 2*np.pi)%(2*np.pi)     # (-180~180) -> (0 ~ 2pi) Yaw\n",
    "        pick[3] = (pick[3] /180 *np.pi + 2*np.pi)%(2*np.pi)     # (-180~180) -> (0 ~ 2pi) Roll\n",
    "        pick[4] = (pick[4] /180 *np.pi)                         # (-90~90)   -> (-pi/2 ~ pi/2) Pitch\n",
    "        place[1] = (place[1] /180 *np.pi + 2*np.pi)%(2*np.pi)   # (-180~180) -> (0 ~ 2pi) Yaw\n",
    "        place[3] = (place[3] /180 *np.pi + 2*np.pi)%(2*np.pi)   # (-180~180) -> (0 ~ 2pi) Roll\n",
    "        place[4] = (place[4] /180 *np.pi)                       # (-90~90)   -> (-pi/2 ~ pi/2) Pitch\n",
    "        img, _, (pick[0], place[0]), (theta, trans, pivot) = perturb(img, [pick[0], place[0]], rim_offset= H//6)\n",
    "        pick[1] = (pick[1] - theta + 2*np.pi) % (2*np.pi)\n",
    "        place[1] = (place[1] - theta + 2*np.pi) % (2*np.pi)\n",
    "\n",
    "        img_visual = img[...,:4].copy() * img_std + img_mean\n",
    "        img_visual = img_visual - img_visual.min()\n",
    "        img_visual = img_visual / img_visual.max()\n",
    "\n",
    "        img_out = (img_visual.copy()[...,:3]*255).astype(np.uint8)\n",
    "        img_gt = (img_visual.copy()[...,:3]*255).astype(np.uint8)\n",
    "        crop_test = (img_visual.copy()[...,:3]*255).astype(np.uint8)\n",
    "        crop_test = np.pad(crop_test, ((crop_size//2,crop_size//2), (crop_size//2,crop_size//2), (0, 0)))\n",
    "        crop_test = crop_test[pick[0][0]:pick[0][0]+crop_size, pick[0][1]:pick[0][1]+crop_size]\n",
    "\n",
    "        data = (img, pick, place)\n",
    "        \n",
    "\n",
    "        agent.train(data)\n",
    "        \n",
    "        if iters % 50 == 0 or iters == 1 or False:\n",
    "            agent.save()\n",
    "            with torch.no_grad():\n",
    "                p0, p1, confs = agent.act(img=img, return_output=True, gt_data = data)\n",
    "            pick_conf, place_conf, crop = confs\n",
    "            pick_conf = pick_conf - pick_conf.min()\n",
    "            pick_conf = pick_conf / pick_conf.max() * 255\n",
    "            place_conf = place_conf - place_conf.min()\n",
    "            place_conf = place_conf / place_conf.max() * 255\n",
    "            # print(f\"pick:    {p0}\")\n",
    "            # print(f\"pick_gt:    {pick[0], pick[1]}\")\n",
    "            # print(f\"place:    {p1}\")\n",
    "            # print(f\"place_gt:    {place[0], place[1]}\")\n",
    "            # print(np.unravel_index(np.argmax(pick_conf), pick_conf.shape))\n",
    "            # print(np.unravel_index(np.argmax(place_conf), place_conf.shape))\n",
    "\n",
    "            img_out = cv2.arrowedLine(img_out, np.array(p0[0])[...,::-1], (np.array(p0[0])[...,::-1] + np.array([np.cos(p0[1]), -np.sin(p0[1])]) * 30).astype(int), (255,0,255), thickness = 3, tipLength=0.3)\n",
    "            img_out = cv2.arrowedLine(img_out, np.array(p1[0])[...,::-1], (np.array(p1[0])[...,::-1] + np.array([np.cos(p1[1]), -np.sin(p1[1])]) * 30).astype(int), (0,0,255), thickness = 3, tipLength=0.3)\n",
    "            img_gt = cv2.arrowedLine(img_gt, pick[0][...,::-1], (pick[0][...,::-1] + np.array([np.cos(pick[1]), -np.sin(pick[1])]) * 30).astype(int), (255,0,255), thickness = 3, tipLength=0.3)\n",
    "            img_gt = cv2.arrowedLine(img_gt, place[0][...,::-1], (place[0][...,::-1] + np.array([np.cos(place[1]), -np.sin(place[1])]) * 30).astype(int), (0,0,255), thickness = 3, tipLength=0.3)\n",
    "            # img_out = cv2.arrowedLine(img_out, np.array(p0[0]), (np.array(p0[0]) + np.array([np.cos(p0[1]), -np.sin(p0[1])]) * 30).astype(int), (255,0,255), thickness = 3, tipLength=0.3)\n",
    "            # img_out = cv2.arrowedLine(img_out, np.array(p1[0]), (np.array(p1[0]) + np.array([np.cos(p1[1]), -np.sin(p1[1])]) * 30).astype(int), (0,0,255), thickness = 3, tipLength=0.3)\n",
    "            # img_gt = cv2.arrowedLine(img_gt, pick[0], (pick[0] + np.array([np.cos(pick[1]), -np.sin(pick[1])]) * 30).astype(int), (255,0,255), thickness = 3, tipLength=0.3)\n",
    "            # img_gt = cv2.arrowedLine(img_gt, place[0], (place[0] + np.array([np.cos(place[1]), -np.sin(place[1])]) * 30).astype(int), (0,0,255), thickness = 3, tipLength=0.3)\n",
    "            print(f\"PICK || Target z, roll, pitch: {pick[2]}, {pick[3]}, {pick[4]}\")\n",
    "            print(f\"PICK || z, roll, pitch: {p0[2]}, {p0[3]}, {p0[4]}\")\n",
    "            print(f\"PLACE || Target z, roll, pitch: {place[2]}, {place[3]}, {place[4]}\")\n",
    "            print(f\"PLACE || z, roll, pitch: {p1[2]}, {p1[3]}, {p1[4]}\")\n",
    "\n",
    "            w = 7\n",
    "            fig, axes = plt.subplots(1, 5, figsize=(w*5,w))\n",
    "            axes[0].imshow(img_out)\n",
    "            axes[1].imshow(img_gt)\n",
    "            #axes[2].imshow(crop_test)\n",
    "            axes[2].imshow(crop)\n",
    "            axes[3].imshow(pick_conf[...,np.unravel_index(np.argmax(pick_conf), pick_conf.shape)[-1]])\n",
    "            axes[4].imshow(place_conf[...,np.unravel_index(np.argmax(place_conf), place_conf.shape)[-1]])\n",
    "\n",
    "            plt.show()\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.12 ('edf')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "c0953e14757fdc95393fc3797619880897b68d726561afa3ebe46daeb55f7087"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
