{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../\")\n",
    "sys.path.append(\"../../models/Memory_RL\")\n",
    "\n",
    "# from environments.Passive_T_Maze_Flag.env.env_passive_t_maze_flag import TMazeClassicPassive\n",
    "from models.Memory_RL.envs.tmaze import TMazeClassicPassive\n",
    "from models.Memory_RL.policies.models.policy_rnn_dqn import ModelFreeOffPolicy_DQN_RNN\n",
    "import os \n",
    "\n",
    "import numpy as np\n",
    "import gym\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import torch\n",
    "import yaml\n",
    "import time\n",
    "from moviepy.editor import ImageSequenceClip, VideoFileClip\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "from configs.rl.name_fns import name_fn as name_fn1\n",
    "from ml_collections import ConfigDict\n",
    "from typing import Tuple\n",
    "from torchkit import pytorch_utils as ptu\n",
    "\n",
    "def dqn_name_fn(\n",
    "    config: ConfigDict, max_episode_steps: int, max_training_steps: int\n",
    ") -> Tuple[ConfigDict, str]:\n",
    "    config, name = name_fn1(config)\n",
    "    # set eps = 1/T, so that the asymptotic prob to\n",
    "    # sample fully exploited trajectory during exploration is\n",
    "    # (1-1/T)^T = 1/e\n",
    "    config.init_eps = 1.0\n",
    "    config.end_eps = 1.0 / max_episode_steps\n",
    "    config.schedule_steps = config.schedule_end * max_training_steps\n",
    "\n",
    "    return config, name\n",
    "\n",
    "\n",
    "def get_rl_config():\n",
    "    config = ConfigDict()\n",
    "    config.name_fn = dqn_name_fn\n",
    "\n",
    "    config.algo = \"dqn\"\n",
    "\n",
    "    config.critic_lr = 3e-4\n",
    "\n",
    "    config.config_critic = ConfigDict()\n",
    "    config.config_critic.hidden_dims = (256, 256)\n",
    "\n",
    "    config.discount = 0.99\n",
    "    config.tau = 0.005\n",
    "    config.schedule_end = 0.1  # at least good for TMaze-like envs\n",
    "\n",
    "    config.replay_buffer_size = 1e6\n",
    "    config.replay_buffer_num_episodes = 1e3\n",
    "\n",
    "    return config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ml_collections import ConfigDict\n",
    "from typing import Tuple\n",
    "\n",
    "\n",
    "def name_fn(config: ConfigDict, max_episode_steps: int) -> Tuple[ConfigDict, str]:\n",
    "    name = \"\"\n",
    "\n",
    "    if config.sampled_seq_len == -1:\n",
    "        config.sampled_seq_len = max_episode_steps\n",
    "\n",
    "    name += f\"{config.model.seq_model_config.name}-len-{config.sampled_seq_len}/\"\n",
    "\n",
    "    assert config.clip is False\n",
    "\n",
    "    del config.name_fn\n",
    "    return config, name\n",
    "\n",
    "\n",
    "def get_seq_config():\n",
    "    config = ConfigDict()\n",
    "    config.name_fn = name_fn\n",
    "\n",
    "    config.is_markov = False\n",
    "    config.is_attn = False\n",
    "    config.use_dropout = False\n",
    "\n",
    "    config.sampled_seq_len = -1\n",
    "\n",
    "    config.clip = False\n",
    "    config.max_norm = 1.0\n",
    "    config.use_l2_norm = False\n",
    "\n",
    "    # fed into Module\n",
    "    config.model = ConfigDict()\n",
    "\n",
    "    # seq_model specific\n",
    "    config.model.seq_model_config = ConfigDict()\n",
    "    config.model.seq_model_config.name = \"lstm\"\n",
    "    config.model.seq_model_config.hidden_size = 128\n",
    "    config.model.seq_model_config.n_layer = 1\n",
    "\n",
    "    # embedders\n",
    "    config.model.observ_embedder = ConfigDict()\n",
    "    config.model.observ_embedder.name = \"mlp\"\n",
    "    config.model.observ_embedder.hidden_size = 32\n",
    "\n",
    "    config.model.action_embedder = ConfigDict()\n",
    "    config.model.action_embedder.name = \"mlp\"\n",
    "    config.model.action_embedder.hidden_size = 16\n",
    "\n",
    "    config.model.reward_embedder = ConfigDict()\n",
    "    config.model.reward_embedder.name = \"mlp\"\n",
    "    config.model.reward_embedder.hidden_size = 0\n",
    "\n",
    "    return config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import permutations\n",
    "\n",
    "def generate_permutations(nums):\n",
    "\n",
    "    perms = permutations(nums)\n",
    "    result = [int(''.join(map(str, perm))) for perm in perms]\n",
    "    \n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "set device: cuda:0\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# AGENT_CLASSES = {\n",
    "#     \"Policy_MLP\": Policy_MLP,\n",
    "#     \"Policy_RNN_MLP\": Policy_RNN_MLP,\n",
    "#     \"Policy_Separate_RNN\": Policy_Separate_RNN,\n",
    "#     \"Policy_Shared_RNN\": Policy_Shared_RNN,\n",
    "#     \"Policy_DQN_RNN\": Policy_DQN_RNN,\n",
    "# }\n",
    "from torchkit.pytorch_utils import set_gpu_mode\n",
    "set_gpu_mode('cuda', 0)\n",
    "\n",
    "agent_class = ModelFreeOffPolicy_DQN_RNN\n",
    "agent_arch = agent_class.ARCH\n",
    "\n",
    "device = torch.device('cuda:0')\n",
    "torch.set_default_device(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0]\n",
      " [0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "episode_timeout = 31\n",
    "corridor_length = episode_timeout - 2\n",
    "penalty = -1/(episode_timeout - 1)\n",
    "\n",
    "\n",
    "env = TMazeClassicPassive(episode_length=episode_timeout, \n",
    "                            corridor_length=corridor_length, \n",
    "                            goal_reward=1.0,\n",
    "                            penalty=penalty)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_episode_steps = 15\n",
    "max_training_steps = 999\n",
    "\n",
    "config_seq, _ = name_fn(get_seq_config(), max_episode_steps = max_episode_steps)\n",
    "config_rl, _ = dqn_name_fn(config = get_rl_config(), max_episode_steps =max_episode_steps , max_training_steps =max_training_steps)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_encoder_fn = lambda: None\n",
    "\n",
    "obs_dim = env.observation_space.shape[0]\n",
    "act_dim = 4\n",
    "\n",
    "freeze_critic = False\n",
    "\n",
    "agent = agent_class(\n",
    "    obs_dim=obs_dim,\n",
    "    action_dim=act_dim,\n",
    "    config_seq=config_seq,\n",
    "    config_rl=config_rl,\n",
    "    image_encoder_fn=image_encoder_fn,\n",
    "    freeze_critic=freeze_critic,\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Checkpoint 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent.load_state_dict(torch.load(ckpt_path, map_location=device))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "nums = [1, 2, 3, 4, 5]\n",
    "eval_seeds = generate_permutations(nums)\n",
    "\n",
    "videos_limit = len(eval_seeds) + 1\n",
    "n_episode = len(eval_seeds)\n",
    "\n",
    "\n",
    "render = False\n",
    "\n",
    "total_reward = 0\n",
    "num_successes = 0\n",
    "total_steps = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "agent.eval()  # set to eval mode for deterministic dropout\n",
    "\n",
    "returns_per_episode = np.zeros(n_episode)\n",
    "success_rate = np.zeros(n_episode)\n",
    "# total_steps = np.zeros(n_episode)\n",
    "\n",
    "for task_idx in range(n_episode):\n",
    "    step = 0\n",
    "    running_reward = 0.0\n",
    "    done_rollout = False\n",
    "\n",
    "    if eval_seeds is not None and False:\n",
    "        obs = ptu.from_numpy(env.reset(seed = eval_seeds[task_idx])).to(device)  # reset\n",
    "    else:\n",
    "        obs = ptu.from_numpy(env.reset()).to(device)  # reset\n",
    "\n",
    "    obs = obs.reshape(1, obs.shape[-1])\n",
    "\n",
    "    # assume initial reward = 0.0\n",
    "    action, reward, internal_state = agent.get_initial_info(\n",
    "        config_seq.sampled_seq_len\n",
    "    )\n",
    "\n",
    "    while not done_rollout:\n",
    "        action, internal_state = agent.act(\n",
    "            prev_internal_state=internal_state,\n",
    "            prev_action=action.to(device),\n",
    "            reward=reward.to(device),\n",
    "            obs=obs.to(device),\n",
    "            deterministic=deterministic,\n",
    "        )\n",
    "\n",
    "\n",
    "        # observe reward and next obs\n",
    "        next_obs, reward, done, info = utl.env_step(\n",
    "            env, action.squeeze(dim=0)\n",
    "        )\n",
    "\n",
    "        # add raw reward\n",
    "        running_reward += reward.item()\n",
    "        step += 1\n",
    "        done_rollout = False if ptu.get_numpy(done[0][0]) == 0.0 else True\n",
    "\n",
    "        # set: obs <- next_obs\n",
    "        obs = next_obs.clone()\n",
    "\n",
    "    #returns_per_episode[task_idx] = running_reward\n",
    "    #total_steps[task_idx] = step\n",
    "    if \"success\" in info and info[\"success\"] == True:  # keytodoor\n",
    "        success_rate[task_idx] = 1.0\n",
    "        num_successes += 1\n",
    "    \n",
    "    total_reward += running_reward\n",
    "    total_steps += step\n",
    "\n",
    "    curr_seed = eval_seeds[task_idx]\n",
    "    print(f'Episode: {task_idx}, seed: {curr_seed} Reward: {running_reward}, Steps: {step} Mean reward: {total_reward / (task_idx + 1)}, Mean steps: {total_steps / (task_idx + 1)}')\n",
    "\n",
    "\n",
    "print(f'Total num episodes: {n_episode} Success rate: {num_successes / n_episode}, Mean reward: {total_reward / n_episode}, Mean steps: {total_steps / n_episode}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
