{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../\")\n",
    "sys.path.append(\"../../models/Memory_RL\")\n",
    "\n",
    "# from environments.Passive_T_Maze_Flag.env.env_passive_t_maze_flag import TMazeClassicPassive\n",
    "from models.Memory_RL.envs.tmaze import TMazeClassicPassive\n",
    "from models.Memory_RL.policies.models.policy_rnn_dqn import ModelFreeOffPolicy_DQN_RNN\n",
    "import os \n",
    "\n",
    "import numpy as np\n",
    "import gym\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import torch\n",
    "import yaml\n",
    "import time\n",
    "from moviepy.editor import ImageSequenceClip, VideoFileClip\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "from configs.rl.name_fns import name_fn as name_fn1\n",
    "from ml_collections import ConfigDict\n",
    "from typing import Tuple\n",
    "from torchkit import pytorch_utils as ptu\n",
    "\n",
    "def dqn_name_fn(\n",
    "    config: ConfigDict, max_episode_steps: int, max_training_steps: int\n",
    ") -> Tuple[ConfigDict, str]:\n",
    "    config, name = name_fn1(config)\n",
    "    # set eps = 1/T, so that the asymptotic prob to\n",
    "    # sample fully exploited trajectory during exploration is\n",
    "    # (1-1/T)^T = 1/e\n",
    "    config.init_eps = 1.0\n",
    "    config.end_eps = 1.0 / max_episode_steps\n",
    "    config.schedule_steps = config.schedule_end * max_training_steps\n",
    "\n",
    "    return config, name\n",
    "\n",
    "\n",
    "def get_rl_config():\n",
    "    config = ConfigDict()\n",
    "    config.name_fn = dqn_name_fn\n",
    "\n",
    "    config.algo = \"dqn\"\n",
    "\n",
    "    config.critic_lr = 3e-4\n",
    "\n",
    "    config.config_critic = ConfigDict()\n",
    "    config.config_critic.hidden_dims = (256, 256)\n",
    "\n",
    "    config.discount = 0.99\n",
    "    config.tau = 0.005\n",
    "    config.schedule_end = 0.1  # at least good for TMaze-like envs\n",
    "\n",
    "    config.replay_buffer_size = 1e6\n",
    "    config.replay_buffer_num_episodes = 1e3\n",
    "\n",
    "    return config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ml_collections import ConfigDict\n",
    "from typing import Tuple\n",
    "from configs.seq_models.name_fns import name_fn\n",
    "\n",
    "\n",
    "def attn_name_fn(config: ConfigDict, max_episode_steps: int) -> Tuple[ConfigDict, str]:\n",
    "    config, name = name_fn(config, max_episode_steps)\n",
    "\n",
    "    config.model.seq_model_config.hidden_size = 0\n",
    "    if config.model.observ_embedder is not None:\n",
    "        config.model.seq_model_config.hidden_size += (\n",
    "            config.model.observ_embedder.hidden_size\n",
    "        )\n",
    "    if config.model.action_embedder is not None:\n",
    "        config.model.seq_model_config.hidden_size += (\n",
    "            config.model.action_embedder.hidden_size\n",
    "        )\n",
    "    if config.model.reward_embedder is not None:\n",
    "        config.model.seq_model_config.hidden_size += (\n",
    "            config.model.reward_embedder.hidden_size\n",
    "        )\n",
    "\n",
    "    config.model.seq_model_config.max_seq_length = (\n",
    "        config.sampled_seq_len + 1\n",
    "    )  # NOTE: zero-prepend\n",
    "\n",
    "    return config, name\n",
    "\n",
    "\n",
    "def get_seq_config():\n",
    "    config = ConfigDict()\n",
    "    config.name_fn = attn_name_fn\n",
    "\n",
    "    config.is_markov = False\n",
    "    config.is_attn = True\n",
    "    config.use_dropout = True\n",
    "\n",
    "    config.sampled_seq_len = -1\n",
    "\n",
    "    config.clip = False\n",
    "    config.max_norm = 1.0\n",
    "    config.use_l2_norm = False\n",
    "\n",
    "    # fed into Module\n",
    "    config.model = ConfigDict()\n",
    "\n",
    "    # seq_model_config specific\n",
    "    config.model.seq_model_config = ConfigDict()\n",
    "    config.model.seq_model_config.name = \"gpt\"\n",
    "\n",
    "    config.model.seq_model_config.hidden_size = (\n",
    "        128  # NOTE: will be overwritten by name_fn\n",
    "    )\n",
    "    config.model.seq_model_config.n_layer = 1\n",
    "    config.model.seq_model_config.n_head = 1\n",
    "    config.model.seq_model_config.pdrop = 0.1\n",
    "    config.model.seq_model_config.position_encoding = \"sine\"\n",
    "\n",
    "    # embedders\n",
    "    config.model.observ_embedder = ConfigDict()\n",
    "    config.model.observ_embedder.name = \"mlp\"\n",
    "    config.model.observ_embedder.hidden_size = 64\n",
    "\n",
    "    config.model.action_embedder = ConfigDict()\n",
    "    config.model.action_embedder.name = \"mlp\"\n",
    "    config.model.action_embedder.hidden_size = 64\n",
    "\n",
    "    config.model.reward_embedder = ConfigDict()\n",
    "    config.model.reward_embedder.name = \"mlp\"\n",
    "    config.model.reward_embedder.hidden_size = 0\n",
    "\n",
    "    return config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import permutations\n",
    "\n",
    "def generate_permutations(nums):\n",
    "\n",
    "    perms = permutations(nums)\n",
    "    result = [int(''.join(map(str, perm))) for perm in perms]\n",
    "    \n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "set device: cuda:0\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# AGENT_CLASSES = {\n",
    "#     \"Policy_MLP\": Policy_MLP,\n",
    "#     \"Policy_RNN_MLP\": Policy_RNN_MLP,\n",
    "#     \"Policy_Separate_RNN\": Policy_Separate_RNN,\n",
    "#     \"Policy_Shared_RNN\": Policy_Shared_RNN,\n",
    "#     \"Policy_DQN_RNN\": Policy_DQN_RNN,\n",
    "# }\n",
    "from torchkit.pytorch_utils import set_gpu_mode\n",
    "set_gpu_mode('cuda', 0)\n",
    "\n",
    "agent_class = ModelFreeOffPolicy_DQN_RNN\n",
    "agent_arch = agent_class.ARCH\n",
    "\n",
    "device = torch.device('cuda:0')\n",
    "torch.set_default_device(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0]\n",
      " [0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]\n"
     ]
    }
   ],
   "source": [
    "config_path = \n",
    "\n",
    "\n",
    "\n",
    "episode_timeout = 18\n",
    "corridor_length = episode_timeout - 2\n",
    "penalty = -1/(episode_timeout - 1)\n",
    "\n",
    "\n",
    "\n",
    "with open(config_path, 'r') as file:\n",
    "    config = yaml.safe_load(file)\n",
    "\n",
    "env = TMazeClassicPassive(episode_length=episode_timeout, \n",
    "                            corridor_length=corridor_length, \n",
    "                            goal_reward=1.0,\n",
    "                            penalty=penalty)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_episode_steps = 18\n",
    "max_training_steps = 999\n",
    "\n",
    "config_seq, _ = attn_name_fn(get_seq_config(), max_episode_steps = max_episode_steps)\n",
    "config_rl, _ = dqn_name_fn(config = get_rl_config(), max_episode_steps =max_episode_steps , max_training_steps =max_training_steps)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'h.0.ln_1.weight': torch.Size([128]), 'h.0.ln_1.bias': torch.Size([128]), 'h.0.attn.c_attn.weight': torch.Size([128, 384]), 'h.0.attn.c_attn.bias': torch.Size([384]), 'h.0.attn.c_proj.weight': torch.Size([128, 128]), 'h.0.attn.c_proj.bias': torch.Size([128]), 'h.0.ln_2.weight': torch.Size([128]), 'h.0.ln_2.bias': torch.Size([128]), 'h.0.mlp.c_fc.weight': torch.Size([128, 512]), 'h.0.mlp.c_fc.bias': torch.Size([512]), 'h.0.mlp.c_proj.weight': torch.Size([512, 128]), 'h.0.mlp.c_proj.bias': torch.Size([128]), 'ln_f.weight': torch.Size([128]), 'ln_f.bias': torch.Size([128])}\n"
     ]
    }
   ],
   "source": [
    "image_encoder_fn = lambda: None\n",
    "\n",
    "obs_dim = env.observation_space.shape[0]\n",
    "act_dim = 4\n",
    "\n",
    "freeze_critic = False\n",
    "\n",
    "agent = agent_class(\n",
    "    obs_dim=obs_dim,\n",
    "    action_dim=act_dim,\n",
    "    config_seq=config_seq,\n",
    "    config_rl=config_rl,\n",
    "    image_encoder_fn=image_encoder_fn,\n",
    "    freeze_critic=freeze_critic,\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List\n",
    "from collections import OrderedDict, namedtuple\n",
    "\n",
    "\n",
    "class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])):\n",
    "    def __repr__(self):\n",
    "        if not self.missing_keys and not self.unexpected_keys:\n",
    "            return '<All keys matched successfully>'\n",
    "        return super().__repr__()\n",
    "\n",
    "    __str__ = __repr__\n",
    "\n",
    "def load_state_dict(module, state_dict: Mapping[str, Any],\n",
    "                    strict: bool = True, assign: bool = False):\n",
    "    r\"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.\n",
    "\n",
    "    If :attr:`strict` is ``True``, then\n",
    "    the keys of :attr:`state_dict` must exactly match the keys returned\n",
    "    by this module's :meth:`~torch.nn.Module.state_dict` function.\n",
    "\n",
    "    .. warning::\n",
    "        If :attr:`assign` is ``True`` the optimizer must be created after\n",
    "        the call to :attr:`load_state_dict` unless\n",
    "        :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.\n",
    "\n",
    "    Args:\n",
    "        state_dict (dict): a dict containing parameters and\n",
    "            persistent buffers.\n",
    "        strict (bool, optional): whether to strictly enforce that the keys\n",
    "            in :attr:`state_dict` match the keys returned by this module's\n",
    "            :meth:`~torch.nn.Module.state_dict` function. Default: ``True``\n",
    "        assign (bool, optional): When ``False``, the properties of the tensors\n",
    "            in the current module are preserved while when ``True``, the\n",
    "            properties of the Tensors in the state dict are preserved. The only\n",
    "            exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s\n",
    "            for which the value from the module is preserved.\n",
    "            Default: ``False``\n",
    "\n",
    "    Returns:\n",
    "        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:\n",
    "            * **missing_keys** is a list of str containing the missing keys\n",
    "            * **unexpected_keys** is a list of str containing the unexpected keys\n",
    "\n",
    "    Note:\n",
    "        If a parameter or buffer is registered as ``None`` and its corresponding key\n",
    "        exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a\n",
    "        ``RuntimeError``.\n",
    "    \"\"\"\n",
    "    if not isinstance(state_dict, Mapping):\n",
    "        raise TypeError(f\"Expected state_dict to be dict-like, got {type(state_dict)}.\")\n",
    "\n",
    "    missing_keys: List[str] = []\n",
    "    unexpected_keys: List[str] = []\n",
    "    error_msgs: List[str] = []\n",
    "\n",
    "    # copy state_dict so _load_from_state_dict can modify it\n",
    "    metadata = getattr(state_dict, '_metadata', None)\n",
    "    state_dict = OrderedDict(state_dict)\n",
    "    if metadata is not None:\n",
    "        # mypy isn't aware that \"_metadata\" exists in state_dict\n",
    "        state_dict._metadata = metadata  # type: ignore[attr-defined]\n",
    "\n",
    "\n",
    "    def load(module, local_state_dict, prefix=''):\n",
    "        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})\n",
    "        if assign:\n",
    "            local_metadata['assign_to_params_buffers'] = assign\n",
    "        module._load_from_state_dict(\n",
    "            local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)\n",
    "        for name, child in module._modules.items():\n",
    "            if child is not None:\n",
    "                child_prefix = prefix + name + '.'\n",
    "                child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}\n",
    "                load(child, child_state_dict, child_prefix)  # noqa: F821\n",
    "\n",
    "        # Note that the hook can modify missing_keys and unexpected_keys.\n",
    "        incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)\n",
    "        for hook in module._load_state_dict_post_hooks.values():\n",
    "            out = hook(module, incompatible_keys)\n",
    "            assert out is None, (\n",
    "                \"Hooks registered with ``register_load_state_dict_post_hook`` are not\"\n",
    "                \"expected to return new values, if incompatible_keys need to be modified,\"\n",
    "                \"it should be done inplace.\"\n",
    "            )\n",
    "\n",
    "    load(module, state_dict)\n",
    "    del load\n",
    "\n",
    "    if strict:\n",
    "        if len(unexpected_keys) > 0:\n",
    "            error_msgs.insert(\n",
    "                0, 'Unexpected key(s) in state_dict: {}. '.format(\n",
    "                    ', '.join(f'\"{k}\"' for k in unexpected_keys)))\n",
    "        if len(missing_keys) > 0:\n",
    "            error_msgs.insert(\n",
    "                0, 'Missing key(s) in state_dict: {}. '.format(\n",
    "                    ', '.join(f'\"{k}\"' for k in missing_keys)))\n",
    "\n",
    "    if len(error_msgs) > 0:\n",
    "        print('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n",
    "                        module.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n",
    "        #raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n",
    "        #                module.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n",
    "    return _IncompatibleKeys(missing_keys, unexpected_keys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckp = torch.load(ckpt_path, map_location=device)\n",
    "# agent.load_state_dict(ckp, strict = False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]],\n",
       "       device='cuda:0', dtype=torch.uint8)"
      ]
     },
     "execution_count": 116,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent.critic.seq_model.transformer.h[0].attn.bias #.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],\n",
       "          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]],\n",
       "       device='cuda:0', dtype=torch.uint8)"
      ]
     },
     "execution_count": 117,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ckp['critic.seq_model.transformer.h.0.attn.bias'] #.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import helpers as utl\n",
    "\n",
    "deterministic = True\n",
    "eval_episodes = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = agent.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "nums = [1, 2, 3, 4, 5]\n",
    "eval_seeds = generate_permutations(nums)\n",
    "\n",
    "videos_limit = len(eval_seeds) + 1\n",
    "n_episode = len(eval_seeds)\n",
    "\n",
    "\n",
    "render = False\n",
    "\n",
    "total_reward = 0\n",
    "num_successes = 0\n",
    "total_steps = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Episode: 0, seed: 12345 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 1, seed: 12354 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 2, seed: 12435 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 3, seed: 12453 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 4, seed: 12534 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 5, seed: 12543 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 6, seed: 13245 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 7, seed: 13254 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 8, seed: 13425 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 9, seed: 13452 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 10, seed: 13524 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 11, seed: 13542 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 12, seed: 14235 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 13, seed: 14253 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 14, seed: 14325 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 15, seed: 14352 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 16, seed: 14523 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 17, seed: 14532 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 18, seed: 15234 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 19, seed: 15243 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 20, seed: 15324 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 21, seed: 15342 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 22, seed: 15423 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 23, seed: 15432 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 24, seed: 21345 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 25, seed: 21354 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 26, seed: 21435 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 27, seed: 21453 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 28, seed: 21534 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 29, seed: 21543 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 30, seed: 23145 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 31, seed: 23154 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 32, seed: 23415 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 33, seed: 23451 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 34, seed: 23514 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 35, seed: 23541 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 36, seed: 24135 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 37, seed: 24153 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 38, seed: 24315 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 39, seed: 24351 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 40, seed: 24513 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 41, seed: 24531 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 42, seed: 25134 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 43, seed: 25143 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 44, seed: 25314 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 45, seed: 25341 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 46, seed: 25413 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 47, seed: 25431 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 48, seed: 31245 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 49, seed: 31254 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 50, seed: 31425 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 51, seed: 31452 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 52, seed: 31524 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 53, seed: 31542 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 54, seed: 32145 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 55, seed: 32154 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 56, seed: 32415 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 57, seed: 32451 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 58, seed: 32514 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 59, seed: 32541 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 60, seed: 34125 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 61, seed: 34152 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 62, seed: 34215 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 63, seed: 34251 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 64, seed: 34512 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 65, seed: 34521 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 66, seed: 35124 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 67, seed: 35142 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 68, seed: 35214 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 69, seed: 35241 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 70, seed: 35412 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 71, seed: 35421 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 72, seed: 41235 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 73, seed: 41253 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 74, seed: 41325 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 75, seed: 41352 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 76, seed: 41523 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 77, seed: 41532 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 78, seed: 42135 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 79, seed: 42153 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 80, seed: 42315 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 81, seed: 42351 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 82, seed: 42513 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 83, seed: 42531 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 84, seed: 43125 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 85, seed: 43152 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 86, seed: 43215 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 87, seed: 43251 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 88, seed: 43512 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 89, seed: 43521 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 90, seed: 45123 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 91, seed: 45132 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 92, seed: 45213 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 93, seed: 45231 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 94, seed: 45312 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 95, seed: 45321 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 96, seed: 51234 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 97, seed: 51243 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 98, seed: 51324 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 99, seed: 51342 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 100, seed: 51423 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 101, seed: 51432 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 102, seed: 52134 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 103, seed: 52143 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 104, seed: 52314 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 105, seed: 52341 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 106, seed: 52413 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 107, seed: 52431 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 108, seed: 53124 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 109, seed: 53142 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 110, seed: 53214 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 111, seed: 53241 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 112, seed: 53412 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 113, seed: 53421 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 114, seed: 54123 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 115, seed: 54132 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 116, seed: 54213 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 117, seed: 54231 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 118, seed: 54312 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Episode: 119, seed: 54321 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0\n",
      "Total num episodes: 120 Success rate: 0.0, Mean reward: -0.1764705888926983, Mean steps: 18.0\n"
     ]
    }
   ],
   "source": [
    "\n",
    "agent.eval()  # set to eval mode for deterministic dropout\n",
    "\n",
    "returns_per_episode = np.zeros(n_episode)\n",
    "success_rate = np.zeros(n_episode)\n",
    "# total_steps = np.zeros(n_episode)\n",
    "\n",
    "for task_idx in range(n_episode):\n",
    "    step = 0\n",
    "    running_reward = 0.0\n",
    "    done_rollout = False\n",
    "\n",
    "    if eval_seeds is not None and False:\n",
    "        obs = ptu.from_numpy(env.reset(seed = eval_seeds[task_idx])).to(device)  # reset\n",
    "    else:\n",
    "        obs = ptu.from_numpy(env.reset()).to(device)  # reset\n",
    "\n",
    "    obs = obs.reshape(1, obs.shape[-1])\n",
    "\n",
    "    # assume initial reward = 0.0\n",
    "    action, reward, internal_state = agent.get_initial_info(\n",
    "        config_seq.sampled_seq_len\n",
    "    )\n",
    "\n",
    "    while not done_rollout:\n",
    "        action, internal_state = agent.act(\n",
    "            prev_internal_state=internal_state,\n",
    "            prev_action=action.to(device),\n",
    "            reward=reward.to(device),\n",
    "            obs=obs.to(device),\n",
    "            deterministic=deterministic,\n",
    "        )\n",
    "\n",
    "\n",
    "        # observe reward and next obs\n",
    "        next_obs, reward, done, info = utl.env_step(\n",
    "            env, action.squeeze(dim=0)\n",
    "        )\n",
    "\n",
    "        # add raw reward\n",
    "        running_reward += reward.item()\n",
    "        step += 1\n",
    "        done_rollout = False if ptu.get_numpy(done[0][0]) == 0.0 else True\n",
    "\n",
    "        # set: obs <- next_obs\n",
    "        obs = next_obs.clone()\n",
    "\n",
    "    #returns_per_episode[task_idx] = running_reward\n",
    "    #total_steps[task_idx] = step\n",
    "    if \"success\" in info and info[\"success\"] == True:  # keytodoor\n",
    "        success_rate[task_idx] = 1.0\n",
    "        num_successes += 1\n",
    "    \n",
    "    total_reward += running_reward\n",
    "    total_steps += step\n",
    "\n",
    "    curr_seed = eval_seeds[task_idx]\n",
    "    print(f'Episode: {task_idx}, seed: {curr_seed} Reward: {running_reward}, Steps: {step} Mean reward: {total_reward / (task_idx + 1)}, Mean steps: {total_steps / (task_idx + 1)}')\n",
    "\n",
    "\n",
    "print(f'Total num episodes: {n_episode} Success rate: {num_successes / n_episode}, Mean reward: {total_reward / n_episode}, Mean steps: {total_steps / n_episode}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
