{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sepsis Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "\n",
    "from core.sepsisSimDiabetes.State import State\n",
    "from core.sepsisSimDiabetes.Action import Action\n",
    "from core import generator_confounded_mdp as DGEN\n",
    "from core import conf_wis as CWIS\n",
    "from core import loss_minimization as LB\n",
    "from utils.utils import *\n",
    "\n",
    "import seaborn as sns\n",
    "sns.set_style(\"whitegrid\")\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "figdir='./figs'\n",
    "\n",
    "from more_itertools import locate\n",
    "import pickle\n",
    "import math"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading the preprocessed data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('data/value_function.pkl', 'rb') as f:\n",
    "    value_function = pickle.load(f)\n",
    "with open('data/t0_policy.pkl', 'rb') as f:\n",
    "    t0_policy = pickle.load(f).transpose([0, 2, 1])\n",
    "with open('data/optimal_policy_st.pkl', 'rb') as f:\n",
    "    optimal_policy_st = pickle.load(f).transpose([1, 0])\n",
    "with open('data/tx_tr.pkl', 'rb') as f:\n",
    "    tx, tr = pickle.load(f)\n",
    "with open('data/mixed_policy.pkl', 'rb') as f:\n",
    "    mixed_policy = pickle.load(f).transpose([1, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 0. read varied policies\n",
    "with_antibiotics_init = (t0_policy[0, :, :], optimal_policy_st)\n",
    "without_antibiotics_init = (t0_policy[1, :, :], optimal_policy_st)\n",
    "optimal_policy = (optimal_policy_st, optimal_policy_st)\n",
    "\n",
    "with_antibiotics_allway = (t0_policy[0, :, :], t0_policy[0, :, :])\n",
    "without_antibiotics_allway = (t0_policy[1, :, :], t0_policy[1, :, :])\n",
    "\n",
    "with open('data/optimal_policy_80_st.pkl', 'rb') as f:\n",
    "    optimal_policy_80_st = pickle.load(f).transpose([1, 0])\n",
    "with open('data/optimal_policy_60_st.pkl', 'rb') as f:\n",
    "    optimal_policy_60_st = pickle.load(f).transpose([1, 0])\n",
    "sub_opt_policy_80 = (optimal_policy_80_st, optimal_policy_80_st)\n",
    "sub_opt_policy_60 = (optimal_policy_60_st, optimal_policy_60_st)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Generation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**config:** the config file containts the following informations:\n",
    "\n",
    "- `Gamma`: amount of confounding in the data generation process ($\\Gamma^\\star$)\n",
    "- `num_itrs`: number of simulation trajectories\n",
    "- `max_horizon`: maximum number of timesteps in the sim\n",
    "- `discount`: discount factor in the MDP ($\\gamma$)\n",
    "- `confounding_threshold`: confounding threshold (refer to Appendix D)\n",
    "- `nS`: number of states (including the termina state)\n",
    "- `nA`: number of actions\n",
    "- `p_diabetes`: probability of a diabetic patient\n",
    "- `n_bootstrap` : number of bootstrap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_itrs = 5000 # train\n",
    "horizon = 50\n",
    "num_itrs_test = 5000 # test\n",
    "discount = 0.99"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# historical data\n",
    "config = {'Gamma': 2.0, 'num_itrs': num_itrs, 'max_horizon': horizon, 'discount': discount,\n",
    "              'confounding_threshold': 0.75, 'nS': State.NUM_FULL_STATES + 1, \n",
    "              'nA': Action.NUM_ACTIONS_TOTAL, 'p_diabetes': 0.2, 'n_bootstrap': 500}\n",
    "dgen = DGEN.conf_data_generator(transitions=(tx, tr), \n",
    "            policies=(optimal_policy_st, np.asarray([optimal_policy_st,optimal_policy_st])), \n",
    "            value_fn=value_function, config=config)\n",
    "trajectories, returns = dgen.simulate(config['num_itrs'], use_tqdm=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read from existing data\n",
    "file = open('processed_data/trajectories_pitree_H_50_beh_optimal.pkl', 'rb')\n",
    "trajectories = pickle.load(file)\n",
    "\n",
    "file = open('processed_data/returns_pitree_H_50_beh_optimal.pkl', 'rb')\n",
    "returns = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wis_all = CWIS.conf_wis(trajectories=trajectories, returns=returns, k=1, config=config)\n",
    "wis_all_with_anti = wis_all.compute(with_antibiotics_init)\n",
    "wis_all_without_anti = wis_all.compute(without_antibiotics_init)\n",
    "wis_all_optimal = wis_all.compute(optimal_policy)\n",
    "wis_all_sub_opt_policy_80 = wis_all.compute(sub_opt_policy_80)\n",
    "wis_all_sub_opt_policy_60 = wis_all.compute(sub_opt_policy_60)\n",
    "wis_all_with_anti_allway = wis_all.compute(with_antibiotics_allway)\n",
    "wis_all_without_allway = wis_all.compute(without_antibiotics_allway) \n",
    "\n",
    "\n",
    "print(wis_all_with_anti.mean(), wis_all_without_anti.mean(), wis_all_optimal.mean(), \n",
    "      wis_all_sub_opt_policy_80.mean(), wis_all_sub_opt_policy_60.mean(), wis_all_with_anti_allway.mean(), wis_all_without_allway.mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getSE(returns):\n",
    "    se = returns.std()/math.sqrt(num_itrs)\n",
    "    return se\n",
    "\n",
    "print(getSE(wis_all_with_anti), getSE(wis_all_without_anti), getSE(wis_all_optimal), \n",
    "      getSE(wis_all_sub_opt_policy_80), getSE(wis_all_sub_opt_policy_60), getSE(wis_all_with_anti_allway), getSE(wis_all_without_allway))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# death patients check:\n",
    "idx_death = [i for i,r in enumerate(returns) if r < 0.]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "##### check if patterns exist\n",
    "max_pattern = 10\n",
    "from prefixspan import PrefixSpan\n",
    "\n",
    "state_traj = [list(traj[:max_pattern,2]) for traj in trajectories]\n",
    "ps_all = PrefixSpan(state_traj)\n",
    "print(ps_all.topk(50))\n",
    "print('--------------')\n",
    "death_trajectories = [trajectories[i] for i in idx_death]\n",
    "state_death_traj = [list(traj[:max_pattern,2]) for traj in death_trajectories]\n",
    "ps = PrefixSpan(state_death_traj)\n",
    "print(ps.topk(50))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pattern_list = [463, 372, 219, 373, 388, 389, 703, 149, 144, 133] # death+undischarged"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # calculate confidence when we decide to switch policy\n",
    "# # calculate value of switching to a policy at time step t\n",
    "\n",
    "target_policies = [with_antibiotics_init, without_antibiotics_init, optimal_policy, sub_opt_policy_80, sub_opt_policy_60, sub_opt_policy_5, with_antibiotics_allway, without_antibiotics_allway]\n",
    "# penalty = 0.\n",
    "\n",
    "# check if signal is strong enough\n",
    "# e.g., pattern appear\n",
    "# then estimate returns on switching to policy pi at step t\n",
    "# if greater than pi_b enough, then switch to policy pi\n",
    "\n",
    "# outside the generator, calculate estimated policy value from each time step t with associated patterns\n",
    "num_subgroup = len(pattern_list)\n",
    "num_target_policy = len(target_policies)\n",
    "\n",
    "potential_policy_returns = np.zeros((num_subgroup, horizon, num_target_policy))\n",
    "\n",
    "for g,pattern in enumerate(pattern_list):\n",
    "    print('pattern: ', g)\n",
    "    \n",
    "    idx_death_with_pattern = []\n",
    "\n",
    "    ## 1. get death/undischarged traj with the patterns\n",
    "    for idx,subtraj in enumerate(death_trajectories):\n",
    "        if pattern in subtraj:\n",
    "            idx_death_with_pattern.append(idx_death[idx])\n",
    "            \n",
    "    for h in range(horizon):\n",
    "        # only use the subtrajectories after the step h\n",
    "        death_traj_with_patterns = [trajectories[idx][h:] for idx in idx_death_with_pattern]\n",
    "        return_with_patterns = [returns[idx] for idx in idx_death_with_pattern]\n",
    "\n",
    "        ## 2. ope for target policies from time step h\n",
    "        wis = CWIS.conf_wis(trajectories=np.asarray(death_traj_with_patterns), returns=np.asarray(return_with_patterns), k=1, config=config)\n",
    "\n",
    "        for p,target_p in enumerate(target_policies):\n",
    "            \n",
    "            wis_target_p = wis.compute(target_p)\n",
    "            # print(wis_target_p.mean())\n",
    "            \n",
    "            ## 3. save the estimated values\n",
    "            potential_policy_returns[g, h, p] = wis_target_p.mean()\n",
    "            \n",
    "            \n",
    "            print('est return after step {} for policy {}: {} std {}'.format(h, p, wis_target_p.mean(), wis_target_p.std()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "potential_policy_returns_clean = np.nan_to_num(potential_policy_returns, nan=-100.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set the threshold as the behavior policy value for the subgroup\n",
    "target_policies_threshold = np.zeros(num_subgroup)\n",
    "\n",
    "for g,pattern in enumerate(pattern_list):\n",
    "    print('pattern: ', g)\n",
    "    \n",
    "    idx_death_with_pattern = []\n",
    "\n",
    "    ## 1. get death/undischarged traj with the patterns\n",
    "    for idx,subtraj in enumerate(death_trajectories):\n",
    "        if pattern in subtraj:\n",
    "            idx_death_with_pattern.append(idx_death[idx])\n",
    "            \n",
    "    \n",
    "    # use the entire subtrajectories\n",
    "    death_traj_with_patterns = [trajectories[idx] for idx in idx_death_with_pattern]\n",
    "    return_with_patterns = [returns[idx] for idx in idx_death_with_pattern]\n",
    "\n",
    "    target_policies_threshold[g] = np.asarray(return_with_patterns).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Run online testing\n",
    "from core import generator_confounded_HBO as DGEN_HBO\n",
    "\n",
    "\n",
    "# remove policy 5 by changing its estimation to -1000.\n",
    "potential_policy_returns_test = potential_policy_returns_clean.copy()\n",
    "potential_policy_returns_test[:, :, 5] = -1000.\n",
    "potential_policy_returns_test[:, :, 4] = -1000.\n",
    "\n",
    "\n",
    "test_config = {'Gamma': 2.0, 'num_itrs': num_itrs_test, 'max_horizon': horizon, 'discount': 0.99,\n",
    "          'confounding_threshold': 0.75, 'nS': State.NUM_FULL_STATES + 1, \n",
    "          'nA': Action.NUM_ACTIONS_TOTAL, 'p_diabetes': 0.2, 'n_bootstrap': 500}\n",
    "\n",
    "dgen_hbo = DGEN_HBO.conf_data_generator(transitions=(tx, tr),\n",
    "            policies=(optimal_policy_st, np.asarray([optimal_policy_st,optimal_policy_st])), \n",
    "            value_fn=value_function, config=test_config,\n",
    "            pattern_list=pattern_list, potential_policy_returns=potential_policy_returns_test,\n",
    "            target_policies_threshold=target_policies_threshold,\n",
    "            target_policies=target_policies\n",
    "            )\n",
    "trajectories_test, returns_test = dgen_hbo.simulate(test_config['num_itrs'], transitions=(tx, tr), value_function=value_function, use_tqdm=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
