{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b240965",
   "metadata": {},
   "outputs": [],
   "source": [
    "# MISC\n",
    "import os\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "os.environ['MUJOCO_GL']='egl'\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
    "\n",
    "# import shutup\n",
    "# shutup.please()\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import numpy as np\n",
    "from functools import partial\n",
    "\n",
    "# VIS\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import mediapy\n",
    "from rich.pretty import pprint\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from agents.dynamics_aware_iql import GCIQLAgent\n",
    "from hydra import initialize, compose\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "plt.style.use(['seaborn-v0_8-colorblind', 'seaborn-v0_8-notebook'])\n",
    "colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] \n",
    "\n",
    "GLOBAL_KEY = jax.random.key(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a9c4264",
   "metadata": {},
   "outputs": [],
   "source": [
    "from envs.custom_mazes.darkroom import FourRoomsMazeEnv, Maze\n",
    "\n",
    "test = FourRoomsMazeEnv(Maze(seed=42, maze_type='fourrooms_random_layouts'))\n",
    "test.reset()\n",
    "test.render(return_img=False) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c22afd62",
   "metadata": {},
   "outputs": [],
   "source": [
    "from envs.minigrid.env_utils import random_exploration_fourrooms, q_learning_fourrooms\n",
    "\n",
    "train_layout_data = []\n",
    "\n",
    "NUM_TRAIN_LAYOUTS = 1\n",
    "seeds = np.arange(0, NUM_TRAIN_LAYOUTS)\n",
    "\n",
    "NUM_TRAIN_STEPS = 100\n",
    "NUM_TRAIN_EPISODES = 1000\n",
    "\n",
    "for i in tqdm(range(NUM_TRAIN_LAYOUTS)):\n",
    "    env = FourRoomsMazeEnv(Maze(seed=0), max_steps=NUM_TRAIN_STEPS)\n",
    "    dataset, env = random_exploration_fourrooms(env, num_episodes=NUM_TRAIN_EPISODES, layout_type=i, num_mdp=NUM_TRAIN_LAYOUTS)\n",
    "    train_layout_data.append(dataset)\n",
    "    \n",
    "pprint(jax.tree.map(lambda x: x.shape, train_layout_data[0]))\n",
    "\n",
    "coverage_map = np.zeros(shape=env.maze.size)\n",
    "for layout in train_layout_data:\n",
    "    for obs in layout['observations']:\n",
    "        obs = obs.astype(np.int16)\n",
    "        coverage_map[obs[1], obs[0]] += 1\n",
    "        \n",
    "plt.imshow(coverage_map, cmap='inferno', vmin=0)\n",
    "plt.colorbar()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fb849c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "from utils.datasets import Dataset, GCDataset\n",
    "\n",
    "with initialize(version_base=None, config_path=\"../configs/\"):\n",
    "    fb_config = compose(config_name='entry.yaml', overrides=['experiment=fb_dynamics_discrete_4rooms.yaml',\n",
    "                                                            f'agent.number_of_meta_envs={NUM_TRAIN_LAYOUTS}',\n",
    "                                                            f'agent.discount=0.99',\n",
    "                                                            f'agent.z_mix_ratio=0.5'])\n",
    "    fb_config = OmegaConf.to_container(fb_config, resolve=True)\n",
    "    pprint(fb_config)\n",
    "\n",
    "def concatenate_dicts(dict1, dict2):\n",
    "    return jax.tree.map(lambda x, y: jnp.concatenate([x, y]), dict1, dict2)\n",
    "\n",
    "whole_data = functools.reduce(concatenate_dicts, train_layout_data)\n",
    "print(jax.tree.map(lambda x: x.shape, whole_data))\n",
    "whole_dataset = Dataset.create(**jax.device_get(whole_data))\n",
    "gc_whole_dataset = GCDataset(whole_dataset, config=fb_config['agent'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df01a13c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from envs.custom_mazes.env_utils import policy_image_fourrooms, value_image_fourrooms\n",
    "from functools import partial\n",
    "from utils.evaluation import supply_rng\n",
    "from IPython.display import clear_output\n",
    "\n",
    "from envs.env_utils import EpisodeMonitor\n",
    "from utils.evaluation import evaluate_fourrooms_dynamics\n",
    "from collections import defaultdict\n",
    "\n",
    "\n",
    "def visualize_value_image(env, layout_type, task_num):\n",
    "    env.reset()#options={\"start\": (1, 1)})\n",
    "    observation, info = env.setup_goals(seed=None, task_num=task_num)\n",
    "    goal = info.get(\"goal_pos\", None)\n",
    "    mdp_type=None\n",
    "    if fb_config['agent']['use_context']:\n",
    "        dataset_inference, env = random_exploration_fourrooms(env, num_episodes=1, layout_type=layout_type, num_mdp=NUM_TRAIN_LAYOUTS)\n",
    "        print(jax.tree.map(lambda x: x.shape, dataset_inference))\n",
    "        dynamics_embedding_mean, dynamics_mean_std = fb_agent.network.select('dynamic_transformer')(dataset_inference['observations'][None], dataset_inference['actions'][None, :, None],\n",
    "                                                                                    dataset_inference['next_observations'][None], train=False, return_embedding=True)\n",
    "        dynamics_embedding = dynamics_embedding_mean + jax.random.normal(key=GLOBAL_KEY, shape=dynamics_embedding_mean.shape) * jnp.exp(dynamics_mean_std)\n",
    "        dynamics_embedding = dynamics_embedding.squeeze()\n",
    "\n",
    "    latent_z = jax.device_get(fb_agent.infer_z(goal, mdp_num=None, dynamics_embedding=dynamics_embedding)[None])\n",
    "    N, M = env.maze.size\n",
    "    pred_value_img = value_image_fourrooms(env, example_batch,N=N, M=M,\n",
    "                                value_fn=partial(fb_agent.predict_q, z=latent_z, mdp_num=mdp_type[None] if mdp_type is not None else None,\n",
    "                                                dynamics_embedding=dynamics_embedding[None]),\n",
    "                                action_fn=None, goal=goal)\n",
    "    return pred_value_img\n",
    "\n",
    "def visualize_policy(env, layout_type, task_num):\n",
    "    env.reset()#options={\"start\": (1, 1)})\n",
    "    observation, info = env.setup_goals(seed=None, task_num=task_num)\n",
    "    goal = info.get(\"goal_pos\", None)\n",
    "    if fb_config['agent']['use_context']:\n",
    "        dataset_inference, env = random_exploration_fourrooms(env, num_episodes=1, layout_type=layout_type, num_mdp=NUM_TRAIN_LAYOUTS)\n",
    "        print(jax.tree.map(lambda x: x.shape, dataset_inference))\n",
    "        dynamics_embedding_mean, dynamics_mean_std = fb_agent.network.select('dynamic_transformer')(dataset_inference['observations'][None], dataset_inference['actions'][None,:,None],\n",
    "                                                                                    dataset_inference['next_observations'][None], train=False, return_embedding=True)\n",
    "        dynamics_embedding = dynamics_embedding_mean + jax.random.normal(key=GLOBAL_KEY, shape=dynamics_embedding_mean.shape) * jnp.exp(dynamics_mean_std)\n",
    "        dynamics_embedding=dynamics_embedding.squeeze()\n",
    "        mdp_type=None\n",
    "        \n",
    "    latent_z = fb_agent.infer_z(goal, mdp_num=mdp_type, dynamics_embedding=dynamics_embedding)\n",
    "    start = info.get(\"start_pos\", None)\n",
    "    example_batch = whole_dataset.sample(1)\n",
    "    mdp_type=None\n",
    "    N, M = env.maze.size\n",
    "    pred_policy_img = policy_image_fourrooms(env, example_batch, N=N, M=M,\n",
    "                                                    action_fn=partial(supply_rng(fb_agent.sample_actions,\n",
    "                                                                                rng=jax.random.PRNGKey(np.random.randint(0, 2**32))), latent_z=latent_z,\n",
    "                                                                    mdp_num=None, dynamics_embedding=dynamics_embedding[None], temperature=0.0),\n",
    "                                                    goal=goal)\n",
    "    return pred_policy_img\n",
    "\n",
    "pbar = tqdm(range(100_001))\n",
    "eval_history_train = []\n",
    "eval_history_test = []\n",
    "\n",
    "PRETRAIN_ENCODER_STEPS = 80_000\n",
    "for update_step in pbar:\n",
    "    batch = gc_whole_dataset.sample(fb_config['agent']['batch_size'], layout_type=None, context_length=NUM_TRAIN_STEPS, get_traj_batch=True)[1]\n",
    "    fb_agent, info = fb_agent.update(batch)\n",
    "    \n",
    "    if update_step % 20_000 == 0:\n",
    "        clear_output()\n",
    "        env = FourRoomsMazeEnv(Maze(seed=4, maze_type='fourrooms_random_layouts'), max_steps=NUM_TRAIN_STEPS)\n",
    "        fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(20, 10))\n",
    "        pred_policy_img = visualize_policy(env, layout_type=0, task_num=0)\n",
    "        \n",
    "        ax[0, 0].imshow(pred_policy_img)\n",
    "        pred_value_img = visualize_value_image(env, layout_type=0, task_num=0)\n",
    "        ax[0, 1].imshow(pred_value_img)\n",
    "        \n",
    "        env = FourRoomsMazeEnv(Maze(seed=3, maze_type='fourrooms_random_layouts'), max_steps=NUM_TRAIN_STEPS)\n",
    "        pred_policy_img = visualize_policy(env, layout_type=1, task_num=2)\n",
    "        ax[1, 0].imshow(pred_policy_img)\n",
    "        pred_value_img = visualize_value_image(env, layout_type=1, task_num=2)\n",
    "        ax[1, 1].imshow(pred_value_img)\n",
    "        \n",
    "        fig.suptitle(f\"Training step: {update_step}\")   \n",
    "        plt.tight_layout()\n",
    "        display(fig)\n",
    "        plt.close(fig)\n",
    "    \n",
    "    if (update_step > PRETRAIN_ENCODER_STEPS and update_step % 20_000 == 0) or update_step == PRETRAIN_ENCODER_STEPS:\n",
    "        eval_metrics = {}\n",
    "        overall_metrics = defaultdict(list)\n",
    "\n",
    "        for task_id in range(4): # static for 4 rooms\n",
    "            for env_id in range(fb_config['agent']['number_of_meta_envs']):\n",
    "                env = FourRoomsMazeEnv(Maze(seed=env_id, maze_type='fourrooms_random_layouts'), max_steps=fb_config['agent']['context_len'])\n",
    "                env = EpisodeMonitor(env, filter_regexes=['.*privileged.*', '.*proprio.*'])\n",
    "                env.reset()#options={\"start\": (1, 1)})\n",
    "                dataset_inference, env = random_exploration_fourrooms(env, num_episodes=1, layout_type=0, num_mdp=fb_config['agent']['number_of_meta_envs'])\n",
    "                dynamics_embedding_mean, dynamics_mean_std = fb_agent.network.select('dynamic_transformer')(dataset_inference['observations'][None], dataset_inference['actions'][None,:,None],\n",
    "                                                                                            dataset_inference['next_observations'][None], train=False, return_embedding=True)\n",
    "                dynamics_embedding = dynamics_embedding_mean + jax.random.normal(key=GLOBAL_KEY, shape=dynamics_embedding_mean.shape) * jnp.exp(dynamics_mean_std)\n",
    "                dynamics_embedding=dynamics_embedding.squeeze()\n",
    "                eval_info, _, _ = evaluate_fourrooms_dynamics(\n",
    "                        agent=fb_agent,\n",
    "                        dynamics_embedding=dynamics_embedding,\n",
    "                        env=env,\n",
    "                        task_id=task_id,\n",
    "                        config=None,\n",
    "                        num_eval_episodes=10,\n",
    "                        num_video_episodes=0,\n",
    "                        video_frame_skip=1,\n",
    "                        eval_temperature=0.0,\n",
    "                        eval_gaussian=None\n",
    "                    )\n",
    "                eval_metrics.update(\n",
    "                    {f'evaluation/task_{task_id}_{k}': v for k, v in eval_info.items() if k != 'total.timesteps'}\n",
    "                )\n",
    "                for k, v in eval_info.items():\n",
    "                    overall_metrics[k].append(v)\n",
    "                    \n",
    "        for k, v in overall_metrics.items():\n",
    "            eval_metrics[f'evaluation/overall_{k}_train'] = np.mean(v)\n",
    "            \n",
    "        eval_history_train.append(eval_metrics['evaluation/overall_episode.final_reward_train'])\n",
    "        \n",
    "        eval_metrics = {}\n",
    "        overall_metrics = defaultdict(list)\n",
    "                    \n",
    "        for task_id in range(4):\n",
    "            for env_id in range(NUM_TRAIN_LAYOUTS+50, NUM_TRAIN_LAYOUTS + 60):\n",
    "                env = FourRoomsMazeEnv(Maze(seed=env_id, maze_type='fourrooms_random_layouts'), max_steps=fb_config['agent']['context_len'])\n",
    "                env = EpisodeMonitor(env, filter_regexes=['.*privileged.*', '.*proprio.*'])\n",
    "                env.reset()#options={\"start\": (1, 1)})\n",
    "                dataset_inference, env = random_exploration_fourrooms(env, num_episodes=1, layout_type=0, num_mdp=NUM_TRAIN_LAYOUTS)\n",
    "                dynamics_embedding_mean, dynamics_mean_std = fb_agent.network.select('dynamic_transformer')(dataset_inference['observations'][None], dataset_inference['actions'][None,:,None],\n",
    "                                                                                            dataset_inference['next_observations'][None], train=False, return_embedding=True)\n",
    "                dynamics_embedding = dynamics_embedding_mean + jax.random.normal(key=GLOBAL_KEY, shape=dynamics_embedding_mean.shape) * jnp.exp(dynamics_mean_std)\n",
    "                dynamics_embedding=dynamics_embedding.squeeze()\n",
    "                eval_info, _, _ = evaluate_fourrooms_dynamics(\n",
    "                        agent=fb_agent,\n",
    "                        dynamics_embedding=dynamics_embedding,\n",
    "                        env=env,\n",
    "                        task_id=task_id,\n",
    "                        config=None,\n",
    "                        num_eval_episodes=20, ##\n",
    "                        num_video_episodes=0,\n",
    "                        video_frame_skip=1,\n",
    "                        eval_temperature=0.0,\n",
    "                        eval_gaussian=None\n",
    "                    )\n",
    "                eval_metrics.update(\n",
    "                    {f'evaluation/task_{task_id}_{k}': v for k, v in eval_info.items() if k != 'total.timesteps'}\n",
    "                    )\n",
    "                for k, v in eval_info.items():\n",
    "                    overall_metrics[k].append(v)\n",
    "                    \n",
    "        for k, v in overall_metrics.items():\n",
    "            eval_metrics[f'evaluation/overall_{k}_ood'] = np.mean(v)\n",
    "        eval_history_test.append(eval_metrics['evaluation/overall_episode.final_reward_ood'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ad340f1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2bbed85",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61fcfcd8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jax2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
