{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "from matplotlib.offsetbox import OffsetImage, AnnotationBbox\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "\n",
    "from tqdm import tqdm\n",
    "from typing import Any, Sequence, Union, List, Dict\n",
    "\n",
    "import numpy as np\n",
    "from torch import nn\n",
    "import torch as tr\n",
    "\n",
    "import copy\n",
    "\n",
    "from grid_env import GridEnv, GridSize, GridObservation\n",
    "from envs import FourRooms\n",
    "from grid_templates import GridTemplate\n",
    "from agents import GeneralQ, LambdaQ\n",
    "from runners import run_nn_experiment_episodic\n",
    "from utils import (\n",
    "    save_results,\n",
    "    load_results,\n",
    "    gif_from_frames,\n",
    "    set_seed_everywhere,\n",
    "    create_neuronav_gif,\n",
    "    plot_neuronav_frame,\n",
    "    errorfill,\n",
    "    plot_values,\n",
    "    plot_action_values,)\n",
    "\n",
    "# display options\n",
    "np.set_printoptions(precision=4, suppress=1)\n",
    "\n",
    "colors = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "# Define the color segments for the colormap\n",
    "segments = [(i/(len(colors)-1), colors[i]) for i in range(len(colors))]\n",
    "# Create a LinearSegmentedColormap from the color segments\n",
    "cmap = LinearSegmentedColormap.from_list(name='my_colormap', colors=segments)\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_lambda_ = 0.5\n",
    "\n",
    "goals = [12, 20, 67]\n",
    "goal_rewards = [10, 5, 5]\n",
    "env_lambda_ = 0.5\n",
    "discount = 0.99\n",
    "max_ep_len = 100\n",
    "\n",
    "env = GridEnv(\n",
    "    template=GridTemplate.two_rooms,\n",
    "    size=GridSize.small,\n",
    "    use_noop=True,\n",
    "    lambda_=env_lambda_,\n",
    "    obs_type=GridObservation.index,\n",
    "    im_size=128)\n",
    "\n",
    "goal_xys = [env.idx_to_state_coords(g) for g in goals]\n",
    "objects = {\"rewards\": dict(zip(goal_xys, goal_rewards))}\n",
    "\n",
    "agent = LambdaQ(\n",
    "  env.state_size,\n",
    "  env.action_space.n,\n",
    "  env.reset(objects=objects),\n",
    "  method='q',\n",
    "  double=False,\n",
    "  step_size=0.1,\n",
    "  use_ez_greedy=False,\n",
    "  epsilon=0.2,\n",
    "  optimistic_init=True,\n",
    "  decay_explore=None,\n",
    "  lambda_=agent_lambda_,)\n",
    "\n",
    "\n",
    "results = run_nn_experiment_episodic(\n",
    "  env, agent, 500, objects=objects, discount=discount,\n",
    "  terminate_on_reward=False, random_start=False, use_underlying_pos=False,\n",
    "  display_eps=10, respect_done=True, max_ep_len=max_ep_len, record=True,\n",
    "  goals_always_available=True, eval_every=20)\n",
    "\n",
    "\n",
    "lim = max_ep_len\n",
    "create_neuronav_gif(results['frames'][:lim], results['rewards_remaining_hist'][:lim],\n",
    "                  f\"{agent_lambda_}.gif\", return_sequence=results['ep_return_hist'][:lim],\n",
    "                  show_map=False, contains_map=False, num_goals=len(goals),\n",
    "                  max_reward=np.max(goal_rewards), max_value=200, duration=300,\n",
    "                  show_values=True, value_sequence=results['value_hist'][:lim])\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# actor critic\n",
    "from utils import DmEnvWrapper, smooth\n",
    "from runners import jax_run\n",
    "from lambda_rac import default_lambda_agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_lambda_ = 0.5\n",
    "print(f\"\\nTraining with lambda = {agent_lambda_} ==========================\")\n",
    "goals = [12, 20, 67]\n",
    "goal_rewards = [10, 5, 5]\n",
    "env_lambda_ = 0.5\n",
    "discount = 0.99\n",
    "max_ep_len = 100\n",
    "num_episodes = 7_500\n",
    "\n",
    "env = GridEnv(\n",
    "    template=GridTemplate.two_rooms,\n",
    "    size=GridSize.small,\n",
    "    use_noop=True,\n",
    "    lambda_=env_lambda_,\n",
    "    obs_type=GridObservation.onehot,\n",
    "    im_size=128)\n",
    "\n",
    "goal_xys = [env.idx_to_state_coords(g) for g in goals]\n",
    "objects = {\"rewards\": dict(zip(goal_xys, goal_rewards))}\n",
    "\n",
    "env = DmEnvWrapper(env, objects=objects, terminate_on_reward=False, random_start=False, discount=discount)\n",
    "\n",
    "# agent = default_agent(env.observation_spec(), env.action_spec(), seed=n)\n",
    "agent = default_lambda_agent(\n",
    "    env.observation_spec(),\n",
    "    env.action_spec(),\n",
    "    lf_lambda=agent_lambda_,\n",
    "    lf_wt=0.1,\n",
    "    hidden_size=128,\n",
    "    feature_dim=121,\n",
    "    entropy_cost=0.01)\n",
    "\n",
    "\n",
    "results = jax_run(agent, env, num_episodes, log_every=5, max_episode_len=max_ep_len)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jax_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
