{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../\")\n",
    "sys.path.append(\"../../models/episodic_transformer_memory_ppo\")\n",
    "\n",
    "from environments.Passive_T_Maze_Flag.env.env_passive_t_maze_flag import TMazeClassicPassive\n",
    "from models.episodic_transformer_memory_ppo.model import ActorCriticModel\n",
    "import os \n",
    "\n",
    "import numpy as np\n",
    "import gym\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import torch\n",
    "import yaml\n",
    "import time\n",
    "from moviepy.editor import ImageSequenceClip, VideoFileClip\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_transformer_memory(trxl_conf, max_episode_steps, device):\n",
    "    \"\"\"Returns initial tensors for the episodic memory of the transformer.\n",
    "\n",
    "    Arguments:\n",
    "        trxl_conf {dict} -- Transformer configuration dictionary\n",
    "        max_episode_steps {int} -- Maximum number of steps per episode\n",
    "        device {torch.device} -- Target device for the tensors\n",
    "\n",
    "    Returns:\n",
    "        memory {torch.Tensor}, memory_mask {torch.Tensor}, memory_indices {torch.Tensor} -- Initial episodic memory, episodic memory mask, and sliding memory window indices\n",
    "    \"\"\"\n",
    "    # Episodic memory mask used in attention\n",
    "    memory_mask = torch.tril(torch.ones((trxl_conf[\"memory_length\"], trxl_conf[\"memory_length\"])), diagonal=-1)\n",
    "    # Episdic memory tensor\n",
    "    memory = torch.zeros((1, max_episode_steps, trxl_conf[\"num_blocks\"], trxl_conf[\"embed_dim\"])).to(device)\n",
    "    # Setup sliding memory window indices\n",
    "    repetitions = torch.repeat_interleave(torch.arange(0, trxl_conf[\"memory_length\"]).unsqueeze(0), trxl_conf[\"memory_length\"] - 1, dim = 0).long()\n",
    "    memory_indices = torch.stack([torch.arange(i, i + trxl_conf[\"memory_length\"]) for i in range(max_episode_steps - trxl_conf[\"memory_length\"] + 1)]).long()\n",
    "    memory_indices = torch.cat((repetitions, memory_indices))\n",
    "    return memory, memory_mask, memory_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train on L = 15, test L = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_path = '/opt/Memory-RL-Codebase/configs/GTRXL_configs/Passive_T_Maze_Flag/Dense/Passive_T_Maze_Flag_SHORT_TERM.yaml'\n",
    "\n",
    "episode_timeout = 15\n",
    "corridor_length = episode_timeout - 2\n",
    "penalty = -1/(episode_timeout - 1)\n",
    "\n",
    "device = torch.device('cuda:0')\n",
    "\n",
    "\n",
    "with open(config_path, 'r') as file:\n",
    "    config = yaml.safe_load(file)\n",
    "\n",
    "env = TMazeClassicPassive(episode_length=episode_timeout, \n",
    "                            corridor_length=corridor_length, \n",
    "                            goal_reward=1.0,\n",
    "                            penalty=penalty)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Checkpoint 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ckp 1\n",
    "\n",
    "checkpoint_path = '/opt/Memory-RL-Codebase/autorun/checkpoints_2024_09_29_19_00/Passive_T_Maze_Flag/GTXL/GTXL_Passive_T_Maze_Flag_SHORT_TERM_dense/2024_09_29-12_25_14.pt'\n",
    "checkpoint = torch.load(checkpoint_path)\n",
    "\n",
    "config['transformer']['memory_length'] = corridor_length\n",
    "\n",
    "config['transformer']['num_blocks'] = 3\n",
    "config['transformer']['embed_dim'] = 64\n",
    "config['transformer']['num_heads'] = 4\n",
    "config['hidden_layer_size'] = 64\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Checkpoint 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ckp 2\n",
    "\n",
    "checkpoint_path = '/opt/Memory-RL-Codebase/autorun/checkpoints_2024_09_29_22_00/Passive_T_Maze_Flag/GTXL/GTXL_Passive_T_Maze_Flag_SHORT_TERM_dense/2024_09_28-23_36_02.pt'\n",
    "checkpoint = torch.load(checkpoint_path)\n",
    "\n",
    "config['transformer']['memory_length'] = corridor_length\n",
    "\n",
    "\n",
    "config['transformer']['num_blocks'] = 6\n",
    "config['transformer']['embed_dim'] = 128\n",
    "config['transformer']['num_heads'] = 8\n",
    "config['hidden_layer_size'] = 128\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Checkpoint 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ckp 3\n",
    "\n",
    "checkpoint_path = '/opt/Memory-RL-Codebase/autorun/checkpoints/Passive_T_Maze_Flag/GTXL/GTXL_Passive_T_Maze_Flag_SHORT_TERM_dense/2024_09_29-12_23_46.pt'\n",
    "checkpoint = torch.load(checkpoint_path)\n",
    "\n",
    "config['transformer']['memory_length'] = corridor_length\n",
    "\n",
    "config['transformer']['num_blocks'] = 6\n",
    "config['transformer']['embed_dim'] = 128\n",
    "config['transformer']['num_heads'] = 8\n",
    "config['hidden_layer_size'] = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = ActorCriticModel(config, env.observation_space, (env.action_space.n,), env.max_episode_steps).to(device)\n",
    "agent.load_state_dict(checkpoint[\"model_state_dict\"])\n",
    "agent.eval()\n",
    "agent = agent.to(device)\n",
    "torch.set_default_device(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import permutations\n",
    "\n",
    "def generate_permutations(nums):\n",
    "\n",
    "    perms = permutations(nums)\n",
    "    result = [int(''.join(map(str, perm))) for perm in perms]\n",
    "    \n",
    "    return result\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "### evaluate !\n",
    "\n",
    "videos_dir = '/opt/Memory-RL-Codebase/eval/Minigrid_Memory/GTRXL'\n",
    "\n",
    "nums = [1, 2, 3, 4, 5]\n",
    "eval_seeds = generate_permutations(nums)\n",
    "\n",
    "videos_limit = len(eval_seeds) + 1\n",
    "n_episode = len(eval_seeds)\n",
    "\n",
    "\n",
    "render = False\n",
    "\n",
    "total_reward = 0\n",
    "num_successes = 0\n",
    "total_steps = 0\n",
    "\n",
    "\n",
    "for i in range(n_episode):\n",
    "\n",
    "    if render:\n",
    "        frames = []\n",
    "\n",
    "    done = False\n",
    "    memory, memory_mask, memory_indices = init_transformer_memory(config[\"transformer\"], env.max_episode_steps, device)\n",
    "\n",
    "    memory = memory.to(device)\n",
    "    memory_mask = memory_mask.to(device)\n",
    "    memory_indices = memory_indices.to(device)\n",
    "\n",
    "\n",
    "    memory_length = config[\"transformer\"][\"memory_length\"]\n",
    "    # eval_seeds = config.get(\"eval_seeds\", None)\n",
    "    t = 0\n",
    "    ep_reward = 0\n",
    "\n",
    "    if eval_seeds is not None:\n",
    "        obs = env.reset(eval_seeds[i])    \n",
    "    else:\n",
    "        obs = env.reset()\n",
    "\n",
    "    if render and i < videos_limit:\n",
    "        rofl = env.render()\n",
    "        time.sleep(0.5)\n",
    "        frames.append(rofl)\n",
    "\n",
    "\n",
    "\n",
    "    while not done:\n",
    "        # Prepare observation and memory\n",
    "        obs = torch.tensor(np.expand_dims(obs, 0), dtype=torch.float32, device=device)\n",
    "        in_memory = memory[0, memory_indices[t].unsqueeze(0)]\n",
    "        t_ = max(0, min(t, memory_length - 1))\n",
    "        mask = memory_mask[t_].unsqueeze(0)\n",
    "        indices = memory_indices[t].unsqueeze(0)\n",
    "        # Forward model\n",
    "        policy, value, new_memory = agent(obs.to(device), in_memory.to(device), mask.to(device), indices.to(device))\n",
    "        memory[:, t] = new_memory\n",
    "        # Sample action\n",
    "        action = []\n",
    "        for action_branch in policy:\n",
    "            action.append(action_branch.sample().item())\n",
    "        # Step environemnt\n",
    "        # print(f'action: {action}')\n",
    "        obs, reward, done, info = env.step(action)\n",
    "        # print(f'Action :{action}, obs: {obs.shape}, reward: {reward}, terminated: {done}, info: {info}')\n",
    "        if render and i < videos_limit:\n",
    "            rofl = env.render()\n",
    "            if done:\n",
    "                print(f\"Episode terminated. Episode reward: {ep_reward}\")\n",
    "            time.sleep(0.5)\n",
    "            frames.append(rofl)\n",
    "\n",
    "\n",
    "\n",
    "        ep_reward += reward\n",
    "        t += 1\n",
    "\n",
    "\n",
    "    if info.get(\"is_success\"):\n",
    "        num_successes += 1\n",
    "    total_reward += ep_reward\n",
    "    total_steps += t\n",
    "\n",
    "    if render and i < videos_limit:\n",
    "        desired_resolution = (945, 540)\n",
    "        original_aspect_ratio = 112 / 64\n",
    "        width = int(desired_resolution[0] * original_aspect_ratio)\n",
    "        height = desired_resolution[1]\n",
    "\n",
    "        observations = [np.squeeze(o) for o in frames]\n",
    "\n",
    "        clip = ImageSequenceClip(observations, fps=2)\n",
    "        clip = clip.resize(width=width, height=height)\n",
    "\n",
    "\n",
    "        run_name = checkpoint_path.split('/')[-1].strip('.pt')\n",
    "        run_type = checkpoint_path.split('/')[-2]\n",
    "        curr_seed = eval_seeds[i]\n",
    "        curr_reward = float(info['reward'])\n",
    "\n",
    "        if not os.path.exists(videos_dir + f\"/{run_type}/{run_name}\"):\n",
    "            os.makedirs(videos_dir + f\"/{run_type}/{run_name}\")\n",
    "\n",
    "        clip.write_videofile(videos_dir + f\"/{run_type}/{run_name}/{run_name}_seed={curr_seed}_reward={curr_reward:0.2}.mp4\", fps=2)\n",
    "\n",
    "    curr_seed = eval_seeds[i]\n",
    "    print(f'Episode: {i}, seed: {curr_seed} Reward: {ep_reward}, Steps: {t} Mean reward: {total_reward / (i + 1)}, Mean steps: {total_steps / (i + 1)}')\n",
    "\n",
    "\n",
    "print(f'Total num episodes: {n_episode} Success rate: {num_successes / n_episode}, Mean reward: {total_reward / n_episode}, Mean steps: {total_steps / n_episode}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
