{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import argparse\n",
    "import time\n",
    "import datetime\n",
    "import random\n",
    "import numpy as np\n",
    "import yaml\n",
    "import gzip\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image, ImageDraw\n",
    "import cv2\n",
    "\n",
    "import torch\n",
    "from pytorch3d import transforms\n",
    "\n",
    "from edf.utils import preprocess, voxelize_sample, OrthoTransform, binomial_test\n",
    "from edf.visual_utils import scatter_plot_ax\n",
    "from edf.pybullet_env.env import MugTask\n",
    "from edf.dist import GaussianDistSE3\n",
    "\n",
    "from baselines.equiv_tn.sixdof_non_equi_transporter import TransporterAgent\n",
    "from baselines.equiv_tn.utils import perturb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 0\n",
    "device = 'cuda'\n",
    "\n",
    "\n",
    "task_config_dir = 'config/task_config/mug_task.yaml'\n",
    "visualize_plot = True\n",
    "save_plot = False\n",
    "plot_path = 'logs/baselines/TN/'\n",
    "use_gui = True\n",
    "\n",
    "with open(task_config_dir) as file:\n",
    "    config = yaml.load(file, Loader=yaml.FullLoader)\n",
    "sleep = config['sleep']\n",
    "d = config['d']\n",
    "d_pick = config['d_pick']\n",
    "d_place = config['d_place']\n",
    "\n",
    "plot_figsize = [28,7]\n",
    "pick_attempt_max = 100\n",
    "place_attempt_max = 100\n",
    "pick_only = False\n",
    "\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "if device == 'cpu':\n",
    "    torch.use_deterministic_algorithms(True)\n",
    "elif device == 'cuda':\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    \n",
    "torch.set_printoptions(precision=4, sci_mode=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pix2pose(p, yaw, z, roll, pitch):\n",
    "    if yaw > np.pi:\n",
    "        yaw -= 2*np.pi\n",
    "    yaw = yaw * 180 / np.pi\n",
    "\n",
    "    if roll > np.pi:\n",
    "        roll -= 2*np.pi\n",
    "    roll = roll * 180 / np.pi\n",
    "\n",
    "    pitch = min(max(pitch, -np.pi), np.pi)\n",
    "    pitch = pitch * 180 / np.pi\n",
    "\n",
    "    T = ortho_transform.pix_yaw_zrp2pose(grasp_pix=p, yaw=yaw, height=z, roll=roll, pitch=pitch, grasp='top')\n",
    "    return T\n",
    "\n",
    "def save_plot_func():\n",
    "    if os.path.exists(plot_path + \"inference/\") is False:\n",
    "        os.makedirs(plot_path + \"inference/\")\n",
    "    fig.savefig(plot_path + \"inference/\" + f\"{seed}.png\")\n",
    "    if os.path.exists(plot_path + \"result/\") is False:\n",
    "        os.makedirs(plot_path + \"result/\")\n",
    "    fig_img.savefig(plot_path + \"result/\" + f\"{seed}.png\")\n",
    "\n",
    "def draw_result():\n",
    "    #pc = task.observe_pointcloud(stride = (1, 1))\n",
    "    #scatter_plot_ax(axes[3], pc['coord'], pc['color'], pc['ranges'])\n",
    "    axes[0].imshow(img_out)\n",
    "    images = task.observe()\n",
    "    for i in range(3):\n",
    "        axes_img[i].imshow(images[i]['color'])\n",
    "\n",
    "def plot():\n",
    "    draw_result()\n",
    "    if save_plot:\n",
    "        save_plot_func()\n",
    "    if visualize_plot:\n",
    "        plt.show()\n",
    "    else:\n",
    "        plt.close(fig)\n",
    "        plt.close(fig_img)\n",
    "\n",
    "def report():\n",
    "    confidence = 0.95\n",
    "    _, _, _, pick_result = binomial_test(success=N_success_pick, n=N_tests, confidence=confidence)\n",
    "    _, _, _, place_result = binomial_test(success=N_success_place, n=N_success_pick, confidence=confidence)\n",
    "    _, _, _, total_result = binomial_test(success=N_success_place, n=N_tests, confidence=confidence)\n",
    "\n",
    "    print(f\"Pick Success Rate: {pick_result}    ||   Place Success Rate: {place_result}    ||   Place-and-Place Success Rate: {total_result})\", flush=True)\n",
    "    plot()\n",
    "    print(\"======================================\", flush=True)\n",
    "\n",
    "def pick(T):\n",
    "    # R, X = transforms.quaternion_to_matrix(T[...,:4]), T[...,4:]\n",
    "    # X_sdg, R_sdg = data_transform.inv_transform_T(X.detach().cpu().numpy(), R.detach().cpu().numpy())\n",
    "    X_sdg, R_sdg = T\n",
    "    z_axis = R_sdg[:,-1]\n",
    "    \n",
    "    R_dg_dgpre = np.eye(3)\n",
    "    R_s_dgpre = R_sdg @ R_dg_dgpre\n",
    "    X_dg_dgpre = np.array([0., 0., -0.03])\n",
    "    sX_dg_dgpre = R_sdg @ X_dg_dgpre\n",
    "    X_s_dgpre = X_sdg + sX_dg_dgpre\n",
    "\n",
    "    pre_pick = (X_s_dgpre, R_s_dgpre)\n",
    "    pick = (X_sdg, R_sdg)\n",
    "\n",
    "    try:\n",
    "        task.pick(pre_pick, pick)\n",
    "        print(\"Pick IK Success\", flush=True)\n",
    "        return True\n",
    "    except StopIteration:\n",
    "        #print(\"Pick IK Failed\", flush=True)\n",
    "        return False\n",
    "\n",
    "def place(T):\n",
    "    # R, X = transforms.quaternion_to_matrix(T[...,:4]), T[...,4:]\n",
    "    # X_sdg, R_sdg = data_transform_K.inv_transform_T(X.detach().cpu().numpy(), R.detach().cpu().numpy())\n",
    "    X_sdg, R_sdg = T\n",
    "\n",
    "    R_dg_dgpre = np.eye(3)\n",
    "    R_s_dgpre = R_sdg @ R_dg_dgpre\n",
    "    X_dg_dgpre = np.array([0., 0., -0.03])\n",
    "    sX_dg_dgpre = R_sdg @ X_dg_dgpre\n",
    "    X_s_dgpre = X_sdg + sX_dg_dgpre\n",
    "\n",
    "    pre_place = (X_s_dgpre, R_s_dgpre)\n",
    "    place = (X_sdg, R_sdg)\n",
    "\n",
    "    try:\n",
    "        task.place(pre_place, place)\n",
    "        print(\"Place IK Success\", flush=True)\n",
    "        return True\n",
    "    except StopIteration:\n",
    "        #print(\"Place IK Failed\", flush=True)\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "H = W = 160\n",
    "crop_size = 16*6\n",
    "ortho_ranges = np.array([[0.4, 0.8],[-0.2, 0.2], [0., 0.4]])\n",
    "ortho_transform = OrthoTransform(W = W, ranges = ortho_ranges[:2])\n",
    "pix_size = (ortho_ranges[0,1] - ortho_ranges[0,0]) / H\n",
    "\n",
    "perturb_dist = GaussianDistSE3(std_theta = 2./180*np.pi, std_X = 0.2 * 0.01)\n",
    "perturb_dist.dist_R.get_inv_cdf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = TransporterAgent(name='any', task='any', root_dir='checkpoint_tn/rim', device=device, load=False, crop_size = crop_size, pix_size = pix_size, bounds = ortho_ranges, H=H, W=W, n_rotations=36)\n",
    "agent.load(n_iter=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##### Initialize task env #####\n",
    "task = MugTask(use_gui=use_gui)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##### Evaluate #####\n",
    "N_tests = 0\n",
    "N_success_pick = 0\n",
    "N_success_place = 0\n",
    "N_IKFAIL_pick = 0\n",
    "N_IKFAIL_place = 0\n",
    "pick_times = []\n",
    "place_times = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_seed = 100\n",
    "end_seed = init_seed + 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "schedule = {'mug_pose': 'upright', 'mug_type': 'default', \n",
    "            'distractor': False, 'use_support': False, \n",
    "            'init_seed': init_seed, 'end_seed': end_seed}\n",
    "\n",
    "schedules = [schedule]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for schedule in schedules:\n",
    "    mug_pose = schedule['mug_pose']\n",
    "    mug_type = schedule['mug_type']\n",
    "    distractor = schedule['distractor']\n",
    "    use_support = schedule['use_support']\n",
    "    for seed in range(schedule['init_seed'], schedule['end_seed']):\n",
    "        N_tests += 1\n",
    "        print(f\"=================Sample {seed}==================\", flush=True)\n",
    "        fig, axes = plt.subplots(1,4, figsize=plot_figsize)\n",
    "        fig_img, axes_img = plt.subplots(1,3, figsize=plot_figsize)\n",
    "\n",
    "        ##### Observe #####\n",
    "        task.reset(seed = seed, mug_pose=mug_pose, mug_type=mug_type, distractor=distractor, use_support=use_support)\n",
    "        pc = task.observe_pointcloud(stride = (1,1))\n",
    "        sample = {}\n",
    "        sample['coord'], sample['color'] = pc['coord'], pc['color']\n",
    "        sample['range'] = pc['ranges']\n",
    "        sample['d'] = 0.001\n",
    "        sample = voxelize_sample(sample, coord_jitter=3., color_jitter=0.03, pick=True, place=False)\n",
    "\n",
    "        in_range_idx = ((sample['coord'][..., -1] > ortho_ranges[-1][0]) * (sample['coord'][..., -1] < ortho_ranges[-1][1]))\n",
    "        coord = sample['coord'][in_range_idx]\n",
    "        color = sample['color'][in_range_idx]\n",
    "\n",
    "        img = ortho_transform.orthographic(coord, color)\n",
    "\n",
    "        img_mean, img_std = np.array([[0.5, 0.5, 0.5, 0.25]]), np.array([[0.5, 0.5, 0.5, 0.25]])\n",
    "        img = (img - img_mean) / img_std\n",
    "        img = np.concatenate((img[Ellipsis, :3],\n",
    "                            img[Ellipsis, 3:4],\n",
    "                            img[Ellipsis, 3:4],\n",
    "                            img[Ellipsis, 3:4]), axis=2).astype(np.float32)\n",
    "\n",
    "        img_visual = img[...,:4].copy() * img_std + img_mean\n",
    "        img_visual = img_visual - img_visual.min()\n",
    "        img_visual = img_visual / img_visual.max()\n",
    "\n",
    "        img_out = (img_visual.copy()[...,:3]*255).astype(np.uint8)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            pick_conf, zrp, zrp_log_std = agent.act_pick(img) # (H,W,nRot), (H,W,nRot,3), (H,W,nRot,3)\n",
    "        indices = pick_conf.reshape(-1).argsort()[-pick_attempt_max:][::-1]\n",
    "        hs,ws,theta_is = np.unravel_index(indices, pick_conf.shape)\n",
    "\n",
    "        for (h, w, theta_i) in zip(hs,ws,theta_is):\n",
    "            p0 = np.array((h, w))\n",
    "            p0_theta = theta_i * (2 * np.pi / pick_conf.shape[2])\n",
    "            z,r,p = zrp[h,w,theta_i].detach().cpu().numpy() #+ np.random.randn(3) * zrp_log_std[h,w,theta_i].detach().cpu().exp().numpy()\n",
    "            T = pix2pose(p0, p0_theta, z, r, p)\n",
    "\n",
    "            #T[0][-1] = 0.09\n",
    "            pick_ik_success = pick(T)\n",
    "            if pick_ik_success:\n",
    "                break\n",
    "            axes[1].imshow(pick_conf[...,np.unravel_index(np.argmax(pick_conf), pick_conf.shape)[-1]])\n",
    "            img_out = cv2.arrowedLine(img_out, np.array(p0)[...,::-1], (np.array(p0)[...,::-1] + np.array([np.cos(p0_theta), -np.sin(p0_theta)]) * 30).astype(int), (255,0,255), thickness = 3, tipLength=0.3)\n",
    "\n",
    "\n",
    "        if not pick_ik_success:\n",
    "            print(\"Pick fail: Couldn't find IK solution\", flush=True)\n",
    "            N_IKFAIL_pick += 1\n",
    "            report()\n",
    "            continue\n",
    "\n",
    "        if task.check_pick_success():\n",
    "            print(\"Pick success\", flush=True)\n",
    "            N_success_pick += 1\n",
    "        else:\n",
    "            print(\"Pick fail: Found IK solution but failed\", flush=True)\n",
    "            report()\n",
    "            continue\n",
    "        \n",
    "        if pick_only:\n",
    "            report()\n",
    "            continue\n",
    "        \n",
    "        ############################################# Pick Finished #######################################\n",
    "        ############################################# Place Starts  #######################################\n",
    "\n",
    "        task.retract_robot(gripper_val=1., IK_time=1., back=True)\n",
    "\n",
    "\n",
    "        crop_test = (img_visual.copy()[...,:3]*255).astype(np.uint8)\n",
    "        crop_test = np.pad(crop_test, ((crop_size//2,crop_size//2), (crop_size//2,crop_size//2), (0, 0)))\n",
    "        crop_test = crop_test[p0[0]:p0[0]+crop_size, p0[1]:p0[1]+crop_size]\n",
    "        axes[2].imshow(crop_test)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            place_conf, zrp_place, zrp_log_std_place = agent.act_place(img, p0_pix = p0, p0_z = z, p0_roll = r, p0_pitch = p) # (H,W,nRot), (H,W,nRot,3), (H,W,nRot,3)\n",
    "        indices = place_conf.reshape(-1).argsort()[-place_attempt_max:][::-1]\n",
    "        hs,ws,theta_is = np.unravel_index(indices, place_conf.shape)\n",
    "        \n",
    "        for (h, w, theta_i) in zip(hs,ws,theta_is):\n",
    "            p1 = np.array((h, w))\n",
    "            p1_theta = theta_i * (2 * np.pi / place_conf.shape[2]) + p0_theta\n",
    "            p1_theta = (p1_theta + 2*np.pi) % (2*np.pi)\n",
    "            z,r,p = zrp_place[h,w,theta_i].detach().cpu().numpy() #+ np.random.randn(3) * zrp_log_std_place[h,w,theta_i].detach().cpu().exp().numpy()\n",
    "\n",
    "            T = pix2pose(p1, p1_theta, z, r, p)\n",
    "\n",
    "            place_ik_success = place(T)\n",
    "            if place_ik_success:\n",
    "                break\n",
    "        axes[3].imshow(place_conf[...,np.unravel_index(np.argmax(place_conf), place_conf.shape)[-1]])\n",
    "        img_out = cv2.arrowedLine(img_out, np.array(p1)[...,::-1], (np.array(p1)[...,::-1] + np.array([np.cos(p1_theta), -np.sin(p1_theta)]) * 30).astype(int), (0,0,255), thickness = 3, tipLength=0.3)\n",
    "\n",
    "\n",
    "        if not place_ik_success:\n",
    "            print(\"Place fail: Couldn't find IK solution\", flush=True)\n",
    "            N_IKFAIL_place += 1\n",
    "            report()\n",
    "            continue\n",
    "\n",
    "        if task.check_place_success():\n",
    "            N_success_place += 1\n",
    "            print('Place Success', flush=True)\n",
    "        else:\n",
    "            print('Place Fail', flush=True)\n",
    "\n",
    "        ##### Visualize final #####\n",
    "        report()\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.12 ('edf')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "c0953e14757fdc95393fc3797619880897b68d726561afa3ebe46daeb55f7087"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
