{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../\")\n",
    "sys.path.append(\"../../models/DTQN\")\n",
    "\n",
    "from environments.Passive_T_Maze_Flag.env.env_passive_t_maze_flag import TMazeClassicPassive\n",
    "from models.DTQN.dtqn.agents.dtqn import DtqnAgent\n",
    "from models.DTQN.utils.agent_utils import get_agent\n",
    "\n",
    "import numpy as np\n",
    "import gym\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import torch\n",
    "import yaml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List\n",
    "from collections import OrderedDict, namedtuple\n",
    "\n",
    "\n",
    "class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])):\n",
    "    def __repr__(self):\n",
    "        if not self.missing_keys and not self.unexpected_keys:\n",
    "            return '<All keys matched successfully>'\n",
    "        return super().__repr__()\n",
    "\n",
    "    __str__ = __repr__\n",
    "\n",
    "def load_state_dict(module, state_dict: Mapping[str, Any],\n",
    "                    strict: bool = True, assign: bool = False):\n",
    "    r\"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.\n",
    "\n",
    "    If :attr:`strict` is ``True``, then\n",
    "    the keys of :attr:`state_dict` must exactly match the keys returned\n",
    "    by this module's :meth:`~torch.nn.Module.state_dict` function.\n",
    "\n",
    "    .. warning::\n",
    "        If :attr:`assign` is ``True`` the optimizer must be created after\n",
    "        the call to :attr:`load_state_dict` unless\n",
    "        :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.\n",
    "\n",
    "    Args:\n",
    "        state_dict (dict): a dict containing parameters and\n",
    "            persistent buffers.\n",
    "        strict (bool, optional): whether to strictly enforce that the keys\n",
    "            in :attr:`state_dict` match the keys returned by this module's\n",
    "            :meth:`~torch.nn.Module.state_dict` function. Default: ``True``\n",
    "        assign (bool, optional): When ``False``, the properties of the tensors\n",
    "            in the current module are preserved while when ``True``, the\n",
    "            properties of the Tensors in the state dict are preserved. The only\n",
    "            exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s\n",
    "            for which the value from the module is preserved.\n",
    "            Default: ``False``\n",
    "\n",
    "    Returns:\n",
    "        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:\n",
    "            * **missing_keys** is a list of str containing the missing keys\n",
    "            * **unexpected_keys** is a list of str containing the unexpected keys\n",
    "\n",
    "    Note:\n",
    "        If a parameter or buffer is registered as ``None`` and its corresponding key\n",
    "        exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a\n",
    "        ``RuntimeError``.\n",
    "    \"\"\"\n",
    "    if not isinstance(state_dict, Mapping):\n",
    "        raise TypeError(f\"Expected state_dict to be dict-like, got {type(state_dict)}.\")\n",
    "\n",
    "    missing_keys: List[str] = []\n",
    "    unexpected_keys: List[str] = []\n",
    "    error_msgs: List[str] = []\n",
    "\n",
    "    # copy state_dict so _load_from_state_dict can modify it\n",
    "    metadata = getattr(state_dict, '_metadata', None)\n",
    "    state_dict = OrderedDict(state_dict)\n",
    "    if metadata is not None:\n",
    "        # mypy isn't aware that \"_metadata\" exists in state_dict\n",
    "        state_dict._metadata = metadata  # type: ignore[attr-defined]\n",
    "\n",
    "\n",
    "    def load(module, local_state_dict, prefix=''):\n",
    "        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})\n",
    "        if assign:\n",
    "            local_metadata['assign_to_params_buffers'] = assign\n",
    "        module._load_from_state_dict(\n",
    "            local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)\n",
    "        for name, child in module._modules.items():\n",
    "            if child is not None:\n",
    "                child_prefix = prefix + name + '.'\n",
    "                child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}\n",
    "                load(child, child_state_dict, child_prefix)  # noqa: F821\n",
    "\n",
    "        # Note that the hook can modify missing_keys and unexpected_keys.\n",
    "        incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)\n",
    "        for hook in module._load_state_dict_post_hooks.values():\n",
    "            out = hook(module, incompatible_keys)\n",
    "            assert out is None, (\n",
    "                \"Hooks registered with ``register_load_state_dict_post_hook`` are not\"\n",
    "                \"expected to return new values, if incompatible_keys need to be modified,\"\n",
    "                \"it should be done inplace.\"\n",
    "            )\n",
    "\n",
    "    load(module, state_dict)\n",
    "    del load\n",
    "\n",
    "    if strict:\n",
    "        if len(unexpected_keys) > 0:\n",
    "            error_msgs.insert(\n",
    "                0, 'Unexpected key(s) in state_dict: {}. '.format(\n",
    "                    ', '.join(f'\"{k}\"' for k in unexpected_keys)))\n",
    "        if len(missing_keys) > 0:\n",
    "            error_msgs.insert(\n",
    "                0, 'Missing key(s) in state_dict: {}. '.format(\n",
    "                    ', '.join(f'\"{k}\"' for k in missing_keys)))\n",
    "\n",
    "    if len(error_msgs) > 0:\n",
    "        print('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n",
    "                        module.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n",
    "        #raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n",
    "        #                module.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n",
    "    return _IncompatibleKeys(missing_keys, unexpected_keys)\n",
    "\n",
    "from itertools import permutations\n",
    "\n",
    "def generate_permutations(nums):\n",
    "\n",
    "    perms = permutations(nums)\n",
    "    result = [int(''.join(map(str, perm))) for perm in perms]\n",
    "    \n",
    "    return result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_path = \n",
    "\n",
    "\n",
    "episode_timeout = 20\n",
    "corridor_length = episode_timeout - 2\n",
    "penalty = -1/(episode_timeout - 1)\n",
    "\n",
    "device = torch.device('cuda:0')\n",
    "\n",
    "\n",
    "with open(config_path, 'r') as file:\n",
    "    args = yaml.safe_load(file)\n",
    "\n",
    "env = TMazeClassicPassive(episode_length=episode_timeout, \n",
    "                            corridor_length=corridor_length, \n",
    "                            goal_reward=1.0,\n",
    "                            penalty=penalty)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Checkpoint 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_path = \n",
    "\n",
    "\n",
    "args['inembed'] = 64\n",
    "args['context'] = episode_timeout"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Checkpoint 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# args['pos'] = 'sin'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = get_agent(\n",
    "        args['model'],\n",
    "        env,\n",
    "        env,\n",
    "        args['obsembed'],\n",
    "        args['inembed'],\n",
    "        args['buf_size'],\n",
    "        device,\n",
    "        args['lr'],\n",
    "        args['batch'],\n",
    "        args['context'],\n",
    "        args['history'],\n",
    "        args['num_steps'],\n",
    "        # DTQN specific\n",
    "        args['heads'],\n",
    "        args['layers'],\n",
    "        args['dropout'],\n",
    "        args['identity'],\n",
    "        args['gate'],\n",
    "        args['pos'],\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'tensorboard'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent.load_checkpoint(checkpoint_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_state_dict(agent.policy_network, ckp['policy_net_state_dict'])\n",
    "load_state_dict(agent.target_network, ckp['target_net_state_dict'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "nums = [1, 2, 3, 4, 5]\n",
    "eval_seeds = generate_permutations(nums)\n",
    "\n",
    "videos_limit = len(eval_seeds) + 1\n",
    "n_episode = len(eval_seeds)\n",
    "\n",
    "videos_dir = '/opt/Memory-RL-Codebase/eval/Minigrid_Memory/DTQN'\n",
    "\n",
    "run_name = checkpoint_path.split('/')[-1].strip('.pt')\n",
    "run_type = checkpoint_path.split('/')[-2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent.evaluate(n_episode = n_episode, eval_seeds = eval_seeds, render = False, videos_limit = videos_limit, videos_dir = videos_dir, run_name = run_name, run_type =run_type )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
