{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "064d3455",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3352e1cd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "\n",
    "from copy import deepcopy\n",
    "import pickle\n",
    "import time\n",
    "import os\n",
    "import uuid\n",
    "\n",
    "import scipy\n",
    "import numpy as np\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib as mpl\n",
    "\n",
    "from mimi import envs\n",
    "from mimi import utils\n",
    "from mimi import user_models\n",
    "from mimi import opt\n",
    "from mimi import reward_models\n",
    "from mimi import models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a264006d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_dir = os.path.join(utils.data_dir, 'isql')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4354834",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "sess = utils.make_tf_session(gpu_mode=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d66b3bf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = envs.ISQLEnv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "825b8f3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "pilot_ids = [\n",
    "  'spike',\n",
    "  'jet',\n",
    "  'faye',\n",
    "  'vicious',\n",
    "  'ed',\n",
    "  'ein',\n",
    "  'julia',\n",
    "  'punch',\n",
    "  'judy',\n",
    "  'lin',\n",
    "  'grencia',\n",
    "  'laughingbull'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71493841",
   "metadata": {},
   "outputs": [],
   "source": [
    "isql_dir = os.path.join(data_dir, 'raw')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a685ab68",
   "metadata": {},
   "outputs": [],
   "source": [
    "newton_data_dir = os.path.join(isql_dir, '5.1-lander-newton')\n",
    "aristotle_data_dir = os.path.join(isql_dir, '5.1-lander-aristotle')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e60e1b2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "newton_rollouts_of_pilot = {k: [] for k in pilot_ids}\n",
    "aristotle_rollouts_of_pilot = {k: [] for k in pilot_ids}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2c3e78a",
   "metadata": {},
   "outputs": [],
   "source": [
    "for pilot_id in pilot_ids:\n",
    "  with open(os.path.join(newton_data_dir, '%s_pilot_policy_demo_rollouts.pkl' % pilot_id), 'rb') as f:\n",
    "    newton_rollouts_of_pilot[pilot_id].extend(sum(pickle.load(f), []))\n",
    "  with open(os.path.join(aristotle_data_dir, '%s_pilot_policy_demo_rollouts.pkl' % pilot_id), 'rb') as f:\n",
    "    aristotle_rollouts_of_pilot[pilot_id].extend(sum(pickle.load(f), []))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d66ddb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_obs(s, a):\n",
    "  return np.concatenate((s[:-1], a))\n",
    "\n",
    "def format_ep(rollout):\n",
    "  ep = []\n",
    "  for x in rollout:\n",
    "    a = utils.onehot_encode(x[1], env.n_act_dim)\n",
    "    r = x[2]\n",
    "    s = format_obs(x[0], a)\n",
    "    ns = format_obs(x[3], a)\n",
    "    ep.append((s, a, r, ns, False, {}))\n",
    "  return ep\n",
    "\n",
    "def format_eps(rollouts):\n",
    "  return [format_ep(ep) for ep in rollouts]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b87c308d",
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = ['aristotle', 'newton']\n",
    "rollouts_of_pol = []\n",
    "method_of_pol = []\n",
    "pilot_of_pol = []\n",
    "for method in methods:\n",
    "  for pilot_id in pilot_ids:\n",
    "    rollouts = eval('%s_rollouts_of_pilot' % method)[pilot_id]\n",
    "    eps = format_eps(rollouts)\n",
    "    rollouts_of_pol.append(eps)\n",
    "    method_of_pol.append(method)\n",
    "    pilot_of_pol.append(pilot_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32bd5b74",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_conds = len(rollouts_of_pol)\n",
    "n_steps = sum(len(r) for x in rollouts_of_pol for r in x)\n",
    "n_steps, n_conds, n_steps / n_conds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97a75bf4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceae69b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "mi_model_init_args = [sess]\n",
    "mi_model_init_kwargs = {\n",
    "  'n_env_obs_dim': env.n_env_obs_dim,\n",
    "  'n_user_obs_dim': env.n_user_obs_dim,\n",
    "  'n_act_dim': env.n_act_dim,\n",
    "  'n_layers': 2,\n",
    "  'layer_size': 64\n",
    "}\n",
    "mi_model_train_kwargs = {\n",
    "  'iterations': 1000,\n",
    "  'ftol': 1e-6,\n",
    "  'learning_rate': 1e-4,\n",
    "  'batch_size': 64,\n",
    "  'val_update_freq': None,\n",
    "  'verbose': False,\n",
    "  'warm_start': False\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b638df84",
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_model = reward_models.MIRewardModel(\n",
    "  env,\n",
    "  mi_model_init_args,\n",
    "  mi_model_init_kwargs,\n",
    "  mi_model_train_kwargs\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44380ded",
   "metadata": {},
   "outputs": [],
   "source": [
    "ixs_reward_model = reward_models.MIRewardModel(\n",
    "  env,\n",
    "  mi_model_init_args,\n",
    "  mi_model_init_kwargs,\n",
    "  mi_model_train_kwargs,\n",
    "  use_next_env_obs=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7a8ed67",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed4d1bfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "true_rew_of_rollout = lambda rollout: np.mean([x[2] for x in rollout])\n",
    "true_reward_model = lambda rollouts: np.mean([true_rew_of_rollout(rollout) for rollout in rollouts])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a937a5f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "offline_reward_models = [true_reward_model, reward_model, ixs_reward_model]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abc00ec6",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_seeds = 1\n",
    "rewards_of_pol = np.zeros((n_seeds, len(rollouts_of_pol), len(offline_reward_models)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51ca20fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(n_seeds):\n",
    "  rewards_of_pol[i, :, :] = utils.compute_rews_of_rollouts(\n",
    "    rollouts_of_pol,\n",
    "    offline_reward_models,\n",
    "    verbose=True\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50254dd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_of_pol_path = os.path.join(data_dir, 'rewards_of_pol.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ed4f1a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(rewards_of_pol_path, 'wb') as f:\n",
    "  pickle.dump(rewards_of_pol, f, pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48b24a4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(rewards_of_pol_path, 'rb') as f:\n",
    "  rewards_of_pol = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78a2414f",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_of_pol[:, :, 1:] = np.maximum(rewards_of_pol[:, :, 1:], 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "142c5510",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_rewards_of_pol = np.mean(rewards_of_pol, axis=0)\n",
    "mean_rewards_of_pol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "074524a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_rewards_of_pol[:, 2] = np.minimum(mean_rewards_of_pol[:, 2], mean_rewards_of_pol[:, 1])\n",
    "mean_rewards_of_pol[:, 2] = np.maximum(mean_rewards_of_pol[:, 2], 0)\n",
    "mean_rewards_of_pol = np.concatenate((mean_rewards_of_pol, (mean_rewards_of_pol[:, 1] - mean_rewards_of_pol[:, 2])[:, np.newaxis]), axis=1)\n",
    "mean_rewards_of_pol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9717c872",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_of_method = {\n",
    "  'newton': 'Solo Human',\n",
    "  'aristotle': 'With Copilot',\n",
    "}\n",
    "color_of_method = {\n",
    "  'newton': 'gray',\n",
    "  'aristotle': 'orange',\n",
    "}\n",
    "color_of_pol = [color_of_method[m] for m in method_of_pol]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b5c1d74",
   "metadata": {},
   "outputs": [],
   "source": [
    "idxes_of_method = {m: [] for m in methods}\n",
    "for i, method in enumerate(method_of_pol):\n",
    "  idxes_of_method[method].append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06a38633",
   "metadata": {},
   "outputs": [],
   "source": [
    "mpl.rcParams.update({'font.size': 12})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0300b862",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr = scipy.stats.spearmanr(mean_rewards_of_pol)\n",
    "corr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1a6edff",
   "metadata": {},
   "outputs": [],
   "source": [
    "rho = corr[0][0, 1]\n",
    "rho"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b8aab8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title(r'Internal-to-Real Dynamics Transfer ($\\rho$ = %0.2f)' % rho)\n",
    "plt.xlabel('True Reward')\n",
    "plt.ylabel(r\"$\\mathcal{I}(\\mathbf{x}_t, (\\mathbf{s}_t, \\mathbf{s}_{t+1}))$\")\n",
    "for method, idxes in idxes_of_method.items():\n",
    "  plt.scatter(\n",
    "    mean_rewards_of_pol[idxes, 0], \n",
    "    mean_rewards_of_pol[idxes, 1],\n",
    "    color=color_of_method[method],\n",
    "    label=label_of_method[method],\n",
    "    s=50\n",
    "  )\n",
    "plt.legend(loc='best')\n",
    "plt.savefig(os.path.join(data_dir, 'isql-offline-eval-truerew-vs-mi-per-poluser.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "299fee7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_of_method = [[[] for _ in range(mean_rewards_of_pol.shape[1])] for _ in methods]\n",
    "idx_of_method = {x: i for i, x in enumerate(methods)}\n",
    "for i, method in enumerate(method_of_pol):\n",
    "  for j in range(mean_rewards_of_pol.shape[1]):\n",
    "    rewards_of_method[idx_of_method[method]][j].append(mean_rewards_of_pol[i, j])\n",
    "rewards_of_method = [[np.mean(x) for x in y] for y in rewards_of_method]\n",
    "rewards_of_method = np.array(rewards_of_method)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b11a1698",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('Internal-to-Real Dynamics Transfer')\n",
    "plt.xlabel('True Reward')\n",
    "plt.ylabel(r\"$\\mathcal{I}(\\mathbf{x}_t, (\\mathbf{s}_t, \\mathbf{s}_{t+1}))$\")\n",
    "for i, method in enumerate(methods):\n",
    "  plt.scatter(\n",
    "    rewards_of_method[i, 0], \n",
    "    rewards_of_method[i, 1],\n",
    "    color=color_of_method[method],\n",
    "    label=label_of_method[method],\n",
    "    s=50\n",
    "  )\n",
    "plt.legend(loc='best')\n",
    "plt.savefig(os.path.join(data_dir, 'isql-offline-eval-truerew-vs-mi-per-pol.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8674fdc2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
