{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "064d3455",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3352e1cd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "\n",
    "from copy import deepcopy\n",
    "import pickle\n",
    "import time\n",
    "import os\n",
    "import uuid\n",
    "\n",
    "import scipy\n",
    "import numpy as np\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib as mpl\n",
    "\n",
    "from mimi import envs\n",
    "from mimi import utils\n",
    "from mimi import user_models\n",
    "from mimi import opt\n",
    "from mimi import reward_models\n",
    "from mimi import models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a264006d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_dir = os.path.join(utils.data_dir, 'deepassist')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4354834",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "sess = utils.make_tf_session(gpu_mode=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d66b3bf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = envs.DeepAssistEnv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e540e7d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "pilot_ids = ['spike', 'jet', 'faye', 'vicious', 'ed', 'ein', 'julia', 'punch', 'judy', 'lin', 'grencia', 'laughingbull']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dd54407",
   "metadata": {},
   "outputs": [],
   "source": [
    "deepassist_dir = os.path.join(data_dir, 'raw')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5643e711",
   "metadata": {},
   "outputs": [],
   "source": [
    "prune_state = lambda state: state[:1]\n",
    "n_eval_eps = 30\n",
    "\n",
    "def format_eps(data, method, delta=1):\n",
    "  eps = []\n",
    "  if type(data) == list:\n",
    "    all_rewards, outcomes, trajs, all_actions = data\n",
    "  else:\n",
    "    all_rewards = data['rewards'][0]\n",
    "    outcomes = data['outcomes'][0]\n",
    "    all_actions = data['actions'][0]\n",
    "    trajs = data['trajectories'][0]\n",
    "    \n",
    "  trajs = trajs[-n_eval_eps:]\n",
    "  all_actions = all_actions[-n_eval_eps:]\n",
    "  all_rewards = all_rewards[-n_eval_eps:]\n",
    "  outcomes = outcomes[-n_eval_eps:]\n",
    "  \n",
    "  for i, traj in enumerate(trajs):\n",
    "    actions = all_actions[i]\n",
    "    T = len(actions)\n",
    "    ep = []\n",
    "    for t in range(T):\n",
    "      state = traj[t]\n",
    "      action = utils.onehot_encode(actions[t], env.n_act_dim)\n",
    "      reward = 0 if t < T - 1 else all_rewards[i]\n",
    "      next_state = traj[min(T,t+delta)]\n",
    "      if method == 'pilot_eval':\n",
    "        format_state = lambda state, action: np.concatenate((prune_state(state), action))\n",
    "      else:\n",
    "        format_state = lambda state, action: np.concatenate((prune_state(state), state[-6:]))\n",
    "      state = format_state(state, action)\n",
    "      next_state = format_state(next_state, action)\n",
    "      assert state.size == env.n_obs_dim\n",
    "      assert next_state.size == env.n_obs_dim\n",
    "      ep.append((state, action, reward, next_state, False, {}))\n",
    "    eps.append(ep)\n",
    "  return eps\n",
    "\n",
    "def load_eps(filename, method, pilot_id, delta=1):\n",
    "  with open(filename, 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "    if method == 'pilot_eval':\n",
    "      data = data[pilot_id]\n",
    "    else:\n",
    "      data = list(data.values())[0]\n",
    "  return format_eps(data, method, delta=delta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b87c308d",
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = ['pilot_eval', 'reward_logs']\n",
    "def make_dataset(delta=1):\n",
    "  rollouts_of_pol = []\n",
    "  method_of_pol = []\n",
    "  pilot_of_pol = []\n",
    "  for method in methods:\n",
    "    for pilot_id in pilot_ids:\n",
    "      path = os.path.join(deepassist_dir, '%s_%s.pkl' % (pilot_id, method))\n",
    "      eps = load_eps(path, method, pilot_id, delta=delta)\n",
    "      rollouts_of_pol.append(eps)\n",
    "      method_of_pol.append(method)\n",
    "      pilot_of_pol.append(pilot_id)\n",
    "  return rollouts_of_pol, method_of_pol, pilot_of_pol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2accb64f",
   "metadata": {},
   "outputs": [],
   "source": [
    "rollouts_of_pol, method_of_pol, pilot_of_pol = make_dataset(delta=np.inf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32bd5b74",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_conds = len(rollouts_of_pol)\n",
    "n_steps = sum(len(r) for x in rollouts_of_pol for r in x)\n",
    "n_steps, n_conds, n_steps / n_conds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97a75bf4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceae69b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "mi_model_init_args = [sess]\n",
    "mi_model_init_kwargs = {\n",
    "  'n_env_obs_dim': env.n_env_obs_dim,\n",
    "  'n_user_obs_dim': env.n_user_obs_dim,\n",
    "  'n_act_dim': env.n_act_dim,\n",
    "  'n_layers': 2,\n",
    "  'layer_size': 64\n",
    "}\n",
    "mi_model_train_kwargs = {\n",
    "  'iterations': 1000,\n",
    "  'ftol': 1e-6,\n",
    "  'learning_rate': 1e-4,\n",
    "  'batch_size': 64,\n",
    "  'val_update_freq': None,\n",
    "  'verbose': False,\n",
    "  'warm_start': False\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b638df84",
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_model = reward_models.MIRewardModel(\n",
    "  env,\n",
    "  mi_model_init_args,\n",
    "  mi_model_init_kwargs,\n",
    "  mi_model_train_kwargs\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44380ded",
   "metadata": {},
   "outputs": [],
   "source": [
    "ixs_reward_model = reward_models.MIRewardModel(\n",
    "  env,\n",
    "  mi_model_init_args,\n",
    "  mi_model_init_kwargs,\n",
    "  mi_model_train_kwargs,\n",
    "  use_next_env_obs=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7a8ed67",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed4d1bfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "true_rew_of_rollout = lambda rollout: np.mean([x[2] for x in rollout])\n",
    "true_reward_model = lambda rollouts: np.mean([true_rew_of_rollout(rollout) for rollout in rollouts])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a937a5f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "offline_reward_models = [true_reward_model, reward_model, ixs_reward_model]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b797d96",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_seeds = 10\n",
    "def compute_rewards(rollouts_of_pol):\n",
    "  rewards_of_pol = np.zeros((n_seeds, len(rollouts_of_pol), len(offline_reward_models)))\n",
    "  for i in range(n_seeds):\n",
    "    rewards_of_pol[i, :, :] = utils.compute_rews_of_rollouts(\n",
    "      rollouts_of_pol,\n",
    "      offline_reward_models,\n",
    "      verbose=True\n",
    "    )\n",
    "  return rewards_of_pol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51ca20fe",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "rewards_of_pol = compute_rewards(rollouts_of_pol)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50254dd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_of_pol_path = os.path.join(data_dir, 'rewards_of_pol.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ed4f1a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(rewards_of_pol_path, 'wb') as f:\n",
    "  pickle.dump(rewards_of_pol, f, pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48b24a4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(rewards_of_pol_path, 'rb') as f:\n",
    "  rewards_of_pol = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6829a5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_of_pol[:, :, 1:] = np.maximum(rewards_of_pol[:, :, 1:], 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31da70e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_rewards_of_pol = np.mean(rewards_of_pol, axis=0)\n",
    "mean_rewards_of_pol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "621d3373",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_rewards_of_pol[:, 2] = np.minimum(mean_rewards_of_pol[:, 2], mean_rewards_of_pol[:, 1])\n",
    "mean_rewards_of_pol[:, 2] = np.maximum(mean_rewards_of_pol[:, 2], 0)\n",
    "mean_rewards_of_pol = np.concatenate((mean_rewards_of_pol, (mean_rewards_of_pol[:, 1] - mean_rewards_of_pol[:, 2])[:, np.newaxis]), axis=1)\n",
    "mean_rewards_of_pol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9717c872",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_of_method = {\n",
    "  'pilot_eval': 'Solo Human',\n",
    "  'reward_logs': 'With Copilot',\n",
    "}\n",
    "color_of_method = {\n",
    "  'pilot_eval': 'gray',\n",
    "  'reward_logs': 'orange',\n",
    "}\n",
    "color_of_pol = [color_of_method[m] for m in method_of_pol]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b5c1d74",
   "metadata": {},
   "outputs": [],
   "source": [
    "idxes_of_method = {m: [] for m in methods}\n",
    "for i, method in enumerate(method_of_pol):\n",
    "  idxes_of_method[method].append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b84b0e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "mpl.rcParams.update({'font.size': 12})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88f9f391",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr = scipy.stats.spearmanr(mean_rewards_of_pol)\n",
    "corr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2ab37de",
   "metadata": {},
   "outputs": [],
   "source": [
    "rho = corr[0][0, 1]\n",
    "rho"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b8aab8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title(r'Shared Autonomy via Deep RL ($\\rho$ = %0.2f)' % rho)\n",
    "plt.xlabel('True Reward')\n",
    "plt.ylabel(r\"$\\mathcal{I}(\\mathbf{x}, (\\mathbf{s}_t, \\mathbf{s}_T))$\")\n",
    "for method, idxes in idxes_of_method.items():\n",
    "  plt.scatter(\n",
    "    mean_rewards_of_pol[idxes, 0], \n",
    "    mean_rewards_of_pol[idxes, 1],\n",
    "    color=color_of_method[method],\n",
    "    label=label_of_method[method],\n",
    "    s=50\n",
    "  )\n",
    "plt.legend(loc='upper center')\n",
    "plt.savefig(os.path.join(data_dir, 'deep-assist-offline-eval-truerew-vs-mi-per-poluser.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "299fee7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_of_method = [[[] for _ in range(mean_rewards_of_pol.shape[1])] for _ in methods]\n",
    "idx_of_method = {x: i for i, x in enumerate(methods)}\n",
    "for i, method in enumerate(method_of_pol):\n",
    "  for j in range(mean_rewards_of_pol.shape[1]):\n",
    "    rewards_of_method[idx_of_method[method]][j].append(mean_rewards_of_pol[i, j])\n",
    "rewards_of_method = [[np.mean(x) for x in y] for y in rewards_of_method]\n",
    "rewards_of_method = np.array(rewards_of_method)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b11a1698",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('Shared Autonomy via Deep RL')\n",
    "plt.xlabel('True Reward')\n",
    "plt.ylabel(r\"$\\mathcal{I}(\\mathbf{x}, (\\mathbf{s}_t, \\mathbf{s}_T))$\")\n",
    "for i, method in enumerate(methods):\n",
    "  plt.scatter(\n",
    "    rewards_of_method[i, 0], \n",
    "    rewards_of_method[i, 1],\n",
    "    color=color_of_method[method],\n",
    "    label=label_of_method[method],\n",
    "    s=50\n",
    "  )\n",
    "plt.legend(loc='best')\n",
    "plt.savefig(os.path.join(data_dir, 'deep-assist-offline-eval-truerew-vs-mi-per-pol.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b759ec1e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab41aadb",
   "metadata": {},
   "outputs": [],
   "source": [
    "deltas = [1, 10, 20, 100, 200]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e1fd89e",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_of_pol_of_delta = {}\n",
    "rewards_of_pol_of_delta[np.inf] = rewards_of_pol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "013975b5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for delta in deltas:\n",
    "  if delta not in rewards_of_pol_of_delta:\n",
    "    rewards_of_pol_of_delta[delta] = compute_rewards(make_dataset(delta=delta)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdc4ddbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_of_pol_of_delta_path = os.path.join(data_dir, 'rewards_of_pol_of_delta.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd8d6f74",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(rewards_of_pol_of_delta_path, 'wb') as f:\n",
    "  pickle.dump(rewards_of_pol_of_delta, f, pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58dadb37",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(rewards_of_pol_of_delta_path, 'rb') as f:\n",
    "  rewards_of_pol_of_delta = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12c82e06",
   "metadata": {},
   "outputs": [],
   "source": [
    "deltas = sorted(rewards_of_pol_of_delta.keys())\n",
    "corrs = []\n",
    "for delta in deltas:\n",
    "  rewards_of_pol = rewards_of_pol_of_delta[delta]\n",
    "  rewards_of_pol[:, :, 1:] = np.maximum(rewards_of_pol[:, :, 1:], 0)\n",
    "  mean_rewards_of_pol = np.mean(rewards_of_pol, axis=0)\n",
    "  corr, _ = scipy.stats.spearmanr(mean_rewards_of_pol[:, 0], mean_rewards_of_pol[:, 1])\n",
    "  corrs.append(corr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "239dbbed",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_ep_len = max(len(r) for rs in rollouts_of_pol for r in rs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94cee1a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('Shared Autonomy via Deep RL\\n' + r'$\\mathcal{I}(\\mathbf{x}_t, (\\mathbf{s}_t, \\mathbf{s}_{t+\\Delta}))$ vs. True Reward')\n",
    "plt.xlabel(r'Time Offset $\\Delta$')\n",
    "plt.ylabel(r\"Spearman's Rank Correlation $\\rho$\")\n",
    "plt.plot([max_ep_len if d == np.inf else d for d in deltas], corrs, color='orange', marker='o')\n",
    "plt.axhline(y=0, linestyle='--', color='gray')\n",
    "plt.xscale('log')\n",
    "plt.savefig(os.path.join(data_dir, 'deep-assist-offline-eval-corr-vs-delta.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8674fdc2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
