{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**RLeXplore provides a standard workflow for creating new intrinsic rewards algorithms.**\n",
    "\n",
    "**The following code demonstrates how to implement the RND algorithm with RLeXplore.**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Load the libraries**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict, Optional\n",
    "\n",
    "import torch as th\n",
    "import torch.nn.functional as F\n",
    "from gymnasium.vector import VectorEnv\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "import sys\n",
    "sys.path.append('../../')\n",
    "\n",
    "from rllte.common.prototype import BaseReward\n",
    "from rllte.xplore.reward.model import ObservationEncoder\n",
    "from rllte.agent import PPO\n",
    "from rllte.env import make_atari_env"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Define the RND class using the `BaseReward`**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RND(BaseReward):\n",
    "    \"\"\"Exploration by Random Network Distillation (RND).\n",
    "        See paper: https://arxiv.org/pdf/1810.12894.pdf\n",
    "\n",
    "    Args:\n",
    "        envs (VectorEnv): The vectorized environments.\n",
    "        device (str): Device (cpu, cuda, ...) on which the code should be run.\n",
    "        beta (float): The initial weighting coefficient of the intrinsic rewards.\n",
    "        kappa (float): The decay rate of the weighting coefficient.\n",
    "        gamma (Optional[float]): Intrinsic reward discount rate, default is `None`.\n",
    "        rwd_norm_type (str): Normalization type for intrinsic rewards from ['rms', 'minmax', 'none'].\n",
    "        obs_norm_type (str): Normalization type for observations data from ['rms', 'none'].\n",
    "\n",
    "        latent_dim (int): The dimension of encoding vectors.\n",
    "        lr (float): The learning rate.\n",
    "        batch_size (int): The batch size for training.\n",
    "        update_proportion (float): The proportion of the training data used for updating the forward dynamics models.\n",
    "        encoder_model (str): The network architecture of the encoder from ['mnih', 'pathak'].\n",
    "        weight_init (str): The weight initialization method from ['default', 'orthogonal'].\n",
    "\n",
    "    Returns:\n",
    "        Instance of RND.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        envs: VectorEnv,\n",
    "        device: str = \"cpu\",\n",
    "        beta: float = 1.0,\n",
    "        kappa: float = 0.0,\n",
    "        gamma: Optional[float] = None,\n",
    "        rwd_norm_type: str = \"rms\",\n",
    "        obs_norm_type: str = \"rms\",\n",
    "        latent_dim: int = 128,\n",
    "        lr: float = 0.001,\n",
    "        batch_size: int = 256,\n",
    "        update_proportion: float = 1.0,\n",
    "        encoder_model: str = \"mnih\",\n",
    "        weight_init: str = \"orthogonal\",\n",
    "    ) -> None:\n",
    "        super().__init__(envs, device, beta, kappa, gamma, rwd_norm_type, obs_norm_type)\n",
    "        # build the predictor and target networks\n",
    "        self.predictor = ObservationEncoder(obs_shape=self.obs_shape, \n",
    "                                            latent_dim=latent_dim, \n",
    "                                            encoder_model=encoder_model, \n",
    "                                            weight_init=weight_init\n",
    "                                            ).to(self.device)\n",
    "        self.target = ObservationEncoder(obs_shape=self.obs_shape, \n",
    "                                         latent_dim=latent_dim, \n",
    "                                         encoder_model=encoder_model, \n",
    "                                         weight_init=weight_init\n",
    "                                         ).to(self.device)\n",
    "\n",
    "        # freeze the randomly initialized target network parameters\n",
    "        for p in self.target.parameters():\n",
    "            p.requires_grad = False\n",
    "        # set the optimizer\n",
    "        self.opt = th.optim.Adam(self.predictor.parameters(), lr=lr)\n",
    "        # set the parameters\n",
    "        self.batch_size = batch_size\n",
    "        self.update_proportion = update_proportion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Define the `.compute()` function**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute(self, samples: Dict[str, th.Tensor], sync: bool = True) -> th.Tensor:\n",
    "    \"\"\"Compute the rewards for current samples.\n",
    "\n",
    "    Args:\n",
    "        samples (Dict[str, th.Tensor]): The collected samples. A python dict consists of multiple tensors,\n",
    "            whose keys are ['observations', 'actions', 'rewards', 'terminateds', 'truncateds', 'next_observations'].\n",
    "            For example, the data shape of 'observations' is (n_steps, n_envs, *obs_shape).\n",
    "        sync (bool): Whether to update the reward module after the `compute` function, default is `True`.\n",
    "\n",
    "    Returns:\n",
    "        The intrinsic rewards.\n",
    "    \"\"\"\n",
    "    super().compute(samples)\n",
    "    # get the number of steps and environments\n",
    "    (n_steps, n_envs) = samples.get(\"next_observations\").size()[:2]\n",
    "    # get the next observations\n",
    "    next_obs_tensor = samples.get(\"next_observations\").to(self.device)\n",
    "    # normalize the observations\n",
    "    next_obs_tensor = self.normalize(next_obs_tensor)\n",
    "    # compute the intrinsic rewards\n",
    "    intrinsic_rewards = th.zeros(size=(n_steps, n_envs)).to(self.device)\n",
    "    with th.no_grad():\n",
    "        # get source and target features\n",
    "        src_feats = self.predictor(next_obs_tensor.view(-1, *self.obs_shape))\n",
    "        tgt_feats = self.target(next_obs_tensor.view(-1, *self.obs_shape))\n",
    "        # compute the distance\n",
    "        dist = F.mse_loss(src_feats, tgt_feats, reduction=\"none\").mean(dim=1)\n",
    "        intrinsic_rewards = dist.view(n_steps, n_envs)\n",
    "\n",
    "    # update the reward module\n",
    "    if sync:\n",
    "        self.update(samples)\n",
    "\n",
    "    # scale the intrinsic rewards\n",
    "    return self.scale(intrinsic_rewards)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Define the `update()` function**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update(self, samples: Dict[str, th.Tensor]) -> None:\n",
    "    \"\"\"Update the reward module if necessary.\n",
    "\n",
    "    Args:\n",
    "        samples (Dict[str, th.Tensor]): The collected samples same as the `compute` function.\n",
    "\n",
    "    Returns:\n",
    "        None.\n",
    "    \"\"\"\n",
    "    # get the observations\n",
    "    obs_tensor = (\n",
    "        samples.get(\"observations\").to(self.device).view(-1, *self.obs_shape)\n",
    "    )\n",
    "    # normalize the observations\n",
    "    obs_tensor = self.normalize(obs_tensor)\n",
    "    # create the dataset and loader\n",
    "    dataset = TensorDataset(obs_tensor)\n",
    "    loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)\n",
    "\n",
    "    avg_loss = []\n",
    "    # update the predictor\n",
    "    for _idx, batch_data in enumerate(loader):\n",
    "        # get the batch data\n",
    "        obs = batch_data[0]\n",
    "        # zero the gradients\n",
    "        self.opt.zero_grad()\n",
    "        # get the source and target features\n",
    "        src_feats = self.predictor(obs)\n",
    "        with th.no_grad():\n",
    "            tgt_feats = self.target(obs)\n",
    "\n",
    "        # compute the loss\n",
    "        loss = F.mse_loss(src_feats, tgt_feats, reduction=\"none\").mean(dim=-1)\n",
    "        # use a random mask to select a subset of the training data\n",
    "        mask = th.rand(len(loss), device=self.device)\n",
    "        mask = (mask < self.update_proportion).type(th.FloatTensor).to(self.device)\n",
    "        # get the masked loss\n",
    "        loss = (loss * mask).sum() / th.max(\n",
    "            mask.sum(), th.tensor([1], device=self.device, dtype=th.float32)\n",
    "        )\n",
    "        # backward and update\n",
    "        loss.backward()\n",
    "        self.opt.step()\n",
    "        avg_loss.append(loss.item())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Merge the code**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RND(BaseReward):\n",
    "    \"\"\"Exploration by Random Network Distillation (RND).\n",
    "        See paper: https://arxiv.org/pdf/1810.12894.pdf\n",
    "\n",
    "    Args:\n",
    "        envs (VectorEnv): The vectorized environments.\n",
    "        device (str): Device (cpu, cuda, ...) on which the code should be run.\n",
    "        beta (float): The initial weighting coefficient of the intrinsic rewards.\n",
    "        kappa (float): The decay rate of the weighting coefficient.\n",
    "        gamma (Optional[float]): Intrinsic reward discount rate, default is `None`.\n",
    "        rwd_norm_type (str): Normalization type for intrinsic rewards from ['rms', 'minmax', 'none'].\n",
    "        obs_norm_type (str): Normalization type for observations data from ['rms', 'none'].\n",
    "\n",
    "        latent_dim (int): The dimension of encoding vectors.\n",
    "        lr (float): The learning rate.\n",
    "        batch_size (int): The batch size for training.\n",
    "        update_proportion (float): The proportion of the training data used for updating the forward dynamics models.\n",
    "        encoder_model (str): The network architecture of the encoder from ['mnih', 'pathak'].\n",
    "        weight_init (str): The weight initialization method from ['default', 'orthogonal'].\n",
    "\n",
    "    Returns:\n",
    "        Instance of RND.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        envs: VectorEnv,\n",
    "        device: str = \"cpu\",\n",
    "        beta: float = 1.0,\n",
    "        kappa: float = 0.0,\n",
    "        gamma: Optional[float] = None,\n",
    "        rwd_norm_type: str = \"rms\",\n",
    "        obs_norm_type: str = \"rms\",\n",
    "        latent_dim: int = 128,\n",
    "        lr: float = 0.001,\n",
    "        batch_size: int = 256,\n",
    "        update_proportion: float = 1.0,\n",
    "        encoder_model: str = \"mnih\",\n",
    "        weight_init: str = \"orthogonal\",\n",
    "    ) -> None:\n",
    "        super().__init__(envs, device, beta, kappa, gamma, rwd_norm_type, obs_norm_type)\n",
    "        # build the predictor and target networks\n",
    "        self.predictor = ObservationEncoder(\n",
    "            obs_shape=self.obs_shape,\n",
    "            latent_dim=latent_dim,\n",
    "            encoder_model=encoder_model,\n",
    "            weight_init=weight_init,\n",
    "        ).to(self.device)\n",
    "        self.target = ObservationEncoder(\n",
    "            obs_shape=self.obs_shape,\n",
    "            latent_dim=latent_dim,\n",
    "            encoder_model=encoder_model,\n",
    "            weight_init=weight_init,\n",
    "        ).to(self.device)\n",
    "\n",
    "        # freeze the randomly initialized target network parameters\n",
    "        for p in self.target.parameters():\n",
    "            p.requires_grad = False\n",
    "        # set the optimizer\n",
    "        self.opt = th.optim.Adam(self.predictor.parameters(), lr=lr)\n",
    "        # set the parameters\n",
    "        self.batch_size = batch_size\n",
    "        self.update_proportion = update_proportion\n",
    "\n",
    "    def compute(self, samples: Dict[str, th.Tensor], sync: bool = True) -> th.Tensor:\n",
    "        \"\"\"Compute the rewards for current samples.\n",
    "\n",
    "        Args:\n",
    "            samples (Dict[str, th.Tensor]): The collected samples. A python dict consists of multiple tensors,\n",
    "                whose keys are ['observations', 'actions', 'rewards', 'terminateds', 'truncateds', 'next_observations'].\n",
    "                For example, the data shape of 'observations' is (n_steps, n_envs, *obs_shape).\n",
    "            sync (bool): Whether to update the reward module after the `compute` function, default is `True`.\n",
    "\n",
    "        Returns:\n",
    "            The intrinsic rewards.\n",
    "        \"\"\"\n",
    "        super().compute(samples)\n",
    "        # get the number of steps and environments\n",
    "        (n_steps, n_envs) = samples.get(\"next_observations\").size()[:2]\n",
    "        # get the next observations\n",
    "        next_obs_tensor = samples.get(\"next_observations\").to(self.device)\n",
    "        # normalize the observations\n",
    "        next_obs_tensor = self.normalize(next_obs_tensor)\n",
    "        # compute the intrinsic rewards\n",
    "        intrinsic_rewards = th.zeros(size=(n_steps, n_envs)).to(self.device)\n",
    "        with th.no_grad():\n",
    "            # get source and target features\n",
    "            src_feats = self.predictor(next_obs_tensor.view(-1, *self.obs_shape))\n",
    "            tgt_feats = self.target(next_obs_tensor.view(-1, *self.obs_shape))\n",
    "            # compute the distance\n",
    "            dist = F.mse_loss(src_feats, tgt_feats, reduction=\"none\").mean(dim=1)\n",
    "            intrinsic_rewards = dist.view(n_steps, n_envs)\n",
    "\n",
    "        # update the reward module\n",
    "        if sync:\n",
    "            self.update(samples)\n",
    "\n",
    "        # scale the intrinsic rewards\n",
    "        return self.scale(intrinsic_rewards)\n",
    "\n",
    "    def update(self, samples: Dict[str, th.Tensor]) -> None:\n",
    "        \"\"\"Update the reward module if necessary.\n",
    "\n",
    "        Args:\n",
    "            samples (Dict[str, th.Tensor]): The collected samples same as the `compute` function.\n",
    "\n",
    "        Returns:\n",
    "            None.\n",
    "        \"\"\"\n",
    "        # get the observations\n",
    "        obs_tensor = (\n",
    "            samples.get(\"observations\").to(self.device).view(-1, *self.obs_shape)\n",
    "        )\n",
    "        # normalize the observations\n",
    "        obs_tensor = self.normalize(obs_tensor)\n",
    "        # create the dataset and loader\n",
    "        dataset = TensorDataset(obs_tensor)\n",
    "        loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)\n",
    "\n",
    "        avg_loss = []\n",
    "        # update the predictor\n",
    "        for _idx, batch_data in enumerate(loader):\n",
    "            # get the batch data\n",
    "            obs = batch_data[0]\n",
    "            # zero the gradients\n",
    "            self.opt.zero_grad()\n",
    "            # get the source and target features\n",
    "            src_feats = self.predictor(obs)\n",
    "            with th.no_grad():\n",
    "                tgt_feats = self.target(obs)\n",
    "\n",
    "            # compute the loss\n",
    "            loss = F.mse_loss(src_feats, tgt_feats, reduction=\"none\").mean(dim=-1)\n",
    "            # use a random mask to select a subset of the training data\n",
    "            mask = th.rand(len(loss), device=self.device)\n",
    "            mask = (mask < self.update_proportion).type(th.FloatTensor).to(self.device)\n",
    "            # get the masked loss\n",
    "            loss = (loss * mask).sum() / th.max(\n",
    "                mask.sum(), th.tensor([1], device=self.device, dtype=th.float32)\n",
    "            )\n",
    "            # backward and update\n",
    "            loss.backward()\n",
    "            self.opt.step()\n",
    "            avg_loss.append(loss.item())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Now you can train RL agents with the implemented RND directly**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create the vectorized environments\n",
    "device = 'cuda' if th.cuda.is_available() else 'cpu'\n",
    "envs = make_atari_env('PongNoFrameskip-v4', device=device)\n",
    "print(device, envs.observation_space, envs.action_space)\n",
    "# create the intrinsic reward module\n",
    "irs = RND(envs, device=device)\n",
    "# create the PPO agent\n",
    "agent = PPO(envs, device=device)\n",
    "# set the intrinsic reward module\n",
    "agent.set(reward=irs)\n",
    "# train the agent\n",
    "agent.train(10000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "marllib",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
