{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import gym\n",
    "import hydra\n",
    "from omegaconf import DictConfig\n",
    "from pytorch_lightning import seed_everything\n",
    "from hydra import initialize, initialize_config_module, initialize_config_dir, compose\n",
    "from hulc.models.decoders.utils.gripper_control import tcp_to_world_frame, world_to_tcp_frame\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "def skill_type_classifier(acts):\n",
    "    energy = 0.\n",
    "    energy_g = 0.\n",
    "    _, T, _ = acts.shape\n",
    "    for i in range(T):\n",
    "        energy += abs(acts[:, i, :6])\n",
    "    for i in range(T-1):\n",
    "        energy_g += abs(acts[:, i+1, 6]- acts[:, i, 6])\n",
    "    translation = (energy[:, 0] + energy[:, 1] + energy[:, 2]) / 3\n",
    "    rotation = (energy[:, 3] + energy[:, 4] + energy[:, 5]) / 3\n",
    "    gripper = energy_g\n",
    "    return translation, rotation, gripper"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/Garen/opt/anaconda3/envs/hulc/lib/python3.8/site-packages/hydra/_internal/defaults_list.py:251: UserWarning: In 'config.yaml': Defaults list is missing `_self_`. See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/default_composition_order for more information\n",
      "  warnings.warn(msg, UserWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/Garen/hulc/dataset/calvin_debug_dataset/training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/Garen/opt/anaconda3/envs/hulc/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:133: LightningDeprecationWarning: DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7.\n",
      "  rank_zero_deprecation(\n",
      "/Users/Garen/opt/anaconda3/envs/hulc/lib/python3.8/site-packages/torch/autocast_mode.py:141: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
      "  warnings.warn('User provided device_type of \\'cuda\\', but CUDA is not available. Disabling')\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[0;32mIn [4], line 16\u001B[0m\n\u001B[1;32m     14\u001B[0m tcp_actions \u001B[38;5;241m=\u001B[39m world_to_tcp_frame(gt_actions, robot_obs)\n\u001B[1;32m     15\u001B[0m t, r, g \u001B[38;5;241m=\u001B[39m skill_type_classifier(tcp_actions)\n\u001B[0;32m---> 16\u001B[0m a_t \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m t\n\u001B[1;32m     17\u001B[0m a_r \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m r\n\u001B[1;32m     18\u001B[0m a_g \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m g\n",
      "Cell \u001B[0;32mIn [4], line 16\u001B[0m\n\u001B[1;32m     14\u001B[0m tcp_actions \u001B[38;5;241m=\u001B[39m world_to_tcp_frame(gt_actions, robot_obs)\n\u001B[1;32m     15\u001B[0m t, r, g \u001B[38;5;241m=\u001B[39m skill_type_classifier(tcp_actions)\n\u001B[0;32m---> 16\u001B[0m a_t \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m t\n\u001B[1;32m     17\u001B[0m a_r \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m r\n\u001B[1;32m     18\u001B[0m a_g \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m g\n",
      "File \u001B[0;32m_pydevd_bundle/pydevd_cython_darwin_38_64.pyx:1179\u001B[0m, in \u001B[0;36m_pydevd_bundle.pydevd_cython_darwin_38_64.SafeCallWrapper.__call__\u001B[0;34m()\u001B[0m\n",
      "File \u001B[0;32m_pydevd_bundle/pydevd_cython_darwin_38_64.pyx:620\u001B[0m, in \u001B[0;36m_pydevd_bundle.pydevd_cython_darwin_38_64.PyDBFrame.trace_dispatch\u001B[0;34m()\u001B[0m\n",
      "File \u001B[0;32m_pydevd_bundle/pydevd_cython_darwin_38_64.pyx:1095\u001B[0m, in \u001B[0;36m_pydevd_bundle.pydevd_cython_darwin_38_64.PyDBFrame.trace_dispatch\u001B[0;34m()\u001B[0m\n",
      "File \u001B[0;32m_pydevd_bundle/pydevd_cython_darwin_38_64.pyx:1053\u001B[0m, in \u001B[0;36m_pydevd_bundle.pydevd_cython_darwin_38_64.PyDBFrame.trace_dispatch\u001B[0;34m()\u001B[0m\n",
      "File \u001B[0;32m/Applications/PyCharm.app/Contents/plugins/python/helpers-pro/jupyter_debug/pydev_jupyter_plugin.py:169\u001B[0m, in \u001B[0;36mstop\u001B[0;34m(plugin, pydb, frame, event, args, stop_info, arg, step_cmd)\u001B[0m\n\u001B[1;32m    167\u001B[0m     frame \u001B[38;5;241m=\u001B[39m suspend_jupyter(main_debugger, thread, frame, step_cmd)\n\u001B[1;32m    168\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m frame:\n\u001B[0;32m--> 169\u001B[0m         \u001B[43mmain_debugger\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdo_wait_suspend\u001B[49m\u001B[43m(\u001B[49m\u001B[43mthread\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mframe\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mevent\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43marg\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    170\u001B[0m         \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[1;32m    171\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mFalse\u001B[39;00m\n",
      "File \u001B[0;32m/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py:1155\u001B[0m, in \u001B[0;36mPyDB.do_wait_suspend\u001B[0;34m(self, thread, frame, event, arg, send_suspend_message, is_unhandled_exception)\u001B[0m\n\u001B[1;32m   1152\u001B[0m         from_this_thread\u001B[38;5;241m.\u001B[39mappend(frame_id)\n\u001B[1;32m   1154\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_threads_suspended_single_notification\u001B[38;5;241m.\u001B[39mnotify_thread_suspended(thread_id, stop_reason):\n\u001B[0;32m-> 1155\u001B[0m     \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_do_wait_suspend\u001B[49m\u001B[43m(\u001B[49m\u001B[43mthread\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mframe\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mevent\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43marg\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msuspend_type\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mfrom_this_thread\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py:1170\u001B[0m, in \u001B[0;36mPyDB._do_wait_suspend\u001B[0;34m(self, thread, frame, event, arg, suspend_type, from_this_thread)\u001B[0m\n\u001B[1;32m   1167\u001B[0m             \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_call_mpl_hook()\n\u001B[1;32m   1169\u001B[0m         \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mprocess_internal_commands()\n\u001B[0;32m-> 1170\u001B[0m         \u001B[43mtime\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msleep\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m0.01\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m   1172\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcancel_async_evaluation(get_current_thread_id(thread), \u001B[38;5;28mstr\u001B[39m(\u001B[38;5;28mid\u001B[39m(frame)))\n\u001B[1;32m   1174\u001B[0m \u001B[38;5;66;03m# process any stepping instructions\u001B[39;00m\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "source": [
    "with initialize(version_base=None, config_path=\"../../conf_sg\"):\n",
    "    cfg = compose(config_name=\"config.yaml\", overrides=['datamodule.root_data_dir=dataset/calvin_debug_dataset'])\n",
    "    datamodule = hydra.utils.instantiate(cfg.datamodule)\n",
    "    print(datamodule.training_dir)\n",
    "    datamodule.prepare_data()\n",
    "    datamodule.setup()\n",
    "    train_dataloader = datamodule.train_dataloader()\n",
    "    a_t = []\n",
    "    a_r = []\n",
    "    a_g = []\n",
    "    for data in train_dataloader:\n",
    "        gt_actions = data['actions']\n",
    "        robot_obs = data['state_info']['robot_obs']\n",
    "        tcp_actions = world_to_tcp_frame(gt_actions, robot_obs)\n",
    "        t, r, g = skill_type_classifier(tcp_actions)\n",
    "        a_t += t\n",
    "        a_r += r\n",
    "        a_g += g\n",
    "\n",
    "    a_t = np.array(a_t)\n",
    "    a_r = np.array(a_r)\n",
    "    a_g = np.array(a_g)\n",
    "    print(a_t.mean(), a_t.std())\n",
    "    print(a_r.mean(), a_r.std())\n",
    "    print(a_g.mean(), a_g.std())"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "hulc",
   "language": "python",
   "display_name": "hulc"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}