{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d49d5c60",
   "metadata": {},
   "outputs": [],
   "source": [
    "# comment out `import tensorflow as tf` from `mimi/utils.py` and `mimi/models.py`, \n",
    "# otherwise hand tracking setup will hang"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "064d3455",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3352e1cd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "\n",
    "from copy import deepcopy\n",
    "import pickle\n",
    "import os\n",
    "import uuid\n",
    "\n",
    "import scipy\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib as mpl\n",
    "\n",
    "from mimi import envs\n",
    "from mimi import utils\n",
    "from mimi import user_models\n",
    "from mimi import opt\n",
    "from mimi import reward_models\n",
    "from mimi import models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a264006d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_dir = os.path.join(utils.data_dir, 'lander')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61e233ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_model = user_models.HumanHandUser()\n",
    "import tensorflow as tf\n",
    "utils.tf = tf\n",
    "models.tf = tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4354834",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "sess = utils.make_tf_session(gpu_mode=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5294e18",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3facc5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_ep_len = 500\n",
    "step_delay = 0.02\n",
    "reset_delay = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8403d0e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_model = user_models.HumanHandUser()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83aa7ba1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "env = envs.LanderEnv(\n",
    "  sess,\n",
    "  user_model,\n",
    "  max_ep_len=max_ep_len,\n",
    "  reset_delay=reset_delay,\n",
    "  step_delay=step_delay\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceae69b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "mi_model_init_args = [sess]\n",
    "mi_model_init_kwargs = {\n",
    "  'n_env_obs_dim': env.n_min_env_obs_dim,\n",
    "  'n_user_obs_dim': env.n_user_obs_dim,\n",
    "  'n_act_dim': env.n_act_dim,\n",
    "  'n_layers': 2,\n",
    "  'layer_size': 64\n",
    "}\n",
    "mi_model_train_kwargs = {\n",
    "  'iterations': 1000,\n",
    "  'ftol': 1e-6,\n",
    "  'learning_rate': 1e-3,\n",
    "  'batch_size': 64,\n",
    "  'val_update_freq': None,\n",
    "  'verbose': False,\n",
    "  'warm_start': False\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b638df84",
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_model = reward_models.MIRewardModel(\n",
    "  env,\n",
    "  mi_model_init_args,\n",
    "  mi_model_init_kwargs,\n",
    "  mi_model_train_kwargs,\n",
    "  use_min_env_obs=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0933c31b",
   "metadata": {},
   "outputs": [],
   "source": [
    "true_rew_of_rollout = lambda rollout: 1 if rollout[-1][-1]['succ'] else 0\n",
    "true_reward_model = lambda rollouts: np.mean([true_rew_of_rollout(rollout) for rollout in rollouts])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70a52d6c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "292849fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "gp_optimizer = opt.GP(\n",
    "  env,\n",
    "  reward_model,\n",
    "  param_bounds=(-1., 1.)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "430e7c7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "gp_min_kwargs = {'n_initial_points': 3}\n",
    "ep_kwargs = {'init_delay': 1, 'render': True}\n",
    "n_eps_per_pol = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5e8cf5b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "gp_policy, res = gp_optimizer.run(\n",
    "  n_pols=50,\n",
    "  n_eps_per_pol=n_eps_per_pol,\n",
    "  gp_min_kwargs=gp_min_kwargs,\n",
    "  ep_kwargs=ep_kwargs,\n",
    "  reward_model_train_kwargs=mi_model_train_kwargs\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "757436f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#eval_data_of_pol.extend(gp_optimizer.eval_data_of_pol)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82aa30b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_data_of_pol = gp_optimizer.eval_data_of_pol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fe22528",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_id = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f08a977e",
   "metadata": {},
   "outputs": [],
   "source": [
    "gp_results_path = os.path.join(data_dir, user_id, 'gp_results.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "516d719c",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(gp_results_path, 'wb') as f:\n",
    "  pickle.dump(eval_data_of_pol, f, pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af61ce39",
   "metadata": {},
   "outputs": [],
   "source": [
    "#with open(gp_results_path, 'rb') as f:\n",
    "#  eval_data_of_pol = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e875623",
   "metadata": {},
   "outputs": [],
   "source": [
    "#x0, _, y0 = [list(z) for z in zip(*eval_data_of_pol)]\n",
    "#gp_min_kwargs.update({'x0': x0, 'y0': y0})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f31c9b85",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5b8ea0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "perf_evals = []\n",
    "for user_id in range(3):\n",
    "  user_path = os.path.join(data_dir, str(user_id))\n",
    "  if os.path.exists(user_path):\n",
    "    gp_results_path = os.path.join(user_path, 'gp_results.pkl')\n",
    "    with open(gp_results_path, 'rb') as f:\n",
    "      eval_data_of_pol = pickle.load(f)\n",
    "    true_rews_of_pol = [[true_reward_model([rollout]) for rollout in eval_data[1]] for eval_data in eval_data_of_pol]\n",
    "    true_rew_of_pol = [true_reward_model(eval_data[1]) for eval_data in eval_data_of_pol]\n",
    "    rew_of_pol = [eval_data[2] for eval_data in eval_data_of_pol]\n",
    "    \n",
    "    baseline_true_rewards = true_rew_of_pol[:3]\n",
    "    true_rews = sum(true_rews_of_pol, [])\n",
    "    perf_evals.append({\n",
    "      'true_rews': true_rews, \n",
    "      'true_rew': true_rew_of_pol,\n",
    "      'rews': rew_of_pol,\n",
    "      'rew_xs': np.cumsum([len(eval_data[1]) for eval_data in eval_data_of_pol]),\n",
    "      'true_xs': np.cumsum(np.ones(len(true_rews)))-1,\n",
    "      'baseline_true_reward': np.mean(baseline_true_rewards),\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb2cad18",
   "metadata": {},
   "outputs": [],
   "source": [
    "mpl.rcParams.update({'font.size': 12})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a774389",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('Lunar Lander with Hand Gestures')\n",
    "plt.xlabel('Number of Online Training Episodes')\n",
    "plt.ylabel('Success Rate')\n",
    "utils.plot_perf_evals(perf_evals, 'true_xs', 'true_rews', label='MIMI (Ours)', smooth_win=10, color='orange')\n",
    "plt.axhline(y=np.mean([perf['baseline_true_reward'] for perf in perf_evals]), linestyle='--', color='gray', label='Random Interfaces (Baseline)')\n",
    "plt.legend(loc='upper left')\n",
    "#plt.savefig(os.path.join(data_dir, 'lander-study-truerew-vs-eps.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b735e91",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('Mutual Information Reward')\n",
    "plt.xlabel('Number of Online Training Episodes')\n",
    "plt.ylabel(r\"$\\mathcal{I}(\\mathbf{x}_t, (\\mathbf{s}_t, \\mathbf{s}_{t+1}))$\")\n",
    "utils.plot_perf_evals(perf_evals, 'rew_xs', 'rews', label='MIMI (Ours)', smooth_win=1, color='teal')\n",
    "plt.legend(loc='upper left')\n",
    "plt.savefig(os.path.join(data_dir, 'lander-study-rew-vs-eps.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15f4e6cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = sum([perf_eval['true_rew'] for perf_eval in perf_evals], [])\n",
    "ys = sum([perf_eval['rews'] for perf_eval in perf_evals], [])\n",
    "rho = scipy.stats.spearmanr(xs, ys)[0]\n",
    "plt.title(r'Success Rate vs. Mutual Information Reward ($\\rho = %0.2f$)' % rho)\n",
    "plt.xlabel('Success Rate')\n",
    "plt.ylabel(r\"$\\mathcal{I}(\\mathbf{x}_t, (\\mathbf{s}_t, \\mathbf{s}_{t+1}))$\")\n",
    "plt.scatter(\n",
    "  xs, \n",
    "  ys,\n",
    "  color='orange',\n",
    "  s=50,\n",
    "  alpha=0.5\n",
    ")\n",
    "plt.xticks(fontsize=10)\n",
    "plt.savefig(os.path.join(data_dir, 'lander-study-truerew-vs-mi.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73fa6888",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
