{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Load the libraries**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3.common.base_class import BaseAlgorithm\n",
    "from stable_baselines3.common.callbacks import BaseCallback\n",
    "from stable_baselines3.common.env_util import make_vec_env\n",
    "from stable_baselines3 import PPO, SAC\n",
    "import torch as th\n",
    "\n",
    "from rllte.xplore.reward import E3B"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**For on-policy RL algorithms**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cuda device\n",
      "torch.Size([2048, 4, 3]) torch.Size([2048, 4, 1]) torch.Size([2048, 4]) torch.Size([2048, 4]) torch.Size([2048, 4, 3])\n",
      "----------------------------------\n",
      "| rollout/           |           |\n",
      "|    ep_len_mean     | 200       |\n",
      "|    ep_rew_mean     | -1.31e+03 |\n",
      "| time/              |           |\n",
      "|    fps             | 1706      |\n",
      "|    iterations      | 1         |\n",
      "|    time_elapsed    | 4         |\n",
      "|    total_timesteps | 8192      |\n",
      "----------------------------------\n",
      "torch.Size([2048, 4, 3]) torch.Size([2048, 4, 1]) torch.Size([2048, 4]) torch.Size([2048, 4]) torch.Size([2048, 4, 3])\n",
      "------------------------------------------\n",
      "| rollout/                |              |\n",
      "|    ep_len_mean          | 200          |\n",
      "|    ep_rew_mean          | -1.25e+03    |\n",
      "| time/                   |              |\n",
      "|    fps                  | 1111         |\n",
      "|    iterations           | 2            |\n",
      "|    time_elapsed         | 14           |\n",
      "|    total_timesteps      | 16384        |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0031577125 |\n",
      "|    clip_fraction        | 0.0214       |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.4         |\n",
      "|    explained_variance   | 0.00427      |\n",
      "|    learning_rate        | 0.0003       |\n",
      "|    loss                 | 2.44e+03     |\n",
      "|    n_updates            | 10           |\n",
      "|    policy_gradient_loss | -0.00175     |\n",
      "|    std                  | 0.976        |\n",
      "|    value_loss           | 7.63e+03     |\n",
      "------------------------------------------\n",
      "torch.Size([2048, 4, 3]) torch.Size([2048, 4, 1]) torch.Size([2048, 4]) torch.Size([2048, 4]) torch.Size([2048, 4, 3])\n",
      "-----------------------------------------\n",
      "| rollout/                |             |\n",
      "|    ep_len_mean          | 200         |\n",
      "|    ep_rew_mean          | -1.22e+03   |\n",
      "| time/                   |             |\n",
      "|    fps                  | 984         |\n",
      "|    iterations           | 3           |\n",
      "|    time_elapsed         | 24          |\n",
      "|    total_timesteps      | 24576       |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.002031011 |\n",
      "|    clip_fraction        | 0.00642     |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.4        |\n",
      "|    explained_variance   | 0.0282      |\n",
      "|    learning_rate        | 0.0003      |\n",
      "|    loss                 | 2.21e+03    |\n",
      "|    n_updates            | 20          |\n",
      "|    policy_gradient_loss | -0.000329   |\n",
      "|    std                  | 0.987       |\n",
      "|    value_loss           | 6.21e+03    |\n",
      "-----------------------------------------\n",
      "torch.Size([2048, 4, 3]) torch.Size([2048, 4, 1]) torch.Size([2048, 4]) torch.Size([2048, 4]) torch.Size([2048, 4, 3])\n",
      "------------------------------------------\n",
      "| rollout/                |              |\n",
      "|    ep_len_mean          | 200          |\n",
      "|    ep_rew_mean          | -1.19e+03    |\n",
      "| time/                   |              |\n",
      "|    fps                  | 926          |\n",
      "|    iterations           | 4            |\n",
      "|    time_elapsed         | 35           |\n",
      "|    total_timesteps      | 32768        |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0026716313 |\n",
      "|    clip_fraction        | 0.0162       |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.4         |\n",
      "|    explained_variance   | 0.00133      |\n",
      "|    learning_rate        | 0.0003       |\n",
      "|    loss                 | 2.12e+03     |\n",
      "|    n_updates            | 30           |\n",
      "|    policy_gradient_loss | -0.0014      |\n",
      "|    std                  | 0.981        |\n",
      "|    value_loss           | 5.84e+03     |\n",
      "------------------------------------------\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<stable_baselines3.ppo.ppo.PPO at 0x1951542f3d0>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class RLeXploreWithOnPolicyRL(BaseCallback):\n",
    "    \"\"\"\n",
    "    A custom callback for combining RLeXplore and on-policy algorithms from SB3.\n",
    "    \"\"\"\n",
    "    def __init__(self, irs, verbose=0):\n",
    "        super(RLeXploreWithOnPolicyRL, self).__init__(verbose)\n",
    "        self.irs = irs\n",
    "        self.buffer = None\n",
    "\n",
    "    def init_callback(self, model: BaseAlgorithm) -> None:\n",
    "        super().init_callback(model)\n",
    "        self.buffer = self.model.rollout_buffer\n",
    "\n",
    "    def _on_step(self) -> bool:\n",
    "        \"\"\"\n",
    "        This method will be called by the model after each call to `env.step()`.\n",
    "\n",
    "        :return: (bool) If the callback returns False, training is aborted early.\n",
    "        \"\"\"\n",
    "        observations = self.locals[\"obs_tensor\"]\n",
    "        device = observations.device\n",
    "        actions = th.as_tensor(self.locals[\"actions\"], device=device)\n",
    "        rewards = th.as_tensor(self.locals[\"rewards\"], device=device)\n",
    "        dones = th.as_tensor(self.locals[\"dones\"], device=device)\n",
    "        next_observations = th.as_tensor(self.locals[\"new_obs\"], device=device)\n",
    "\n",
    "        # ===================== watch the interaction ===================== #\n",
    "        self.irs.watch(observations, actions, rewards, dones, dones, next_observations)\n",
    "        # ===================== watch the interaction ===================== #\n",
    "        return True\n",
    "\n",
    "    def _on_rollout_end(self) -> None:\n",
    "        # ===================== compute the intrinsic rewards ===================== #\n",
    "        # prepare the data samples\n",
    "        obs = th.as_tensor(self.buffer.observations)\n",
    "        # get the new observations\n",
    "        new_obs = obs.clone()\n",
    "        new_obs[:-1] = obs[1:]\n",
    "        new_obs[-1] = th.as_tensor(self.locals[\"new_obs\"])\n",
    "        actions = th.as_tensor(self.buffer.actions)\n",
    "        rewards = th.as_tensor(self.buffer.rewards)\n",
    "        dones = th.as_tensor(self.buffer.episode_starts)\n",
    "        print(obs.shape, actions.shape, rewards.shape, dones.shape, obs.shape)\n",
    "        # compute the intrinsic rewards\n",
    "        intrinsic_rewards = irs.compute(\n",
    "            samples=dict(observations=obs, actions=actions, \n",
    "                         rewards=rewards, terminateds=dones, \n",
    "                         truncateds=dones, next_observations=new_obs),\n",
    "            sync=True)\n",
    "        # add the intrinsic rewards to the buffer\n",
    "        self.buffer.advantages += intrinsic_rewards.cpu().numpy()\n",
    "        self.buffer.returns += intrinsic_rewards.cpu().numpy()\n",
    "        # ===================== compute the intrinsic rewards ===================== #\n",
    "\n",
    "# Parallel environments\n",
    "device = 'cuda'\n",
    "n_envs = 4\n",
    "envs = make_vec_env(\"Pendulum-v1\", n_envs=n_envs)\n",
    "\n",
    "# ===================== build the reward ===================== #\n",
    "irs = E3B(envs, device=device)\n",
    "# ===================== build the reward ===================== #\n",
    "\n",
    "model = PPO(\"MlpPolicy\", envs, verbose=1, device=device)\n",
    "model.learn(total_timesteps=25000, callback=RLeXploreWithOnPolicyRL(irs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**For off-policy RL algorithms**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RLeXploreWithOffPolicyRL(BaseCallback):\n",
    "    \"\"\"\n",
    "    A custom callback for combining RLeXplore and off-policy algorithms from SB3. \n",
    "    \"\"\"\n",
    "    def __init__(self, irs, verbose=0):\n",
    "        super(RLeXploreWithOffPolicyRL, self).__init__(verbose)\n",
    "        self.irs = irs\n",
    "        self.buffer = None\n",
    "\n",
    "    def init_callback(self, model: BaseAlgorithm) -> None:\n",
    "        super().init_callback(model)\n",
    "        self.buffer = self.model.replay_buffer\n",
    "        \n",
    "\n",
    "    def _on_step(self) -> bool:\n",
    "        \"\"\"\n",
    "        This method will be called by the model after each call to `env.step()`.\n",
    "\n",
    "        :return: (bool) If the callback returns False, training is aborted early.\n",
    "        \"\"\"\n",
    "        device = self.irs.device\n",
    "        obs = th.as_tensor(self.locals['self']._last_obs, device=device)\n",
    "        actions = th.as_tensor(self.locals[\"actions\"], device=device)\n",
    "        rewards = th.as_tensor(self.locals[\"rewards\"], device=device)\n",
    "        dones = th.as_tensor(self.locals[\"dones\"], device=device)\n",
    "        next_obs = th.as_tensor(self.locals[\"new_obs\"], device=device)\n",
    "\n",
    "        # ===================== watch the interaction ===================== #\n",
    "        self.irs.watch(obs, actions, rewards, dones, dones, next_obs)\n",
    "        # ===================== watch the interaction ===================== #\n",
    "        \n",
    "        # ===================== compute the intrinsic rewards ===================== #\n",
    "        intrinsic_rewards = irs.compute(samples={'observations':obs.unsqueeze(0), \n",
    "                                            'actions':actions.unsqueeze(0), \n",
    "                                            'rewards':rewards.unsqueeze(0),\n",
    "                                            'terminateds':dones.unsqueeze(0),\n",
    "                                            'truncateds':dones.unsqueeze(0),\n",
    "                                            'next_observations':next_obs.unsqueeze(0)}, \n",
    "                                            sync=False)\n",
    "        # ===================== compute the intrinsic rewards ===================== #\n",
    "\n",
    "        try:\n",
    "            # add the intrinsic rewards to the original rewards\n",
    "            self.locals['rewards'] += intrinsic_rewards.cpu().numpy().squeeze()\n",
    "            # update the intrinsic reward module\n",
    "            replay_data = self.buffer.sample(batch_size=self.irs.batch_size)\n",
    "            self.irs.update(samples={'observations': th.as_tensor(replay_data.observations).unsqueeze(1).to(device), # (n_steps, n_envs, *obs_shape)\n",
    "                                     'actions': th.as_tensor(replay_data.actions).unsqueeze(1).to(device),\n",
    "                                     'rewards': th.as_tensor(replay_data.rewards).to(device),\n",
    "                                     'terminateds': th.as_tensor(replay_data.dones).to(device),\n",
    "                                     'truncateds': th.as_tensor(replay_data.dones).to(device),\n",
    "                                     'next_observations': th.as_tensor(replay_data.next_observations).unsqueeze(1).to(device)\n",
    "                                     })\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "        return True\n",
    "\n",
    "    def _on_rollout_end(self) -> None:\n",
    "        pass\n",
    "\n",
    "# Parallel environments\n",
    "device = 'cuda'\n",
    "n_envs = 4\n",
    "envs = make_vec_env(\"Pendulum-v1\", n_envs=n_envs)\n",
    "\n",
    "# ===================== build the reward ===================== #\n",
    "irs = E3B(envs, device=device)\n",
    "# ===================== build the reward ===================== #\n",
    "\n",
    "model = SAC(\"MlpPolicy\", envs, verbose=1, device=device)\n",
    "model.learn(total_timesteps=25000, callback=RLeXploreWithOffPolicyRL(irs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rllte",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
