{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "import shutil\n",
    "import time\n",
    "from typing import Dict, List, Union\n",
    "import skill_generator.models.skill_generator as model_sg\n",
    "import hulc\n",
    "import torch"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "def get_all_checkpoints(experiment_folder: Path) -> List:\n",
    "    if experiment_folder.is_dir():\n",
    "        checkpoint_folder = experiment_folder / \"saved_models\"\n",
    "        if checkpoint_folder.is_dir():\n",
    "            checkpoints = sorted(Path(checkpoint_folder).iterdir(), key=lambda chk: chk.stat().st_mtime)\n",
    "            if len(checkpoints):\n",
    "                return [chk for chk in checkpoints if chk.suffix == \".ckpt\"]\n",
    "    return []"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "def get_last_checkpoint(experiment_folder: Path) -> Union[Path, None]:\n",
    "    # return newest checkpoint according to creation time\n",
    "    checkpoints = get_all_checkpoints(experiment_folder)\n",
    "    if len(checkpoints):\n",
    "        return checkpoints[-1]\n",
    "    return None"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "outputs": [],
   "source": [
    "def _sample(mu: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:\n",
    "    eps = torch.randn(*mu.size()).to(mu)\n",
    "    return mu +  1.0 * scale * eps"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "outputs": [],
   "source": [
    "def _check_direction(dire):\n",
    "    p = torch.clone(dire)\n",
    "    n = torch.clone(dire)\n",
    "    p[dire < 0.] = 0.\n",
    "    n[dire > 0.] = 0.\n",
    "    n = torch.abs(n)\n",
    "    return p, n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "outputs": [],
   "source": [
    "def skill_classifier(actions, scale=[1.6, 3.0, 1.0], eps=0.5):\n",
    "    gripper_energy = 0.\n",
    "    _, T, _ = actions.shape\n",
    "    energy = torch.abs(torch.sum(actions[:, :, :6], dim=1))\n",
    "    for i in range(T - 1):\n",
    "        gripper_energy += abs(actions[:, i + 1, 6] - actions[:, i, 6])\n",
    "\n",
    "    translation = torch.norm(energy[:, :3], dim=1)\n",
    "    rotation = torch.norm(energy[:, 3:6], dim=1)\n",
    "    gripper = gripper_energy\n",
    "    rotation[rotation < eps] = 0.\n",
    "    translation /= scale[0]\n",
    "    rotation /= scale[1]\n",
    "    gripper /= scale[2]\n",
    "\n",
    "    t = torch.stack([translation, rotation, gripper], dim=-1)\n",
    "    B, _ = t.shape\n",
    "    skill_types = torch.argmax(t, dim=-1)\n",
    "    return skill_types"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "translation rate:  tensor(0.8091)\n",
      "rotation rate:  tensor(0.4681)\n",
      "grasp rate:  tensor(0.7188)\n"
     ]
    }
   ],
   "source": [
    "batch = 10000\n",
    "# load_checkpoint\n",
    "sg_chk_path = './sg_runs/2023-01-25/09-36-29'\n",
    "sg_chk_path = Path(hulc.__file__).parent.parent / sg_chk_path\n",
    "chk = get_last_checkpoint(sg_chk_path)\n",
    "skill_generator = getattr(model_sg, 'SkillGenerator').load_from_checkpoint(chk.as_posix())\n",
    "skill_generator.freeze()\n",
    "prior_locator = skill_generator.prior_locator.eval()\n",
    "action_decoder = skill_generator.decoder.eval()\n",
    "\n",
    "priors = prior_locator(repeat=batch)\n",
    "skill_len = torch.tensor(5)\n",
    "\n",
    "t_mu = priors['p_mu'][:,0,:]\n",
    "t_scale = priors['p_scale'][:,0,:]\n",
    "t_sampled = _sample(t_mu, t_scale)\n",
    "\n",
    "r_mu = priors['p_mu'][:,1,:]\n",
    "r_scale = priors['p_scale'][:,1,:]\n",
    "r_sampled = _sample(r_mu, r_scale)\n",
    "\n",
    "g_mu = priors['p_mu'][:,2,:]\n",
    "g_scale = priors['p_scale'][:,2,:]\n",
    "g_sampled = _sample(g_mu, g_scale)\n",
    "\n",
    "t_actions = action_decoder(t_sampled, skill_len.repeat(batch))\n",
    "r_actions = action_decoder(r_sampled, skill_len.repeat(batch))\n",
    "g_actions = action_decoder(g_sampled, skill_len.repeat(batch))\n",
    "\n",
    "rate_t = torch.sum(skill_classifier(t_actions) == 0) / batch\n",
    "rate_r = torch.sum(skill_classifier(r_actions) == 1) / batch\n",
    "rate_g = torch.sum(skill_classifier(g_actions) == 2) / batch\n",
    "\n",
    "print('translation rate: ', rate_t)\n",
    "print('rotation rate: ', rate_r)\n",
    "print('grasp rate: ', rate_g)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "hulc",
   "language": "python",
   "display_name": "hulc"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}