{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "064d3455",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3352e1cd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "\n",
    "from copy import deepcopy\n",
    "import pickle\n",
    "import time\n",
    "import os\n",
    "import uuid\n",
    "\n",
    "import scipy\n",
    "import numpy as np\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib as mpl\n",
    "\n",
    "from mimi import envs\n",
    "from mimi import utils\n",
    "from mimi import user_models\n",
    "from mimi import opt\n",
    "from mimi import reward_models\n",
    "from mimi import models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a264006d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_dir = os.path.join(utils.data_dir, 'asha-switch')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4354834",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "sess = utils.make_tf_session(gpu_mode=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f664bf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = envs.ASHASwitchEnv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eef68d7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "asha_dir = os.path.join(data_dir, 'raw')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ad2acbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_obses(raw_obses):\n",
    "  obses = []\n",
    "  for raw_obs in raw_obses:\n",
    "    obses.append(np.concatenate((raw_obs['raw_obs'], raw_obs['gaze_features'])))\n",
    "  return obses\n",
    "\n",
    "def format_eps(data):\n",
    "  eps = []\n",
    "  for block in data:\n",
    "    for raw_ep in block:\n",
    "      T = raw_ep['rewards'].shape[0]\n",
    "      obses = format_obses(raw_ep['observations'])\n",
    "      actions = raw_ep['actions']\n",
    "      rewards = raw_ep['rewards'][:, 0]\n",
    "      next_obses = format_obses(raw_ep['next_observations'])\n",
    "      ep = []\n",
    "      for t in range(T):\n",
    "        state = obses[t]\n",
    "        action = actions[t]\n",
    "        reward = rewards[t]\n",
    "        next_state = next_obses[t]\n",
    "        ep.append((state, action, reward, next_state, False, {}))\n",
    "      eps.append(ep)\n",
    "  return eps\n",
    "\n",
    "def load_eps(filename):\n",
    "  with open(filename, 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "  return format_eps(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdc45b5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = ['A', 'B', 'C']\n",
    "rollouts_of_pol = []\n",
    "method_of_pol = []\n",
    "for method in methods:\n",
    "  method_path = os.path.join(asha_dir, method)\n",
    "  user_ids = os.listdir(method_path)\n",
    "  for user_id in user_ids:\n",
    "    user_path = os.path.join(method_path, user_id, 'data.pkl')\n",
    "    if os.path.exists(user_path):\n",
    "      eps = load_eps(user_path)\n",
    "      rollouts_of_pol.append(eps)\n",
    "      method_of_pol.append(method)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81485ff8",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_conds = len(rollouts_of_pol)\n",
    "n_steps = sum(len(r) for x in rollouts_of_pol for r in x)\n",
    "n_steps, n_conds, n_steps / n_conds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdfb8a0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "rollouts_of_pol[0][0][0][0].shape, rollouts_of_pol[0][0][0][1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8eb8de3d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceae69b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "mi_model_init_args = [sess]\n",
    "mi_model_init_kwargs = {\n",
    "  'n_env_obs_dim': env.n_env_obs_dim,\n",
    "  'n_user_obs_dim': env.n_user_obs_dim,\n",
    "  'n_act_dim': env.n_act_dim,\n",
    "  'n_layers': 2,\n",
    "  'layer_size': 64\n",
    "}\n",
    "mi_model_train_kwargs = {\n",
    "  'iterations': 1000,\n",
    "  'ftol': 1e-6,\n",
    "  'learning_rate': 1e-4,\n",
    "  'batch_size': 64,\n",
    "  'val_update_freq': None,\n",
    "  'verbose': False,\n",
    "  'warm_start': False\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b638df84",
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_model = reward_models.MIRewardModel(\n",
    "  env,\n",
    "  mi_model_init_args,\n",
    "  mi_model_init_kwargs,\n",
    "  mi_model_train_kwargs\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44380ded",
   "metadata": {},
   "outputs": [],
   "source": [
    "ixs_reward_model = reward_models.MIRewardModel(\n",
    "  env,\n",
    "  mi_model_init_args,\n",
    "  mi_model_init_kwargs,\n",
    "  mi_model_train_kwargs,\n",
    "  use_next_env_obs=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7a8ed67",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed4d1bfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "true_rew_of_rollout = lambda rollout: np.mean([x[2] for x in rollout])\n",
    "true_reward_model = lambda rollouts: np.mean([true_rew_of_rollout(rollout) for rollout in rollouts])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a937a5f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "offline_reward_models = [true_reward_model, reward_model, ixs_reward_model]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3f67e77",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_seeds = 10\n",
    "rewards_of_pol = np.zeros((n_seeds, len(rollouts_of_pol), len(offline_reward_models)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51ca20fe",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i in range(n_seeds):\n",
    "  rewards_of_pol[i, :, :] = utils.compute_rews_of_rollouts(\n",
    "    rollouts_of_pol,\n",
    "    offline_reward_models,\n",
    "    verbose=True\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50254dd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_of_pol_path = os.path.join(data_dir, 'rewards_of_pol.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ed4f1a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(rewards_of_pol_path, 'wb') as f:\n",
    "  pickle.dump(rewards_of_pol, f, pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48b24a4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(rewards_of_pol_path, 'rb') as f:\n",
    "  rewards_of_pol = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c483c30f",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_of_pol[:, :, 1:] = np.maximum(rewards_of_pol[:, :, 1:], 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59ecc094",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_rewards_of_pol = np.mean(rewards_of_pol, axis=0)\n",
    "mean_rewards_of_pol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fc12b8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_rewards_of_pol[:, 2] = np.minimum(mean_rewards_of_pol[:, 2], mean_rewards_of_pol[:, 1])\n",
    "mean_rewards_of_pol[:, 2] = np.maximum(mean_rewards_of_pol[:, 2], 0)\n",
    "mean_rewards_of_pol = np.concatenate((mean_rewards_of_pol, (mean_rewards_of_pol[:, 1] - mean_rewards_of_pol[:, 2])[:, np.newaxis]), axis=1)\n",
    "mean_rewards_of_pol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5f66d4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "color_of_method = {\n",
    "  'A': 'gray',\n",
    "  'B': 'orange',\n",
    "  'C': 'teal'\n",
    "}\n",
    "label_of_method = {\n",
    "  'A': 'Non-Adaptive',\n",
    "  'B': 'ASHA',\n",
    "  'C': 'ASHA (With Distribution Shift)'\n",
    "}\n",
    "color_of_pol = [color_of_method[m] for m in method_of_pol]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "636500db",
   "metadata": {},
   "outputs": [],
   "source": [
    "idxes_of_method = {m: [] for m in methods}\n",
    "for i, method in enumerate(method_of_pol):\n",
    "  idxes_of_method[method].append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "167e6c41",
   "metadata": {},
   "outputs": [],
   "source": [
    "mpl.rcParams.update({'font.size': 12})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "864af26a",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr = scipy.stats.spearmanr(mean_rewards_of_pol)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebf2d169",
   "metadata": {},
   "outputs": [],
   "source": [
    "rho = corr[0][0, 1]\n",
    "rho"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b8aab8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title(r'ASHA Switch ($\\rho$ = %0.2f)' % rho)\n",
    "plt.xlabel('True Reward')\n",
    "plt.ylabel(r\"$\\mathcal{I}(\\mathbf{x}_t, (\\mathbf{s}_t, \\mathbf{s}_{t+1}))$\")\n",
    "for method, idxes in idxes_of_method.items():\n",
    "  plt.scatter(\n",
    "    mean_rewards_of_pol[idxes, 0], \n",
    "    mean_rewards_of_pol[idxes, 1],\n",
    "    color=color_of_method[method],\n",
    "    label=label_of_method[method],\n",
    "    s=50\n",
    "  )\n",
    "plt.legend(loc='best')\n",
    "plt.savefig(os.path.join(data_dir, 'asha-switch-offline-eval-truerew-vs-mi-per-poluser.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eec835ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_of_method = [[[] for _ in range(mean_rewards_of_pol.shape[1])] for _ in methods]\n",
    "idx_of_method = {x: i for i, x in enumerate(methods)}\n",
    "for i, method in enumerate(method_of_pol):\n",
    "  for j in range(mean_rewards_of_pol.shape[1]):\n",
    "    rewards_of_method[idx_of_method[method]][j].append(mean_rewards_of_pol[i, j])\n",
    "rewards_of_method = [[np.mean(x) for x in y] for y in rewards_of_method]\n",
    "rewards_of_method = np.array(rewards_of_method)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "010cc62d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('ASHA Switch')\n",
    "plt.xlabel('True Reward')\n",
    "plt.ylabel(r\"$\\mathcal{I}(\\mathbf{x}_t, (\\mathbf{s}_t, \\mathbf{s}_{t+1}))$\")\n",
    "for i, method in enumerate(methods):\n",
    "  plt.scatter(\n",
    "    rewards_of_method[i, 0], \n",
    "    rewards_of_method[i, 1],\n",
    "    color=color_of_method[method],\n",
    "    label=label_of_method[method],\n",
    "    s=50\n",
    "  )\n",
    "plt.legend(loc='lower right')\n",
    "plt.savefig(os.path.join(data_dir, 'asha-switch-offline-eval-truerew-vs-mi-per-pol.pdf'), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8674fdc2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
