{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d49d5c60",
   "metadata": {},
   "outputs": [],
   "source": [
    "# comment out `import tensorflow as tf` from `mimi/utils.py` and `mimi/models.py`, \n",
    "# otherwise hand tracking setup will hang"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "064d3455",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3352e1cd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "\n",
    "from copy import deepcopy\n",
    "import pickle\n",
    "import os\n",
    "import uuid\n",
    "\n",
    "import scipy\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib as mpl\n",
    "\n",
    "from mimi import envs\n",
    "from mimi import utils\n",
    "from mimi import user_models\n",
    "from mimi import opt\n",
    "from mimi import reward_models\n",
    "from mimi import models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a264006d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_dir = os.path.join(utils.data_dir, 'mnist')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61e233ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_model = user_models.HumanHandUser()\n",
    "import tensorflow as tf\n",
    "utils.tf = tf\n",
    "models.tf = tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4354834",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "sess = utils.make_tf_session(gpu_mode=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f217f8ca",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8403d0e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_model = user_models.HumanHandUser()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83aa7ba1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "env = envs.MNISTEnv(\n",
    "  sess,\n",
    "  user_model,\n",
    "  max_ep_len=500,\n",
    "  reset_delay=1,\n",
    "  min_pos=-10,\n",
    "  max_pos=10,\n",
    "  goal_dist_thresh=-0.3\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceae69b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "mi_model_init_args = [sess]\n",
    "mi_model_init_kwargs = {\n",
    "  'n_env_obs_dim': env.n_env_obs_dim,\n",
    "  'n_user_obs_dim': env.n_user_obs_dim,\n",
    "  'n_act_dim': env.n_act_dim,\n",
    "  'n_layers': 2,\n",
    "  'layer_size': 64\n",
    "}\n",
    "mi_model_train_kwargs = {\n",
    "  'iterations': 1000,\n",
    "  'ftol': 1e-6,\n",
    "  'learning_rate': 1e-3,\n",
    "  'batch_size': 64,\n",
    "  'val_update_freq': None,\n",
    "  'verbose': False,\n",
    "  'warm_start': False\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b638df84",
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_model = reward_models.MIRewardModel(\n",
    "  env,\n",
    "  mi_model_init_args,\n",
    "  mi_model_init_kwargs,\n",
    "  mi_model_train_kwargs\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0933c31b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def true_rew_of_rollout(rollout):\n",
    "  rews = [-x[-1]['goal_dist'] for x in rollout]\n",
    "  p = len(rollout) / env.max_ep_len\n",
    "  return p * np.mean(rews) + (1-p) * rews[-1]\n",
    "true_reward_model = lambda rollouts: np.mean([true_rew_of_rollout(rollout) for rollout in rollouts])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70a52d6c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "292849fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "gp_optimizer = opt.GP(\n",
    "  env,\n",
    "  reward_model,\n",
    "  param_bounds=(-2., 2.) # DEBUG\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "430e7c7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "gp_min_kwargs = {'n_initial_points': 3}\n",
    "ep_kwargs = {'init_delay': 1, 'render': True}\n",
    "n_eps_per_pol = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5e8cf5b",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "gp_policy, res = gp_optimizer.run(\n",
    "  n_pols=50,\n",
    "  n_eps_per_pol=n_eps_per_pol,\n",
    "  gp_min_kwargs=gp_min_kwargs,\n",
    "  ep_kwargs=ep_kwargs,\n",
    "  reward_model_train_kwargs=mi_model_train_kwargs\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "757436f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_data_of_pol.extend(gp_optimizer.eval_data_of_pol)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82aa30b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_data_of_pol = gp_optimizer.eval_data_of_pol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fe22528",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_id = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f08a977e",
   "metadata": {},
   "outputs": [],
   "source": [
    "gp_results_path = os.path.join(data_dir, user_id, 'gp_results.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "516d719c",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(gp_results_path, 'wb') as f:\n",
    "  pickle.dump(eval_data_of_pol, f, pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af61ce39",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(gp_results_path, 'rb') as f:\n",
    "  eval_data_of_pol = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e875623",
   "metadata": {},
   "outputs": [],
   "source": [
    "x0, _, y0 = [list(z) for z in zip(*eval_data_of_pol)]\n",
    "gp_min_kwargs.update({'x0': x0, 'y0': y0})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f31c9b85",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5b8ea0f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "perf_evals = []\n",
    "for user_id in range(1):\n",
    "  user_path = os.path.join(data_dir, str(user_id))\n",
    "  if os.path.exists(user_path):\n",
    "    gp_results_path = os.path.join(user_path, 'gp_results.pkl')\n",
    "    with open(gp_results_path, 'rb') as f:\n",
    "      eval_data_of_pol = pickle.load(f)\n",
    "    true_rews_of_pol = [[true_reward_model([rollout]) for rollout in eval_data[1]] for eval_data in eval_data_of_pol]\n",
    "    true_rew_of_pol = [true_reward_model(eval_data[1]) for eval_data in eval_data_of_pol]\n",
    "    rew_of_pol = [eval_data[2] for eval_data in eval_data_of_pol]\n",
    "    \n",
    "    baseline_true_rewards = true_rew_of_pol[:3]\n",
    "    baseline_rewards = rew_of_pol[:3]\n",
    "    true_rews = sum(true_rews_of_pol, [])\n",
    "    perf_evals.append({\n",
    "      'true_rews': true_rews, \n",
    "      'true_rew': true_rew_of_pol,\n",
    "      'rews': rew_of_pol,\n",
    "      'rew_xs': np.arange(0, len(rew_of_pol), 1),\n",
    "      'true_xs': np.cumsum(np.ones(len(true_rews)))-1,\n",
    "      'baseline_true_reward': np.mean(baseline_true_rewards),\n",
    "      'baseline_reward': np.mean(baseline_rewards)\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb2cad18",
   "metadata": {},
   "outputs": [],
   "source": [
    "mpl.rcParams.update({'font.size': 12})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a774389",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('MNIST Latent Space Navigation with Hand Gestures')\n",
    "plt.xlabel('Number of Online Training Episodes')\n",
    "plt.ylabel('True Reward (Margin Between \\nGoal Prob. and Next Largest Prob.)')\n",
    "utils.plot_perf_evals(perf_evals, 'true_xs', 'true_rews', label='MIMI (Ours)', smooth_win=10, color='orange')\n",
    "plt.axhline(y=np.mean([perf['baseline_true_reward'] for perf in perf_evals]), linestyle='--', color='gray', label='Random Interfaces (Baseline)')\n",
    "plt.legend(loc='upper left')\n",
    "#plt.savefig(os.path.join(data_dir, 'mnist-truerew-vs-eps.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c39378c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('Mutual Information Reward')\n",
    "plt.xlabel('Number of Online Training Episodes')\n",
    "plt.ylabel(r\"$\\mathcal{I}(\\mathbf{x}_t, (\\mathbf{s}_t, \\mathbf{s}_{t+1}))$\")\n",
    "utils.plot_perf_evals(perf_evals, 'rew_xs', 'rews', label='MIMI (Ours)', smooth_win=1, color='teal')\n",
    "plt.axhline(y=np.mean([perf['baseline_reward'] for perf in perf_evals]), linestyle='--', color='gray', label='Random Interfaces (Baseline)')\n",
    "plt.legend(loc='upper left')\n",
    "#plt.savefig(os.path.join(data_dir, 'mnist-study-rew-vs-eps.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66c66548",
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = sum([perf_eval['true_rew'] for perf_eval in perf_evals], [])\n",
    "ys = sum([perf_eval['rews'] for perf_eval in perf_evals], [])\n",
    "rho = scipy.stats.spearmanr(xs, ys)[0]\n",
    "plt.title(r'Success Rate vs. Mutual Information Reward ($\\rho = %0.2f$)' % rho)\n",
    "plt.xlabel('Success Rate')\n",
    "plt.ylabel(r\"$\\mathcal{I}(\\mathbf{x}_t, (\\mathbf{s}_t, \\mathbf{s}_{t+1}))$\")\n",
    "plt.scatter(\n",
    "  xs, \n",
    "  ys,\n",
    "  color='orange',\n",
    "  s=50,\n",
    "  alpha=0.5\n",
    ")\n",
    "plt.xticks(fontsize=10)\n",
    "#plt.savefig(os.path.join(data_dir, 'mnist-study-truerew-vs-mi.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "027c7b00",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
