{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pygame 2.1.2 (SDL 2.0.16, Python 3.8.19)\n",
      "Hello from the pygame community. https://www.pygame.org/contribute.html\n"
     ]
    }
   ],
   "source": [
    "from omegaconf import OmegaConf\n",
    "import torch\n",
    "torch.set_grad_enabled(False)\n",
    "import hydra\n",
    "import wandb.sdk.data_types.video as wv\n",
    "from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv\n",
    "from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper\n",
    "from torchvision.transforms.functional import to_pil_image\n",
    "from diffusion_policy.common.pytorch_util import dict_apply\n",
    "import torch.nn.functional as F\n",
    "from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder\n",
    "import pathlib\n",
    "\n",
    "def prepare_obs(obs):\n",
    "    np_obs_dict = {\n",
    "        'image': F.interpolate(torch.from_numpy(obs['image']), size=(96, 96), mode='bilinear', align_corners=False)[None, ...],\n",
    "        'origin_image': torch.from_numpy(obs['image']), \n",
    "        'agent_pos': torch.from_numpy(obs['agent_pos'][None, ])\n",
    "    }\n",
    "    return np_obs_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "============= Initialized Observation Utils with Obs Spec =============\n",
      "\n",
      "using obs modality: rgb with keys: ['image']\n",
      "using obs modality: depth with keys: []\n",
      "using obs modality: scan with keys: []\n",
      "using obs modality: low_dim with keys: []\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/common/users/xz653/anaconda3/envs/rvt/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "/common/users/xz653/anaconda3/envs/rvt/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
      "  warnings.warn(msg)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cfg = OmegaConf.load(f'./configs/arp.yaml')\n",
    "arp_policy = hydra.utils.instantiate(cfg.policy)\n",
    "arp_policy.load_state_dict(torch.load('/common/users/xz653/Workspace/iclr2025/release/pusht/weights/epoch=2000-test_mean_score=0.865.ckpt')['state_dicts']['model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = MultiStepWrapper(VideoRecordingWrapper(\n",
    "                    PushTImageEnv(\n",
    "                        render_size=256\n",
    "                    ),\n",
    "                    video_recoder=VideoRecorder.create_h264(\n",
    "                        fps=10,\n",
    "                        codec='h264',\n",
    "                        input_pix_fmt='rgb24',\n",
    "                        crf=22,\n",
    "                        thread_type='FRAME',\n",
    "                        thread_count=1\n",
    "                    ),\n",
    "                    file_path=None,\n",
    "                    steps_per_render=1\n",
    "                ),\n",
    "                n_obs_steps=cfg.policy.n_obs_steps,\n",
    "                n_action_steps=cfg.policy.n_action_steps,\n",
    "                max_episode_steps=200)\n",
    "\n",
    "env.env.video_recoder.stop()\n",
    "demo_path = pathlib.Path('./outputs/demo')\n",
    "demo_path.mkdir(parents=True, exist_ok=True)\n",
    "filename = demo_path.joinpath(\n",
    "    'media', wv.util.generate_id() + \".mp4\")\n",
    "filename.parent.mkdir(parents=False, exist_ok=True)\n",
    "filename = str(filename)\n",
    "env.env.file_path = filename"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('demonstration is saved to ', 'outputs/demo/media/boybped3.mp4')"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.seed(10)\n",
    "obs = env.reset() # obs['agent_pos'].shape = (2, 2)\n",
    "\n",
    "arp_policy.eval()\n",
    "done = False\n",
    "while not done:\n",
    "    action_dict = arp_policy.predict_action(prepare_obs(obs))\n",
    "    obs, reward, done, info = env.step({k: v.detach().to('cpu').numpy() for k, v in action_dict.items()}['action'][0])\n",
    "\n",
    "obs = env.reset()\n",
    "\n",
    "\"demonstration is saved to \", env.env.file_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0637ef624ac242f89b27db6ea36c1791",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm.auto import trange\n",
    "for s in trange(1, 101):\n",
    "    env.seed(s)\n",
    "    filename = demo_path.joinpath(\n",
    "    'media',  f\"{s}.mp4\")\n",
    "    filename.parent.mkdir(parents=False, exist_ok=True)\n",
    "    filename = str(filename)\n",
    "    env.env.file_path = filename\n",
    "\n",
    "    obs = env.reset() # obs['agent_pos'].shape = (2, 2)\n",
    "\n",
    "    arp_policy.eval()\n",
    "    done = False\n",
    "    while not done:\n",
    "        action_dict = arp_policy.predict_action(prepare_obs(obs))\n",
    "        obs, reward, done, info = env.step({k: v.detach().to('cpu').numpy() for k, v in action_dict.items()}['action'][0])\n",
    "\n",
    "    obs = env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rvt",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
