{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**RLLTE allows you to use intrinsic reward modules in a simple and elegant way.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load libraries\n",
    "import torch as th\n",
    "\n",
    "from rllte.xplore.reward import ICM, PseudoCounts\n",
    "from rllte.env import make_mario_env, make_dmc_env\n",
    "from rllte.agent import PPO, DDPG"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Example with PPO**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "e:\\anaconda3\\envs\\marllib\\lib\\site-packages\\gym\\envs\\registration.py:627: UserWarning: \u001b[33mWARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']\u001b[0m\n",
      "  logger.warn(\n",
      "e:\\anaconda3\\envs\\marllib\\lib\\site-packages\\gymnasium\\core.py:311: UserWarning: \u001b[33mWARN: env.metadata to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.metadata` for environment variables or `env.get_wrapper_attr('metadata')` that will search the reminding wrappers.\u001b[0m\n",
      "  logger.warn(\n",
      "e:\\anaconda3\\envs\\marllib\\lib\\site-packages\\gymnasium\\core.py:311: UserWarning: \u001b[33mWARN: env.single_observation_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_observation_space` for environment variables or `env.get_wrapper_attr('single_observation_space')` that will search the reminding wrappers.\u001b[0m\n",
      "  logger.warn(\n",
      "e:\\anaconda3\\envs\\marllib\\lib\\site-packages\\gymnasium\\core.py:311: UserWarning: \u001b[33mWARN: env.single_action_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_action_space` for environment variables or `env.get_wrapper_attr('single_action_space')` that will search the reminding wrappers.\u001b[0m\n",
      "  logger.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda Box(0, 255, (3, 84, 84), uint8) Discrete(7)\n",
      "[05/14/2024 05:01:14 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Invoking RLLTE Engine...\n",
      "[05/14/2024 05:01:14 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - ================================================================================\n",
      "[05/14/2024 05:01:14 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Tag               : default\n",
      "[05/14/2024 05:01:14 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Device            : NVIDIA GeForce RTX 3080\n",
      "[05/14/2024 05:01:14 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Agent             : PPO\n",
      "[05/14/2024 05:01:14 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Encoder           : MnihCnnEncoder\n",
      "[05/14/2024 05:01:14 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Policy            : OnPolicySharedActorCritic\n",
      "[05/14/2024 05:01:14 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Storage           : VanillaRolloutStorage\n",
      "[05/14/2024 05:01:14 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Distribution      : Categorical\n",
      "[05/14/2024 05:01:14 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Augmentation      : None\n",
      "[05/14/2024 05:01:14 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Intrinsic Reward  : ICM\n",
      "[05/14/2024 05:01:14 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - ================================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "e:\\anaconda3\\envs\\marllib\\lib\\site-packages\\torch\\nn\\modules\\conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
      "  return F.conv2d(input, weight, bias, self.stride,\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[05/14/2024 05:01:15 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 1024        | E: 8           | L: 50          | R: 2.303       | FPS: 580.810   | T: 0:00:01    \n",
      "[05/14/2024 05:01:17 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 2048        | E: 16          | L: 50          | R: 2.303       | FPS: 674.402   | T: 0:00:03    \n",
      "[05/14/2024 05:01:18 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 3072        | E: 24          | L: 125         | R: 4.252       | FPS: 700.904   | T: 0:00:04    \n",
      "[05/14/2024 05:01:19 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 4096        | E: 32          | L: 207         | R: 5.124       | FPS: 719.410   | T: 0:00:05    \n",
      "[05/14/2024 05:01:20 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 5120        | E: 40          | L: 262         | R: 5.311       | FPS: 734.897   | T: 0:00:06    \n",
      "[05/14/2024 05:01:22 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 6144        | E: 48          | L: 262         | R: 5.311       | FPS: 741.157   | T: 0:00:08    \n",
      "[05/14/2024 05:01:23 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 7168        | E: 56          | L: 339         | R: 5.787       | FPS: 744.078   | T: 0:00:09    \n",
      "[05/14/2024 05:01:24 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 8192        | E: 64          | L: 430         | R: 7.972       | FPS: 746.637   | T: 0:00:10    \n",
      "[05/14/2024 05:01:26 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 9216        | E: 72          | L: 326         | R: 6.363       | FPS: 750.032   | T: 0:00:12    \n",
      "[05/14/2024 05:01:26 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Training Accomplished!\n",
      "[05/14/2024 05:01:26 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Model saved at: e:\\HSMRCS\\rllte\\examples\\rlexplore\\logs\\default\\2024-05-14-05-01-14\\model\n"
     ]
    }
   ],
   "source": [
    "# create the vectorized environments\n",
    "device = 'cuda' if th.cuda.is_available() else 'cpu'\n",
    "envs = make_mario_env('SuperMarioBros-1-1-v3', device=device)\n",
    "print(device, envs.observation_space, envs.action_space)\n",
    "# create the intrinsic reward module\n",
    "irs = ICM(envs, device=device)\n",
    "# create the PPO agent\n",
    "agent = PPO(envs, device=device)\n",
    "# set the intrinsic reward module\n",
    "agent.set(reward=irs)\n",
    "# train the agent\n",
    "agent.train(10000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Example with DDPG**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "e:\\anaconda3\\envs\\marllib\\lib\\site-packages\\gymnasium\\core.py:311: UserWarning: \u001b[33mWARN: env.single_observation_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_observation_space` for environment variables or `env.get_wrapper_attr('single_observation_space')` that will search the reminding wrappers.\u001b[0m\n",
      "  logger.warn(\n",
      "e:\\anaconda3\\envs\\marllib\\lib\\site-packages\\gymnasium\\core.py:311: UserWarning: \u001b[33mWARN: env.single_action_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_action_space` for environment variables or `env.get_wrapper_attr('single_action_space')` that will search the reminding wrappers.\u001b[0m\n",
      "  logger.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda Box(-inf, inf, (15,), float32) Box(-1.0, 1.0, (4,), float32)\n",
      "[05/14/2024 05:03:53 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Invoking RLLTE Engine...\n",
      "[05/14/2024 05:03:53 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - ================================================================================\n",
      "[05/14/2024 05:03:53 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Tag               : default\n",
      "[05/14/2024 05:03:53 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Device            : NVIDIA GeForce RTX 3080\n",
      "[05/14/2024 05:03:53 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Agent             : DDPG\n",
      "[05/14/2024 05:03:53 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Encoder           : IdentityEncoder\n",
      "[05/14/2024 05:03:53 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Policy            : OffPolicyDetActorDoubleCritic\n",
      "[05/14/2024 05:03:53 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Storage           : VanillaReplayStorage\n",
      "[05/14/2024 05:03:53 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Distribution      : TruncatedNormalNoise\n",
      "[05/14/2024 05:03:53 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Augmentation      : None\n",
      "[05/14/2024 05:03:53 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Intrinsic Reward  : PseudoCounts\n",
      "[05/14/2024 05:03:53 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - ================================================================================\n",
      "[05/14/2024 05:03:56 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 1000        | E: 1           | L: 1000        | R: 0.000       | FPS: 367.456   | T: 0:00:02    \n",
      "[05/14/2024 05:03:58 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 2000        | E: 2           | L: 1000        | R: 0.000       | FPS: 394.624   | T: 0:00:05    \n",
      "[05/14/2024 05:04:05 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 3000        | E: 3           | L: 1000        | R: 0.000       | FPS: 244.838   | T: 0:00:12    \n",
      "[05/14/2024 05:04:13 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 4000        | E: 4           | L: 1000        | R: 0.001       | FPS: 203.291   | T: 0:00:19    \n",
      "[05/14/2024 05:04:20 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 5000        | E: 5           | L: 1000        | R: 0.001       | FPS: 186.675   | T: 0:00:26    \n",
      "[05/14/2024 05:04:28 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 6000        | E: 6           | L: 1000        | R: 0.001       | FPS: 174.465   | T: 0:00:34    \n",
      "[05/14/2024 05:04:35 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 7000        | E: 7           | L: 1000        | R: 0.001       | FPS: 167.744   | T: 0:00:41    \n",
      "[05/14/2024 05:04:42 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 8000        | E: 8           | L: 1000        | R: 0.001       | FPS: 162.902   | T: 0:00:49    \n",
      "[05/14/2024 05:04:49 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 9000        | E: 9           | L: 1000        | R: 0.001       | FPS: 159.881   | T: 0:00:56    \n",
      "[05/14/2024 05:04:57 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 10000       | E: 10          | L: 1000        | R: 0.001       | FPS: 157.679   | T: 0:01:03    \n",
      "[05/14/2024 05:04:57 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Training Accomplished!\n",
      "[05/14/2024 05:04:57 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Model saved at: e:\\HSMRCS\\rllte\\examples\\rlexplore\\logs\\default\\2024-05-14-05-03-53\\model\n"
     ]
    }
   ],
   "source": [
    "# create the vectorized environments\n",
    "device = 'cuda' if th.cuda.is_available() else 'cpu'\n",
    "envs = make_dmc_env('hopper_hop', device=device)\n",
    "print(device, envs.observation_space, envs.action_space)\n",
    "# create the intrinsic reward module\n",
    "irs = PseudoCounts(envs, device=device)\n",
    "# create the PPO agent\n",
    "agent = DDPG(envs, device=device)\n",
    "# set the intrinsic reward module\n",
    "agent.set(reward=irs)\n",
    "# train the agent\n",
    "agent.train(10000, log_interval=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rllte",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
