{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Load the libraries**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gymnasium as gym\n",
    "import numpy as np\n",
    "import torch as th\n",
    "\n",
    "from rllte.env.utils import Gymnasium2Torch\n",
    "from rllte.xplore.reward import ICM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Create a fake Atari environment with image observations**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FakeAtari(gym.Env):\n",
    "    def __init__(self):\n",
    "        self.action_space = gym.spaces.Discrete(7)\n",
    "        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(4, 84, 84))\n",
    "        self.count = 0\n",
    "\n",
    "    def reset(self):\n",
    "        self.count = 0\n",
    "        return self.observation_space.sample(), {}\n",
    "\n",
    "    def step(self, action):\n",
    "        self.count += 1\n",
    "        if self.count > 100 and np.random.rand() < 0.1:\n",
    "            term = trunc = True\n",
    "        else:\n",
    "            term = trunc = False\n",
    "        return self.observation_space.sample(), 0, term, trunc, {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Synchronous Mode**:\n",
    "\n",
    "**The `.update()` will be automatically invoked in the `.compute()` function, usually for on-policy RL algorithms.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "e:\\anaconda3\\envs\\marllib\\lib\\site-packages\\torch\\nn\\modules\\conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
      "  return F.conv2d(input, weight, bias, self.stride,\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[6.5928, 5.5006, 5.3346,  ..., 5.3286, 6.5831, 5.1960],\n",
      "        [3.4611, 3.4754, 5.3265,  ..., 4.9442, 5.3422, 3.7767],\n",
      "        [3.7612, 3.7736, 6.5909,  ..., 3.7735, 4.9679, 6.5922],\n",
      "        ...,\n",
      "        [3.4737, 4.9781, 6.5358,  ..., 5.2204, 5.3287, 6.5794],\n",
      "        [3.7659, 5.3463, 5.3620,  ..., 6.5735, 5.3437, 3.7666],\n",
      "        [5.4956, 4.9599, 5.3435,  ..., 6.5689, 5.2174, 3.7587]],\n",
      "       device='cuda:0')\n",
      "torch.Size([128, 8])\n"
     ]
    }
   ],
   "source": [
    "# set the parameters\n",
    "device = 'cuda' if th.cuda.is_available() else 'cpu'\n",
    "n_steps = 128\n",
    "n_envs = 8\n",
    "# create the vectorized environments\n",
    "envs = gym.vector.AsyncVectorEnv([FakeAtari for _ in range(n_envs)])\n",
    "# wrap the environments to convert the observations to torch tensors\n",
    "envs = Gymnasium2Torch(envs, device)\n",
    "# create the intrinsic reward module\n",
    "irs = ICM(envs, device)\n",
    "# reset the environments and get the initial observations\n",
    "obs, infos = envs.reset()\n",
    "# create a dictionary to store the samples\n",
    "samples = {'observations':[], \n",
    "           'actions':[], \n",
    "           'rewards':[],\n",
    "           'terminateds':[],\n",
    "           'truncateds':[],\n",
    "           'next_observations':[]}\n",
    "# sampling loop\n",
    "for _ in range(n_steps):\n",
    "    # sample random actions\n",
    "    actions = th.stack([th.as_tensor(envs.action_space.sample()) for _ in range(n_envs)])\n",
    "    # environment step\n",
    "    next_obs, rewards, terminateds, truncateds, infos = envs.step(actions)\n",
    "    # watch the interactions and get necessary information for the intrinsic reward computation\n",
    "    irs.watch(observations=obs, \n",
    "              actions=actions, \n",
    "              rewards=rewards,\n",
    "              terminateds=terminateds,\n",
    "              truncateds=truncateds,\n",
    "              next_observations=next_obs)\n",
    "    # store the samples\n",
    "    samples['observations'].append(obs)\n",
    "    samples['actions'].append(actions)\n",
    "    samples['rewards'].append(rewards)\n",
    "    samples['terminateds'].append(terminateds)\n",
    "    samples['truncateds'].append(truncateds)\n",
    "    samples['next_observations'].append(next_obs)\n",
    "    obs = next_obs\n",
    "# compute the intrinsic rewards\n",
    "samples = {k: th.stack(v) for k, v in samples.items()}\n",
    "intrinsic_rewards = irs.compute(samples=samples)\n",
    "print(intrinsic_rewards)\n",
    "print(intrinsic_rewards.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Asynchronous Mode**:\n",
    "\n",
    "**The `.update()` must be invoked separately, usually for off-policy RL algorithms.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2.5189, 2.5474, 2.5163, 2.5503, 2.1224, 2.1203, 2.5226, 2.6890]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[3.5146, 2.5905, 3.6144, 3.4424, 3.3997, 3.4378, 3.5162, 2.5951]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[3.9397, 3.0138, 3.0003, 3.9741, 3.3031, 4.1907, 2.9930, 3.3006]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[3.7179, 3.7295, 3.7109, 4.5688, 3.3561, 3.7105, 4.4071, 3.7139]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[4.8140, 3.6262, 4.7395, 4.9179, 3.6130, 4.7960, 3.6326, 4.8056]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[5.0398, 5.0510, 5.1055, 5.3550, 3.8483, 4.2502, 5.3718, 4.2484]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[5.6157, 4.4320, 4.0090, 5.6285, 5.4394, 4.4302, 5.2698, 5.5971]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[5.6835, 5.6864, 4.1652, 5.6790, 4.6121, 4.6099, 5.6657, 5.6570]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[5.6814, 5.7421, 4.7672, 4.7597, 4.7601, 4.3097, 4.3259, 5.7375]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[4.9303, 4.9200, 5.8479, 6.2329, 4.9210, 5.8397, 6.0653, 6.2386]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.0084, 4.5753, 5.0504, 6.3834, 6.3831, 6.0676, 6.0940, 5.0500]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[5.1951, 6.3671, 6.2316, 4.7049, 6.5330, 6.1471, 6.2264, 6.1414]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[5.2593, 4.7726, 6.4777, 6.3163, 6.4697, 4.7580, 6.2402, 5.2581]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.4456, 5.3758, 5.3288, 6.3376, 4.8353, 6.5705, 5.3339, 6.7704]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.5045, 6.4952, 6.4789, 6.5752, 4.9466, 5.4617, 6.7029, 6.4732]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.5939, 6.5816, 5.0173, 5.5160, 6.6742, 5.5143, 7.0145, 6.6522]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.8995, 6.7711, 5.0903, 6.9133, 6.9366, 6.6753, 6.9412, 6.9126]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[5.6771, 7.2055, 6.8529, 7.0467, 5.7132, 7.0183, 5.7049, 6.8671]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[5.7847, 5.7781, 6.8816, 5.7667, 7.1182, 7.3120, 5.8266, 6.8931]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[5.8520, 7.0614, 7.3840, 6.9593, 7.4139, 5.8421, 7.2144, 7.1987]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.1393, 5.9264, 7.2863, 5.9334, 5.9343, 7.1650, 7.3139, 5.9198]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[5.4096, 7.5668, 5.9831, 7.1920, 5.9870, 7.2063, 7.0720, 7.3720]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.4056, 7.5959, 6.0306, 7.5999, 6.0336, 7.6374, 6.0452, 7.6443]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.6416, 6.0455, 6.0523, 7.6636, 6.0505, 5.4667, 6.0389, 7.4645]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.2210, 6.0524, 7.4637, 5.4693, 7.6710, 5.4991, 6.0689, 7.1928]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[5.5446, 7.2855, 7.7199, 7.2614, 7.5277, 7.3406, 7.3372, 7.3796]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.0634, 5.5116, 5.4904, 5.4997, 5.5174, 7.4939, 7.7054, 7.6895]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[5.5292, 6.1226, 7.7319, 6.1375, 7.5021, 7.3694, 7.3333, 7.7288]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.4000, 7.7558, 6.1351, 7.7527, 6.1455, 5.5215, 7.5265, 7.3050]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.5665, 7.4047, 6.1359, 7.7609, 5.5729, 7.4408, 5.5511, 7.4111]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[5.5638, 6.1537, 7.7850, 6.0955, 7.3921, 5.5569, 7.5695, 7.5742]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.3309, 7.3234, 5.5776, 7.8148, 7.4272, 6.1497, 6.1578, 5.5858]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.8540, 6.2200, 6.1938, 7.8392, 7.8433, 7.6165, 7.4645, 7.4662]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.2338, 6.2148, 5.6336, 7.8571, 7.6251, 5.6216, 7.8475, 7.3810]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.7001, 7.6813, 6.2371, 6.2457, 7.8997, 6.2469, 7.6550, 7.4445]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.2513, 7.7041, 7.4667, 6.2685, 5.6738, 5.6541, 7.5227, 7.5220]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.9752, 7.5654, 6.2762, 6.2868, 6.3125, 7.4704, 7.9608, 7.4869]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.5261, 6.3393, 6.3498, 7.6241, 7.5224, 7.6593, 7.8034, 7.5425]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.5744, 8.0477, 6.4003, 6.3686, 7.6020, 7.5972, 7.6196, 6.4099]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.4328, 7.6024, 6.3944, 5.7862, 6.4168, 5.7895, 6.4107, 6.3490]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.1178, 6.4018, 8.1104, 8.0850, 5.7984, 8.0718, 7.5951, 6.4191]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.4093, 7.7411, 7.9178, 5.8257, 6.4005, 7.9099, 7.6135, 6.4129]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.9136, 7.7384, 8.1548, 7.6241, 5.8339, 5.8249, 6.4197, 7.8855]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.1471, 6.4325, 5.7995, 7.9193, 8.1279, 5.8162, 7.8807, 7.7583]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.1265, 6.4460, 5.8118, 6.4054, 7.7365, 7.8935, 5.8426, 7.6337]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.4356, 7.9150, 6.4331, 6.4339, 5.8481, 5.8278, 7.6502, 6.4065]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.7760, 6.4581, 7.9355, 6.4533, 6.4415, 6.4529, 7.7674, 8.1853]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.8146, 7.9990, 7.9994, 7.7119, 7.7116, 7.7351, 6.4972, 8.2211]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.5144, 8.0027, 7.8617, 7.8226, 6.5235, 6.5109, 6.5245, 7.8415]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.5235, 6.5224, 7.8763, 7.7513, 5.9053, 5.9250, 8.2696, 6.5215]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.5555, 6.5527, 8.0437, 7.7585, 6.5333, 8.2738, 8.2881, 7.7885]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.5318, 8.2863, 6.5439, 7.7932, 5.9245, 8.0821, 8.2962, 8.0655]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.5701, 7.8167, 8.2960, 8.3254, 6.5946, 8.0762, 6.5734, 7.8333]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.6012, 6.5821, 7.8501, 8.3284, 7.8458, 6.5970, 8.1203, 6.5814]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[7.8703, 7.8616, 6.5897, 6.6091, 5.9940, 8.3374, 6.6075, 8.3650]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.6230, 7.8734, 7.9758, 6.5839, 6.6488, 8.1504, 6.6329, 8.1729]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.3946, 7.9885, 6.6454, 8.1497, 8.4235, 8.3977, 8.3952, 8.4141]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.6257, 8.1712, 6.6068, 8.0210, 8.4161, 6.0351, 6.0244, 8.3892]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.6589, 8.3948, 6.6505, 8.0002, 6.0335, 6.6388, 7.8725, 7.8852]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.1538, 6.6093, 6.0048, 7.9734, 6.6489, 7.8948, 8.4132, 6.0372]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.0134, 7.9058, 6.6598, 6.6376, 8.1580, 6.6524, 8.4214, 8.0338]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.6785, 6.6432, 8.1789, 6.6255, 8.4429, 6.0151, 8.4246, 6.6256]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.0298, 8.0484, 8.0407, 6.0433, 8.0293, 7.9335, 8.0287, 8.0244]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.6811, 6.6823, 8.0581, 6.6834, 8.4880, 7.9648, 7.9775, 8.2562]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.2161, 8.2022, 8.4435, 8.2210, 6.0334, 6.0583, 8.4802, 8.0467]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.6682, 6.6929, 6.6658, 7.9382, 8.4540, 6.6904, 6.0749, 8.4526]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.7233, 6.6992, 6.6813, 6.7164, 6.6973, 7.9534, 8.2242, 8.4854]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.7151, 8.2630, 6.7359, 6.7135, 6.6956, 6.7263, 8.2843, 8.2874]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.5250, 6.7474, 8.1178, 8.1024, 8.2682, 6.7160, 8.2719, 6.0902]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.3157, 6.7374, 8.2888, 6.7319, 8.2941, 6.7404, 6.7288, 6.7116]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.1529, 8.5258, 8.0289, 8.1354, 6.7358, 8.5729, 8.0293, 8.0347]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.1339, 8.0364, 8.0346, 8.3424, 6.1317, 6.7833, 8.0458, 6.7454]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.7609, 6.7531, 8.5327, 8.3447, 8.3435, 8.3396, 6.1294, 8.3299]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.7615, 6.1233, 6.1321, 6.1172, 8.0395, 6.7536, 6.1193, 8.0368]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.7583, 8.1704, 6.7604, 6.7672, 8.1699, 8.0348, 8.3242, 8.5868]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.5308, 6.1060, 6.7676, 6.7349, 6.7598, 6.1401, 6.7720, 8.5315]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.7612, 8.5406, 6.7576, 8.3222, 8.1523, 6.1131, 8.3296, 8.0375]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.3307, 8.3650, 6.7364, 8.5620, 8.3509, 8.1694, 6.7357, 8.5408]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.1665, 8.0898, 8.3833, 8.3359, 8.5848, 8.6063, 8.0929, 8.1680]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.8305, 6.8177, 8.0888, 8.0783, 8.1273, 8.1072, 6.8551, 8.1005]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.8148, 8.3965, 8.6342, 8.4003, 8.4274, 8.6757, 8.6204, 6.8126]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.1310, 6.8639, 8.6585, 8.6400, 6.8338, 8.2189, 6.8077, 8.1169]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.4355, 6.1733, 6.8518, 8.2428, 8.2181, 6.1838, 6.8269, 8.0936]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.4090, 6.1979, 8.4010, 8.4160, 8.2650, 6.1985, 6.8597, 8.2372]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.1474, 6.8097, 8.1304, 6.2080, 8.2196, 6.8178, 6.8505, 6.2248]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.2081, 6.8371, 8.1384, 8.2529, 8.1212, 8.2330, 6.1886, 8.1477]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.3795, 8.2516, 6.8465, 8.1333, 8.6696, 8.6656, 6.2023, 8.6656]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.6585, 8.1079, 6.8574, 6.1693, 8.2378, 6.2161, 6.8750, 8.2143]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.4081, 8.1471, 6.1968, 6.2077, 8.4072, 8.1461, 8.2626, 6.8819]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.4318, 8.2639, 6.8412, 6.8419, 8.2558, 6.8568, 8.6697, 6.2107]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.4493, 8.6288, 6.8302, 8.1581, 8.6536, 6.8365, 8.6570, 6.8290]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.8507, 8.1606, 8.1686, 8.4687, 6.8322, 8.1700, 8.2578, 8.4617]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.2331, 8.1584, 6.2056, 8.4490, 6.8615, 6.8434, 8.2541, 6.8773]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.9058, 6.8418, 8.4680, 8.1609, 8.6888, 8.2776, 6.9042, 8.4767]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.2250, 6.8609, 8.6997, 6.2016, 8.1977, 6.8639, 8.1894, 8.1649]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.1986, 8.6992, 8.4312, 8.6919, 6.8711, 8.7119, 8.6787, 8.7054]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.8604, 8.2680, 6.8611, 8.2887, 8.7327, 8.7004, 8.2953, 8.2793]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.3141, 8.7419, 8.2897, 8.3094, 6.2419, 6.8968, 6.8840, 8.7114]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.1897, 8.3088, 8.1816, 6.9062, 8.4766, 6.9056, 8.7292, 8.7444]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.8741, 8.7288, 6.8937, 6.9063, 6.8967, 6.2670, 8.2925, 8.7337]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.9418, 6.8914, 6.9281, 8.2858, 6.8695, 8.7314, 8.4836, 8.7418]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.7458, 6.9122, 8.2148, 8.7280, 6.2606, 8.2943, 8.7372, 6.2530]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.4752, 8.3063, 6.8897, 6.2333, 8.7319, 6.2651, 8.4943, 8.4909]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.4794, 6.8804, 6.2371, 8.2028, 6.8882, 8.3140, 8.7399, 8.6859]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.2095, 8.7098, 8.1918, 8.7073, 6.2252, 8.2765, 6.2286, 8.2643]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.1560, 8.2943, 8.7009, 6.2320, 8.4643, 6.8888, 6.8696, 8.4834]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.4894, 6.2508, 8.3219, 8.2882, 8.7281, 6.8887, 8.2556, 6.8984]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.9006, 6.2357, 6.2288, 8.1901, 6.8638, 6.8818, 6.8861, 8.7114]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.2029, 8.1787, 6.2239, 6.8546, 8.2052, 6.8559, 6.2152, 6.2269]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.8357, 8.1710, 8.7295, 6.8672, 8.1849, 8.1913, 6.9011, 6.8710]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.5132, 8.4854, 8.2192, 8.7438, 8.1793, 8.4898, 8.4861, 6.9018]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.3121, 8.2311, 8.2241, 6.2631, 8.4910, 6.9289, 8.2010, 6.9013]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.9124, 6.8876, 6.2415, 8.2359, 8.4993, 6.2779, 8.2022, 6.2898]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.2668, 8.4886, 6.8927, 6.2515, 8.2280, 6.8708, 6.8910, 8.2162]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.3145, 8.7234, 6.2575, 6.8985, 6.9208, 8.5134, 6.2417, 8.7341]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.7603, 6.8909, 8.2110, 8.5341, 8.3155, 6.8900, 8.2011, 6.8998]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.8704, 6.2291, 6.2593, 8.2752, 6.2550, 8.7256, 8.2760, 6.2612]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.1928, 8.7037, 8.3072, 8.3273, 8.2878, 8.7430, 6.2446, 8.4911]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.8711, 6.2402, 6.2049, 6.2454, 6.8962, 6.8983, 6.9248, 6.8769]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.3029, 6.9179, 8.3193, 6.8749, 6.2477, 6.8769, 8.7397, 8.2534]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.5013, 8.2306, 8.3101, 8.5164, 8.3227, 8.5005, 8.5087, 8.2972]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.3213, 8.2133, 6.2458, 6.2348, 6.8796, 6.8956, 6.9175, 8.2171]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.8987, 8.2319, 8.2047, 6.2503, 6.9289, 8.3158, 6.9315, 6.9145]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.2423, 8.2152, 8.3513, 6.9187, 8.5156, 8.5352, 8.7737, 6.9267]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.5061, 8.2281, 8.3405, 6.2778, 6.9141, 6.9359, 8.7346, 8.5349]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.9135, 6.9272, 6.9522, 8.3297, 8.2253, 6.2750, 6.9384, 6.9647]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[6.9333, 8.7747, 8.7607, 8.7728, 8.3614, 8.5240, 6.2737, 8.3637]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n",
      "tensor([[8.7714, 6.9247, 8.3714, 6.9230, 8.5350, 8.7865, 8.5471, 8.7704]],\n",
      "       device='cuda:0') torch.Size([1, 8])\n"
     ]
    }
   ],
   "source": [
    "# set the parameters\n",
    "device = 'cuda' if th.cuda.is_available() else 'cpu'\n",
    "n_steps = 128\n",
    "n_envs = 8\n",
    "# create the vectorized environments\n",
    "envs = gym.vector.AsyncVectorEnv([FakeAtari for _ in range(n_envs)])\n",
    "# wrap the environments to convert the observations to torch tensors\n",
    "envs = Gymnasium2Torch(envs, device)\n",
    "# create the intrinsic reward module\n",
    "irs = ICM(envs, device)\n",
    "# reset the environments and get the initial observations\n",
    "obs, infos = envs.reset()\n",
    "# create a dictionary to store the samples\n",
    "samples = {'observations':[], \n",
    "           'actions':[], \n",
    "           'rewards':[],\n",
    "           'terminateds':[],\n",
    "           'truncateds':[],\n",
    "           'next_observations':[]}\n",
    "# sampling loop\n",
    "for _ in range(n_steps):\n",
    "    # sample random actions\n",
    "    actions = th.stack([th.as_tensor(envs.action_space.sample()) for _ in range(n_envs)])\n",
    "    # environment step\n",
    "    next_obs, rewards, terminateds, truncateds, infos = envs.step(actions)\n",
    "    # watch the interactions and get necessary information for the intrinsic reward computation\n",
    "    irs.watch(observations=obs, \n",
    "              actions=actions, \n",
    "              rewards=rewards,\n",
    "              terminateds=terminateds,\n",
    "              truncateds=truncateds,\n",
    "              next_observations=next_obs)\n",
    "    # compute the intrinsic rewards at each step\n",
    "    intrinsic_rewards = irs.compute(samples={'observations':obs.unsqueeze(0), \n",
    "                                            'actions':actions.unsqueeze(0), \n",
    "                                            'rewards':rewards.unsqueeze(0),\n",
    "                                            'terminateds':terminateds.unsqueeze(0),\n",
    "                                            'truncateds':truncateds.unsqueeze(0),\n",
    "                                            'next_observations':next_obs.unsqueeze(0)}, \n",
    "                                            sync=False)\n",
    "    print(intrinsic_rewards, intrinsic_rewards.shape)\n",
    "    # store the samples\n",
    "    samples['observations'].append(obs)\n",
    "    samples['actions'].append(actions)\n",
    "    samples['rewards'].append(rewards)\n",
    "    samples['terminateds'].append(terminateds)\n",
    "    samples['truncateds'].append(truncateds)\n",
    "    samples['next_observations'].append(next_obs)\n",
    "    obs = next_obs\n",
    "# update the intrinsic reward module\n",
    "samples = {k: th.stack(v) for k, v in samples.items()}\n",
    "irs.update(samples=samples)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rllte",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
