{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "064d3455",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3352e1cd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "\n",
    "from copy import deepcopy\n",
    "import pickle\n",
    "import time\n",
    "import os\n",
    "import uuid\n",
    "\n",
    "import scipy\n",
    "import numpy as np\n",
    "import h5py\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib as mpl\n",
    "\n",
    "from mimi import envs\n",
    "from mimi import utils\n",
    "from mimi import user_models\n",
    "from mimi import opt\n",
    "from mimi import reward_models\n",
    "from mimi import models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a264006d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_dir = os.path.join(utils.data_dir, 'x2t')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4354834",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "sess = utils.make_tf_session(gpu_mode=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec6d981c",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = envs.X2TEnv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "474400d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "x2t_dir = os.path.join(data_dir, 'raw')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d13a44eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_eps(obses, actions, rewards):\n",
    "  eps = []\n",
    "  for i in range(len(obses)):\n",
    "    obs = obses[i]\n",
    "    action = utils.onehot_encode(actions[i], env.n_act_dim)\n",
    "    reward = rewards[i]\n",
    "    state = np.zeros(env.n_act_dim)\n",
    "    state = np.concatenate((state, obs))\n",
    "    next_state = action\n",
    "    next_state = np.concatenate((next_state, obs))\n",
    "    eps.append([(state, action, reward, next_state, True, {})])\n",
    "  return eps\n",
    "\n",
    "def load_eps(filename):\n",
    "  with h5py.File(filename, 'r') as f:\n",
    "    obses = np.array(f['obses'])\n",
    "    obses = np.mean(obses, axis=1)\n",
    "    actions = np.array([f['actions']])[0]\n",
    "    rewards = np.array(f['rewards'])\n",
    "  return format_eps(obses, actions, rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6ae7613",
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = ['default', 'x2t', 'offline']\n",
    "rollouts_of_pol = []\n",
    "method_of_pol = []\n",
    "for method in methods:\n",
    "  method_path = os.path.join(x2t_dir, method)\n",
    "  user_ids = os.listdir(method_path)\n",
    "  for user_id in user_ids:\n",
    "    user_path = os.path.join(method_path, user_id, 'data.hdf5')\n",
    "    eps = load_eps(user_path)\n",
    "    rollouts_of_pol.append(eps)\n",
    "    method_of_pol.append(method)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e36d928b",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_conds = len(rollouts_of_pol)\n",
    "n_steps = sum(len(r) for x in rollouts_of_pol for r in x)\n",
    "n_steps, n_conds, n_steps / n_conds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8960e93",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceae69b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "mi_model_init_args = [sess]\n",
    "mi_model_init_kwargs = {\n",
    "  'n_env_obs_dim': env.n_env_obs_dim,\n",
    "  'n_user_obs_dim': env.n_user_obs_dim,\n",
    "  'n_act_dim': env.n_act_dim,\n",
    "  'n_layers': 2,\n",
    "  'layer_size': 64\n",
    "}\n",
    "mi_model_train_kwargs = {\n",
    "  'iterations': 1000,\n",
    "  'ftol': 1e-6,\n",
    "  'learning_rate': 1e-4,\n",
    "  'batch_size': 64,\n",
    "  'val_update_freq': None,\n",
    "  'verbose': False,\n",
    "  'warm_start': False\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b638df84",
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_model = reward_models.MIRewardModel(\n",
    "  env,\n",
    "  mi_model_init_args,\n",
    "  mi_model_init_kwargs,\n",
    "  mi_model_train_kwargs\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44380ded",
   "metadata": {},
   "outputs": [],
   "source": [
    "ixs_reward_model = reward_models.MIRewardModel(\n",
    "  env,\n",
    "  mi_model_init_args,\n",
    "  mi_model_init_kwargs,\n",
    "  mi_model_train_kwargs,\n",
    "  use_next_env_obs=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7a8ed67",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed4d1bfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "true_rew_of_rollout = lambda rollout: np.mean([x[2] for x in rollout])\n",
    "true_reward_model = lambda rollouts: np.mean([true_rew_of_rollout(rollout) for rollout in rollouts])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a937a5f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "offline_reward_models = [true_reward_model, reward_model, ixs_reward_model]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10c0fc9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_seeds = 10\n",
    "rewards_of_pol = np.zeros((n_seeds, len(rollouts_of_pol), len(offline_reward_models)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74c5229a",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(n_seeds):\n",
    "  rewards_of_pol[i, :, :] = utils.compute_rews_of_rollouts(\n",
    "    rollouts_of_pol,\n",
    "    offline_reward_models,\n",
    "    verbose=True\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50254dd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_of_pol_path = os.path.join(data_dir, 'rewards_of_pol.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ed4f1a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(rewards_of_pol_path, 'wb') as f:\n",
    "  pickle.dump(rewards_of_pol, f, pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48b24a4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(rewards_of_pol_path, 'rb') as f:\n",
    "  rewards_of_pol = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "321f3099",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_of_pol[:, :, 1:] = np.maximum(rewards_of_pol[:, :, 1:], 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99082a8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_rewards_of_pol = np.mean(rewards_of_pol, axis=0)\n",
    "mean_rewards_of_pol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2b1c175",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_rewards_of_pol[:, 2] = np.minimum(mean_rewards_of_pol[:, 2], mean_rewards_of_pol[:, 1])\n",
    "mean_rewards_of_pol[:, 2] = np.maximum(mean_rewards_of_pol[:, 2], 0)\n",
    "mean_rewards_of_pol = np.concatenate((mean_rewards_of_pol, (mean_rewards_of_pol[:, 1] - mean_rewards_of_pol[:, 2])[:, np.newaxis]), axis=1)\n",
    "mean_rewards_of_pol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7655392",
   "metadata": {},
   "outputs": [],
   "source": [
    "color_of_method = {\n",
    "  'default': 'gray',\n",
    "  'offline': 'teal',\n",
    "  'x2t': 'orange'\n",
    "}\n",
    "label_of_method = {\n",
    "  'default': 'Non-Adaptive (After X2T)',\n",
    "  'offline': 'Non-Adaptive (Before X2T)',\n",
    "  'x2t': 'X2T'\n",
    "}\n",
    "color_of_pol = [color_of_method[m] for m in method_of_pol]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6b46d3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "idxes_of_method = {m: [] for m in methods}\n",
    "for i, method in enumerate(method_of_pol):\n",
    "  idxes_of_method[method].append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f673efa",
   "metadata": {},
   "outputs": [],
   "source": [
    "mpl.rcParams.update({'font.size': 12})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a35e91c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr = scipy.stats.spearmanr(mean_rewards_of_pol)\n",
    "corr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8375a5cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "rho = corr[0][0, 1]\n",
    "rho"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b8aab8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title(r'X2T ($\\rho$ = %0.2f)' % rho)\n",
    "plt.xlabel('True Reward (Classification Accuracy)')\n",
    "plt.ylabel(r\"$\\mathcal{I}(\\mathbf{x}_t, (\\mathbf{s}_t, \\mathbf{s}_{t+1}))$\")\n",
    "for method, idxes in idxes_of_method.items():\n",
    "  plt.scatter(\n",
    "    mean_rewards_of_pol[idxes, 0], \n",
    "    mean_rewards_of_pol[idxes, 1],\n",
    "    color=color_of_method[method],\n",
    "    label=label_of_method[method],\n",
    "    s=50\n",
    "  )\n",
    "plt.legend(loc='best')\n",
    "plt.savefig(os.path.join(data_dir, 'x2t-offline-eval-truerew-vs-mi-per-poluser.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "522f4734",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_of_method = [[[] for _ in range(mean_rewards_of_pol.shape[1])] for _ in methods]\n",
    "idx_of_method = {x: i for i, x in enumerate(methods)}\n",
    "for i, method in enumerate(method_of_pol):\n",
    "  for j in range(mean_rewards_of_pol.shape[1]):\n",
    "    rewards_of_method[idx_of_method[method]][j].append(mean_rewards_of_pol[i, j])\n",
    "rewards_of_method = [[np.mean(x) for x in y] for y in rewards_of_method]\n",
    "rewards_of_method = np.array(rewards_of_method)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96927291",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('X2T')\n",
    "plt.xlabel('True Reward (Classification Accuracy)')\n",
    "plt.ylabel(r\"$\\mathcal{I}(\\mathbf{x}_t, (\\mathbf{s}_t, \\mathbf{s}_{t+1}))$\")\n",
    "for i, method in enumerate(methods):\n",
    "  plt.scatter(\n",
    "    rewards_of_method[i, 0], \n",
    "    rewards_of_method[i, 1],\n",
    "    color=color_of_method[method],\n",
    "    label=label_of_method[method],\n",
    "    s=50\n",
    "  )\n",
    "plt.legend(loc='lower right')\n",
    "plt.savefig(os.path.join(data_dir, 'x2t-offline-eval-truerew-vs-mi-per-pol.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddc7a53a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
