{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from omegaconf import OmegaConf\n",
    "from lerobot.common.policies.factory import make_policy, _policy_cfg_from_hydra_cfg\n",
    "from pathlib import Path\n",
    "from lerobot.common.envs.factory import make_env\n",
    "from importlib import import_module\n",
    "from safetensors import safe_open\n",
    "from lerobot.scripts.eval import eval_policy\n",
    "import torch\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "def load_safetensors(path):\n",
    "    tensors = {}\n",
    "    with safe_open(path, framework=\"pt\", device=\"cpu\") as f:\n",
    "        for key in f.keys():\n",
    "            tensors[key] = f.get_tensor(key)\n",
    "    return tensors\n",
    "\n",
    "device = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# task = 'insertion'\n",
    "task = 'transfer_cube'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_path = './configs/arp.yaml'\n",
    "ckpt_path = f'./weights/model.{task}.safetensors'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = OmegaConf.load(config_path)\n",
    "cfg['dataset_repo_id'] = f'lerobot/aloha_sim_{task}_human'\n",
    "if task == 'insertion':\n",
    "    cfg.env.task = 'AlohaInsertion-v0'\n",
    "else:\n",
    "    cfg.env.task = 'AlohaTransferCube-v0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "importing policy module from: lerobot.common.policies.autoregressive_policy\n"
     ]
    }
   ],
   "source": [
    "prefix = 'lerobot.common.policies.' + cfg.policy['name']\n",
    "print(f'importing policy module from: {prefix}')\n",
    "config_mod = import_module(prefix + '.configuration')\n",
    "Config = config_mod.ARPConfig if hasattr(config_mod, 'ARPConfig') else config_mod.Config\n",
    "modeling_mod = import_module(prefix + '.modeling')\n",
    "Policy = modeling_mod.ARPPolicy if hasattr(modeling_mod, 'ARPPolicy') else modeling_mod.Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Hydra config is missing arguments: {'guide_chunk_size', 'action_chunk_size'}\n"
     ]
    }
   ],
   "source": [
    "config = _policy_cfg_from_hydra_cfg(Config, cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/common/users/xz653/anaconda3/envs/rvt/lib/python3.8/site-packages/torch/nn/modules/transformer.py:306: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
      "  warnings.warn(f\"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}\")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "policy = Policy(config)\n",
    "policy.load_state_dict(load_safetensors(ckpt_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "policy = policy.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = make_env(cfg, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stepping through eval batches: 100%|██████████| 30/30 [02:08<00:00,  4.29s/it, running_success_rate=80.0%]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'per_episode': [{'episode_ix': 0,\n",
       "   'sum_reward': 242.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 1,\n",
       "   'sum_reward': 264.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 2,\n",
       "   'sum_reward': 252.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 3,\n",
       "   'sum_reward': 223.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 4,\n",
       "   'sum_reward': 35.0,\n",
       "   'max_reward': 2.0,\n",
       "   'success': False,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 5,\n",
       "   'sum_reward': 290.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 6,\n",
       "   'sum_reward': 102.0,\n",
       "   'max_reward': 2.0,\n",
       "   'success': False,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 7,\n",
       "   'sum_reward': 278.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 8,\n",
       "   'sum_reward': 64.0,\n",
       "   'max_reward': 1.0,\n",
       "   'success': False,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 9,\n",
       "   'sum_reward': 242.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 10,\n",
       "   'sum_reward': 228.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 11,\n",
       "   'sum_reward': 288.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 12,\n",
       "   'sum_reward': 272.0,\n",
       "   'max_reward': 2.0,\n",
       "   'success': False,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 13,\n",
       "   'sum_reward': 379.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 14,\n",
       "   'sum_reward': 278.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 15,\n",
       "   'sum_reward': 297.0,\n",
       "   'max_reward': 2.0,\n",
       "   'success': False,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 16,\n",
       "   'sum_reward': 271.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 17,\n",
       "   'sum_reward': 354.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 18,\n",
       "   'sum_reward': 303.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 19,\n",
       "   'sum_reward': 291.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 20,\n",
       "   'sum_reward': 181.0,\n",
       "   'max_reward': 2.0,\n",
       "   'success': False,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 21,\n",
       "   'sum_reward': 362.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 22,\n",
       "   'sum_reward': 281.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 23,\n",
       "   'sum_reward': 285.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 24,\n",
       "   'sum_reward': 268.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 25,\n",
       "   'sum_reward': 278.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 26,\n",
       "   'sum_reward': 230.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 27,\n",
       "   'sum_reward': 324.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 28,\n",
       "   'sum_reward': 293.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None},\n",
       "  {'episode_ix': 29,\n",
       "   'sum_reward': 245.0,\n",
       "   'max_reward': 4.0,\n",
       "   'success': True,\n",
       "   'seed': None}],\n",
       " 'aggregated': {'avg_sum_reward': 256.6666666666667,\n",
       "  'avg_max_reward': 3.566666666666667,\n",
       "  'pc_success': 80.0,\n",
       "  'eval_s': 129.17113661766052,\n",
       "  'eval_ep_s': 4.305704577763875},\n",
       " 'video_paths': ['outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_0.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_1.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_2.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_3.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_4.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_5.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_6.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_7.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_8.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_9.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_10.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_11.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_12.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_13.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_14.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_15.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_16.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_17.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_18.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_19.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_20.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_21.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_22.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_23.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_24.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_25.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_26.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_27.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_28.mp4',\n",
       "  'outputs/demo/lerobot/aloha_sim_transfer_cube_human/eval_episode_29.mp4']}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_policy(env, policy, 30, max_episodes_rendered=30, videos_dir=Path('./outputs/demo/' + cfg['dataset_repo_id']), \n",
    "            enable_progbar=True, enable_inner_progbar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rvt",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
