{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sparse point robot navigation (adopted from https://github.com/katerakelly/oyster)\n",
    "In this task, a point robot must navigate under sparse rewards to different locations on the x-y plane. Reward is zero outside the goal region, and shaped inside. The goal distribution is chosen to be the half unit circle.\n",
    "\n",
    "This notebook is for visualizing rollouts collected from a learned policy.\n",
    "\n",
    "\n",
    "For now, you must first train the policy.\n",
    " - edit `sparse-point-robot.json` to add `dump_eval_paths=1`\n",
    " - run `python launch_experiment.py ./configs/sparse-point-robot.json`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import csv\n",
    "import pickle\n",
    "import os\n",
    "import colour\n",
    "\n",
    "# config\n",
    "exp_id = '' # EDIT THIS\n",
    "tlow, thigh = 81, 100 # task ID range \n",
    "# see `n_tasks` and `n_eval_tasks` args in the training config json \n",
    "# by convention, the test tasks are always the last `n_eval_tasks` IDs\n",
    "# so if there are 100 tasks total, and 20 test tasks, the test tasks will be IDs 81-100\n",
    "epoch = 100 # training epoch to load data from\n",
    "gr = 0.2 # goal radius, for visualization purposes\n",
    "\n",
    "expdir = 'output/sparse-point-robot/{}/eval_trajectories/'.format(exp_id) # directory to load data from\n",
    "\n",
    "# helpers\n",
    "def load_pkl(task):\n",
    "    with open(os.path.join(expdir, 'task{}-epoch{}-run0.pkl'.format(task, epoch)), 'rb') as f:\n",
    "        data = pickle.load(f)\n",
    "    return data\n",
    "\n",
    "def load_pkl_prior():\n",
    "    with open(os.path.join(expdir, 'prior-epoch{}.pkl'.format(epoch)), 'rb') as f:\n",
    "        data = pickle.load(f)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Trajectories sampled from the meta-learned prior\n",
    "Over the course of training the distribution over the latent context variable is optimized to represent the task distribution. Samples from the prior should look relatively uniform - visually we should see trajectories of the agent navigating to all the goal regions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "paths = load_pkl_prior()\n",
    "goals = [load_pkl(task)[0]['goal'] for task in range(tlow, thigh)]\n",
    "\n",
    "plt.figure(figsize=(8,8))\n",
    "axes = plt.axes()\n",
    "axes.set(aspect='equal')\n",
    "plt.axis([-1.25, 1.25, -0.25, 1.25])\n",
    "for g in goals:\n",
    "    circle = plt.Circle((g[0], g[1]), radius=gr)\n",
    "    axes.add_artist(circle)\n",
    "rewards = 0\n",
    "final_rewards = 0\n",
    "for traj in paths:\n",
    "    rewards += sum(traj['rewards'])\n",
    "    final_rewards += traj['rewards'][-1]\n",
    "    states = traj['observations']\n",
    "    plt.plot(states[:-1, 0], states[:-1, 1], '-o')\n",
    "    plt.plot(states[-1, 0], states[-1, 1], '-x', markersize=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Online adaptation trajectories\n",
    "Online adaptation proceeds by posterior sampling: roll out a trajectory, add it to the pool of context seen so far in the task, update the posterior over the latent context variable via feed-forward inference, and sample a new latent context from the posterior. \n",
    "In this visualization, each figure shows the agent in a different task (goal region shown in dark blue). Early (dark) trajectories should resemble samples from the prior, while later (light) trajectories that are conditioned on more context should navigate to the correct goal."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "mpl = 20\n",
    "num_trajs = 60\n",
    "\n",
    "all_paths = []\n",
    "for task in range(tlow, thigh):\n",
    "    paths = [t['observations'] for t in load_pkl(task)]\n",
    "    all_paths.append(paths)\n",
    "\n",
    "# color trajectories in order they were collected\n",
    "cmap = matplotlib.cm.get_cmap('plasma')\n",
    "sample_locs = np.linspace(0, 0.9, num_trajs)\n",
    "colors = [cmap(s) for s in sample_locs]\n",
    "\n",
    "fig, axes = plt.subplots(5, 2, figsize=(12, 20))\n",
    "t = 0\n",
    "for j in range(2):\n",
    "    for i in range(5): \n",
    "        axes[i, j].set_xlim([-1.25, 1.25])\n",
    "        axes[i, j].set_ylim([-0.25, 1.25])\n",
    "        for k, g in enumerate(goals):\n",
    "            alpha = 1 if k == t else 0.2\n",
    "            circle = plt.Circle((g[0], g[1]), radius=gr, alpha=alpha)\n",
    "            axes[i, j].add_artist(circle)\n",
    "        indices = list(np.linspace(0, len(all_paths[t]), num_trajs, endpoint=False).astype(np.int))\n",
    "        counter = 0\n",
    "        for idx in indices:\n",
    "            states = all_paths[t][idx]\n",
    "            axes[i, j].plot(states[:-1, 0], states[:-1, 1], '-', color=colors[counter])\n",
    "            axes[i, j].plot(states[-1, 0], states[-1, 1], '-x', markersize=10, color=colors[counter])\n",
    "            axes[i, j].set(aspect='equal')\n",
    "            counter += 1\n",
    "        t += 1"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
