{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "064d3455",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3352e1cd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "\n",
    "import pickle\n",
    "import os\n",
    "import time\n",
    "\n",
    "import scipy\n",
    "import numpy as np\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib as mpl\n",
    "\n",
    "from mimi import envs\n",
    "from mimi import utils\n",
    "from mimi import user_models\n",
    "from mimi import opt\n",
    "from mimi import reward_models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "982cca90",
   "metadata": {},
   "outputs": [],
   "source": [
    "sess = utils.make_tf_session(gpu_mode=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55d2380a",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_id = 'pilot'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a264006d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_dir = os.path.join(utils.data_dir, 'cursor', user_id)\n",
    "if not os.path.exists(data_dir):\n",
    "  os.makedirs(data_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dae4521e",
   "metadata": {},
   "outputs": [],
   "source": [
    "win_dims = np.array([1, 1]) * 1000\n",
    "max_ep_len = 300\n",
    "speed = 0.02\n",
    "goal_dist_thresh = 0.05\n",
    "reset_delay = 0\n",
    "step_delay = 0.075"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cfc33fa",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "user_model = user_models.HumanMouseUser(win_dims=win_dims, step_delay=step_delay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83aa7ba1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "env = envs.CursorEnv(\n",
    "  sess,\n",
    "  user_model,\n",
    "  max_ep_len=max_ep_len,\n",
    "  goal_dist_thresh=goal_dist_thresh,\n",
    "  speed=speed,\n",
    "  win_dims=win_dims,\n",
    "  reset_delay=reset_delay\n",
    ")\n",
    "\n",
    "utils.prep_env_for_human_user(env, user_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b638df84",
   "metadata": {},
   "outputs": [],
   "source": [
    "mi_model_init_args = [sess]\n",
    "mi_model_init_kwargs = {\n",
    "  'n_env_obs_dim': env.n_env_obs_dim,\n",
    "  'n_user_obs_dim': env.n_user_obs_dim,\n",
    "  'n_act_dim': env.n_act_dim,\n",
    "  'n_layers': 2,\n",
    "  'layer_size': 64\n",
    "}\n",
    "mi_model_train_kwargs = {\n",
    "  'iterations': 1000,\n",
    "  'ftol': 1e-6,\n",
    "  'learning_rate': 1e-3,\n",
    "  'batch_size': 64,\n",
    "  'val_update_freq': None,\n",
    "  'verbose': False,\n",
    "  'warm_start': False\n",
    "}\n",
    "reward_model = reward_models.MIRewardModel(\n",
    "  env,\n",
    "  mi_model_init_args,\n",
    "  mi_model_init_kwargs,\n",
    "  mi_model_train_kwargs\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed4d1bfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def true_rew_of_rollout(rollout):\n",
    "  goal = rollout[-1][-1]['goal']\n",
    "  rews = [-np.linalg.norm(x[0][:2]-goal) for x in rollout]\n",
    "  p = len(rollout) / env.max_ep_len\n",
    "  return p * np.mean(rews) + (1-p) * rews[-1]\n",
    "true_reward_model = lambda rollouts: np.mean([true_rew_of_rollout(rollout) for rollout in rollouts])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18a651f1",
   "metadata": {},
   "source": [
    "**Pause here for instructions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39ed8231",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_rand_policy():\n",
    "  ang = np.random.random() * 2*np.pi\n",
    "  return lambda obs: utils.rotate_vec(env.extract_user_obses(obs[np.newaxis])[0], ang)\n",
    "\n",
    "n_policies = 5\n",
    "n_rollouts_per_policy = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00b4651e",
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_rollouts = [[] for _ in range(n_policies)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe644945",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i in range(n_policies):\n",
    "  policy = make_rand_policy()\n",
    "  time.sleep(5)\n",
    "  while len(baseline_rollouts[i]) < n_rollouts_per_policy:\n",
    "    rollout = utils.run_ep(policy, env, render=True, init_delay=1)\n",
    "    baseline_rollouts[i].append(rollout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c27ea236",
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_path = os.path.join(data_dir, 'baseline_rollouts.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d714b611",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(baseline_path, 'wb') as f:\n",
    "  pickle.dump(baseline_rollouts, f, pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75660df6",
   "metadata": {},
   "outputs": [],
   "source": [
    "true_reward_model(sum(baseline_rollouts, []))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "822ad89d",
   "metadata": {},
   "source": [
    "**Pause here for instructions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8199d5cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "gp_optimizer = opt.GP(\n",
    "  env,\n",
    "  reward_model,\n",
    "  param_bounds=(0, 2*np.pi),\n",
    "  n_policy_params=1,\n",
    "  W_from_w=(lambda w: np.array([[np.cos(w[0]), -np.sin(w[0])], [np.sin(w[0]), np.cos(w[0])]]))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11fb4ab8",
   "metadata": {},
   "outputs": [],
   "source": [
    "gp_min_kwargs = {'n_initial_points': 5}\n",
    "ep_kwargs = {'init_delay': 1, 'render': True}\n",
    "n_eps_per_pol = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f10261e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#x0, _, y0 = zip(*eval_data_of_pol)\n",
    "#gp_min_kwargs.update({'x0': list(x0), 'y0': list(y0)})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "def83b8f",
   "metadata": {},
   "source": [
    "**Pause here for instructions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed369c31",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "gp_policy, res = gp_optimizer.run(\n",
    "  n_pols=50,\n",
    "  n_eps_per_pol=n_eps_per_pol,\n",
    "  gp_min_kwargs=gp_min_kwargs,\n",
    "  ep_kwargs=ep_kwargs,\n",
    "  reward_model_train_kwargs=mi_model_train_kwargs\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "299cd597",
   "metadata": {},
   "source": [
    "**Pause here for instructions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5396785",
   "metadata": {},
   "outputs": [],
   "source": [
    "#eval_data_of_pol.extend(gp_optimizer.eval_data_of_pol)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "557215ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_data_of_pol = gp_optimizer.eval_data_of_pol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e59ca574",
   "metadata": {},
   "outputs": [],
   "source": [
    "gp_results_path = os.path.join(data_dir, 'gp_results.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2452b79a",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(gp_results_path, 'wb') as f:\n",
    "  pickle.dump(eval_data_of_pol, f, pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e632ee9",
   "metadata": {},
   "source": [
    "**Pause here for instructions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e767e9fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = os.path.join(utils.data_dir, 'cursor')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "708c96fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "perf_evals = []\n",
    "pols = []\n",
    "\n",
    "for user_id in range(12):\n",
    "  user_path = os.path.join(data_dir, str(user_id))\n",
    "  if os.path.exists(user_path):\n",
    "    baseline_path = os.path.join(user_path, 'baseline_rollouts.pkl')\n",
    "    with open(baseline_path, 'rb') as f:\n",
    "      baseline_rollouts = pickle.load(f)\n",
    "    baseline_rollouts = sum(baseline_rollouts, [])\n",
    "    baseline_true_rewards = true_reward_model(baseline_rollouts)\n",
    "    \n",
    "    gp_results_path = os.path.join(user_path, 'gp_results.pkl')\n",
    "    with open(gp_results_path, 'rb') as f:\n",
    "      eval_data_of_pol = pickle.load(f)\n",
    "    true_rews_of_pol = [[true_reward_model([rollout]) for rollout in eval_data[1]] for eval_data in eval_data_of_pol]\n",
    "    true_rew_of_pol = [true_reward_model(eval_data[1]) for eval_data in eval_data_of_pol]\n",
    "    rew_of_pol = [eval_data[2] for eval_data in eval_data_of_pol]\n",
    "    n_eps_of_pol = [len(eval_data[1]) for eval_data in eval_data_of_pol]\n",
    "    \n",
    "    true_rews = sum(true_rews_of_pol, [])\n",
    "    perf_evals.append({\n",
    "      'true_rews': true_rews, \n",
    "      'true_rew': true_rew_of_pol, \n",
    "      'rew': rew_of_pol,\n",
    "      'n_eps': np.cumsum(n_eps_of_pol),\n",
    "      'xs': np.cumsum(np.ones(len(true_rews)))-1,\n",
    "      'baseline_true_reward': np.mean(baseline_true_rewards)\n",
    "    })\n",
    "    pols.append(eval_data_of_pol[-1][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a558a52",
   "metadata": {},
   "outputs": [],
   "source": [
    "mpl.rcParams.update({'font.size': 12})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7ded559",
   "metadata": {},
   "outputs": [],
   "source": [
    "pol_angs = [[pol[0] for pol in pols]]\n",
    "\n",
    "N = 20\n",
    "bottom = 8\n",
    "max_height = 4\n",
    "\n",
    "radii, theta = plt.hist(pol_angs, bins=20)[:2]\n",
    "theta = np.mean(list(zip(theta[:-1], theta[1:])), axis=1)\n",
    "width = (2*np.pi) / N\n",
    "\n",
    "ax = plt.subplot(111, polar=True)\n",
    "bars = ax.bar(theta, radii, width=width, color='orange')\n",
    "\n",
    "plt.title('Emergent Interfaces')\n",
    "plt.savefig(os.path.join(data_dir, 'user-study-learned-int.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e0b6332",
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = sum([perf_eval['true_rew'] for perf_eval in perf_evals], [])\n",
    "ys = sum([perf_eval['rew'] for perf_eval in perf_evals], [])\n",
    "rho = scipy.stats.spearmanr(xs, ys)[0]\n",
    "plt.title(r'True Reward vs. Mutual Information Reward ($\\rho = %0.2f$)' % rho)\n",
    "plt.xlabel('True Reward (Avg. Distance to Target)')\n",
    "plt.ylabel(r\"$\\mathcal{I}(\\mathbf{x}_t, (\\mathbf{s}_t, \\mathbf{s}_{t+1}))$\")\n",
    "plt.scatter(\n",
    "  xs, \n",
    "  ys,\n",
    "  color='orange',\n",
    "  s=50,\n",
    "  alpha=0.5\n",
    ")\n",
    "plt.xticks(fontsize=10)\n",
    "plt.savefig(os.path.join(data_dir, 'user-study-truerew-vs-mi.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ff464f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('Mutual Information Reward')\n",
    "plt.xlabel('Number of Online Training Episodes')\n",
    "plt.ylabel(r\"$\\mathcal{I}(\\mathbf{x}_t, (\\mathbf{s}_t, \\mathbf{s}_{t+1}))$\")\n",
    "utils.plot_perf_evals(perf_evals, 'n_eps', 'rew', label='MIMI (Ours)', smooth_win=1, color='teal')\n",
    "plt.legend(loc='lower right')\n",
    "plt.savefig(os.path.join(data_dir, 'user-study-mi-vs-eps.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f2b362a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('User Study: 2D Cursor Control with Perturbed Mouse')\n",
    "plt.xlabel('Number of Online Training Episodes')\n",
    "plt.ylabel('True Reward (Avg. Distance to Target)')\n",
    "utils.plot_perf_evals(perf_evals, 'xs', 'true_rews', label='MIMI (Ours)', smooth_win=10, color='orange')\n",
    "plt.axhline(y=np.mean([perf['baseline_true_reward'] for perf in perf_evals]), linestyle='--', color='gray', label='Random Interfaces (Baseline)')\n",
    "plt.axhline(y=-0.0726392950518139, linestyle='--', color='green', label='Oracle')\n",
    "plt.legend(loc='lower right')\n",
    "plt.ylim([-0.14, None])\n",
    "plt.savefig(os.path.join(data_dir, 'user-study-truerew-vs-eps.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ee581f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_trajs(eval_data_of_pol):\n",
    "  traj = np.array([x[0][:2] for eval_data in eval_data_of_pol for rollout in eval_data[1] for x in rollout])\n",
    "  plt.scatter(traj[:, 0], traj[:, 1], alpha=0.25, linewidth=0, color='gray')\n",
    "  plt.gca().set_aspect('equal', adjustable='box')\n",
    "  plt.xticks([])\n",
    "  plt.yticks([])\n",
    "  \n",
    "def plot_before_trajs(eval_data_of_pol):\n",
    "  plt.title('<50 Training Episodes')\n",
    "  plot_trajs(eval_data_of_pol[:5])\n",
    "  \n",
    "def plot_after_trajs(eval_data_of_pol):\n",
    "  plt.title('>150 Training Episodes')\n",
    "  plot_trajs(eval_data_of_pol[15:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9d96a7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(os.path.join(data_dir, '7', 'gp_results.pkl'), 'rb') as f:\n",
    "  eval_data_of_pol = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79a69101",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_before_trajs(eval_data_of_pol)\n",
    "plt.savefig(os.path.join(data_dir, 'before-trajs.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4693547",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_after_trajs(eval_data_of_pol)\n",
    "plt.savefig(os.path.join(data_dir, 'after-trajs.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eeadbf29",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
