{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import torch\n",
    "import numpy as np\n",
    "from collections import deque\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import numpy as np\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pandas import DataFrame\n",
    "import pickle as pickle\n",
    "import math\n",
    "import copy\n",
    "import tf_slim\n",
    "import tensorflow.compat.v1 as tf\n",
    "tf.disable_v2_behavior()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv('./data/rl_train_data_final_reward2_auto.csv')\n",
    "val = pd.read_csv(\"./data/rl_val_data_final_reward2_auto.csv\")\n",
    "test = pd.read_csv('./data/rl_test_data_final_reward2_auto.csv')\n",
    "\n",
    "\n",
    "df = train\n",
    "val_df = val\n",
    "test_df = test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['iv_input'] = df['iv_input']/2000\n",
    "val_df['iv_input'] = val_df['iv_input']/2000\n",
    "test_df['iv_input'] = test_df['iv_input']/2000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# with open('../Our-Model/data/state_features.txt') as f:\n",
    "#     state_features = f.read().split()\n",
    "state_features = [str(i) for i in range(16)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cql_new_data(df):\n",
    "    obs = []\n",
    "    next_states = []\n",
    "    actions = []\n",
    "    rewards = []\n",
    "    dones = []\n",
    "    died = []\n",
    "    for i in df.index:\n",
    "        ob = df.loc[i,state_features]\n",
    "        r = df.loc[i,'reward']\n",
    "        iv = df.loc[i, 'iv_input']\n",
    "        vaso = df.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        die = df.loc[i,'died_in_hosp']\n",
    "        if i != df.index[-1]:\n",
    "            # if not terminal step in trajectory             \n",
    "            if df.loc[i, 'icustayid'] == df.loc[i+1, 'icustayid']:\n",
    "                next_state = df.loc[i + 1, state_features]\n",
    "                done = 0\n",
    "            else:\n",
    "                # trajectory is finished\n",
    "                next_state = np.zeros(len(ob))\n",
    "                done = 1\n",
    "        else:\n",
    "            # last entry in df is the final state of that trajectory\n",
    "            next_state = np.zeros(len(ob))\n",
    "            done = 1\n",
    "        obs.append(np.array(ob))\n",
    "        next_states.append(np.array(next_state))\n",
    "        actions.append(np.array(action))\n",
    "        rewards.append(np.array(r))\n",
    "        dones.append(np.array(done))\n",
    "    \n",
    "    path = dict({'observations': np.array(obs),'next_observations': np.array(next_states),'actions': np.array(actions),'rewards': np.array(rewards),'dones': np.array(dones),'died':np.array(died)})\n",
    "\n",
    "    return path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dtdata_val(df1):\n",
    "    trajectories=[]\n",
    "    obs = []\n",
    "    next_states = []\n",
    "    actions = []\n",
    "    rewards = []\n",
    "    dones = []\n",
    "    dieds =[]\n",
    "    length=[]\n",
    "    path_r=[]\n",
    "    for i in df1.index:\n",
    "        ob = df1.loc[i,state_features]\n",
    "        r = df1.loc[i,'reward']\n",
    "        iv = df1.loc[i, 'iv_input']\n",
    "        vaso = df1.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        die = df1.loc[i,'died_in_hosp']\n",
    "        if i != df1.index[-1]:\n",
    "            # if not terminal step in trajectory             \n",
    "            if df1.loc[i, 'icustayid'] == df1.loc[i+1, 'icustayid']:\n",
    "                next_state = df1.loc[i + 1, state_features]\n",
    "                done = 0\n",
    "            else:\n",
    "                # trajectory is finished\n",
    "                next_state = np.zeros(len(ob))\n",
    "                done = 1\n",
    "        else:\n",
    "            # last entry in df is the final state of that trajectory\n",
    "            next_state = np.zeros(len(ob))\n",
    "            done = 1\n",
    "        obs.append(ob)\n",
    "        next_states.append(next_state)\n",
    "        actions.append(action)\n",
    "        rewards.append(r)\n",
    "        dones.append(done)\n",
    "        dieds.append(die)\n",
    "        if done == 1 and len(actions)>0:\n",
    "            path = dict({'observations': np.array(obs),'next_observations': np.array(next_states),'actions': np.array(actions),'rewards': np.array(rewards),'terminals': np.array(dones),'dieds':np.array(dieds)})\n",
    "            trajectories.append(path)\n",
    "            path_r.append(sum(rewards))\n",
    "            length.append(len(obs))\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "            dieds = []\n",
    "        elif done == 1:\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "            dieds =[]\n",
    "            \n",
    "    return trajectories,length,path_r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_cql_df,length,path_r = dtdata_val(val_df)\n",
    "with open('my_dt_test_reward2.pkl','wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "    pickle.dump(path_cql_df,f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(path_cql_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def listdata(df1,df2):\n",
    "    trajectories=[]\n",
    "    obs = []\n",
    "    next_states = []\n",
    "    actions = []\n",
    "    rewards = []\n",
    "    dones = []\n",
    "    length=[]\n",
    "    path_r=[]\n",
    "    for i in df1.index:\n",
    "        ob = df1.loc[i,state_features]\n",
    "        r = df1.loc[i,'reward']\n",
    "        iv = df1.loc[i, 'iv_input']\n",
    "        vaso = df1.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        if i != df1.index[-1]:\n",
    "            # if not terminal step in trajectory             \n",
    "            if df1.loc[i, 'icustayid'] == df1.loc[i+1, 'icustayid']:\n",
    "                next_state = df1.loc[i + 1, state_features]\n",
    "                done = 0\n",
    "            else:\n",
    "                # trajectory is finished\n",
    "                next_state = np.zeros(len(ob))\n",
    "                done = 1\n",
    "        else:\n",
    "            # last entry in df is the final state of that trajectory\n",
    "            next_state = np.zeros(len(ob))\n",
    "            done = 1\n",
    "        obs.append(ob)\n",
    "        next_states.append(next_state)\n",
    "        actions.append(action)\n",
    "        rewards.append(r)\n",
    "        dones.append(done)\n",
    "        if done == 1 and len(actions)>9:\n",
    "            path = dict({'observations': np.array(obs),'next_observations': np.array(next_states),'actions': np.array(actions),'rewards': np.array(rewards),'terminals': np.array(dones)})\n",
    "            trajectories.append(path)\n",
    "            path_r.append(sum(rewards))\n",
    "            length.append(len(obs))\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "        elif done == 1:\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "    obs = []\n",
    "    next_states = []\n",
    "    actions = []\n",
    "    rewards = []\n",
    "    dones = []\n",
    "    for i in df2.index:\n",
    "        ob = df2.loc[i,state_features]\n",
    "        r = df2.loc[i,'reward']\n",
    "        iv = df2.loc[i, 'iv_input']\n",
    "        vaso = df2.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        if i != df2.index[-1]:\n",
    "            # if not terminal step in trajectory             \n",
    "            if df2.loc[i, 'icustayid'] == df2.loc[i+1, 'icustayid']:\n",
    "                next_state = df2.loc[i + 1, state_features]\n",
    "                done = 0\n",
    "            else:\n",
    "                # trajectory is finished\n",
    "                next_state = np.zeros(len(ob))\n",
    "                done = 1\n",
    "        else:\n",
    "            # last entry in df is the final state of that trajectory\n",
    "            next_state = np.zeros(len(ob))\n",
    "            done = 1\n",
    "        obs.append(ob)\n",
    "        next_states.append(next_state)\n",
    "        actions.append(action)\n",
    "        rewards.append(r)\n",
    "        dones.append(done)\n",
    "        if done == 1 and len(actions)>9:\n",
    "            path = dict({'observations': np.array(obs),'next_observations': np.array(next_states),'actions': np.array(actions),'rewards': np.array(rewards),'terminals': np.array(dones)})\n",
    "            trajectories.append(path)\n",
    "            path_r.append(sum(rewards))\n",
    "            length.append(len(obs))\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "        elif done == 1:\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "            \n",
    "    return trajectories,length,path_r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_cql_df,length,path_r = listdata(df,val_df)\n",
    "# with open('my_dt_new_df_.pkl','wb') as f:\n",
    "#     pickle.dump(path_cql_df,f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iv = df['iv_input']\n",
    "vaso = df['vaso_input']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.mean(iv),np.mean(vaso))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.max(iv),np.max(vaso))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "pd.Series(iv).plot.hist(range=[0,2],bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "plt.hist(length,10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_val_df = listdata(val_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_test_df = listdata(test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('my_dict_df.pkl','wb') as f:\n",
    "    pickle.dump(path_df,f)\n",
    "with open('my_dict_val_df.pkl','wb') as f:\n",
    "    pickle.dump(path_val_df,f)\n",
    "with open('my_dict_test_df.pkl','wb') as f:\n",
    "    pickle.dump(path_test_df,f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "constant = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "dataset_path = f'./data/my_cql_new_df_10.pkl'\n",
    "with open(dataset_path, 'rb') as f:\n",
    "    trajectories = pickle.load(f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(trajectories[0]['observations'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "phys_df = pd.DataFrame(columns=['0','mort','iv','vaso','ivt','vasot'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reward = []\n",
    "mort = []\n",
    "iv = []\n",
    "vaso = []\n",
    "for i in range(len(trajectories)):\n",
    "    r = trajectories[i]['rewards']\n",
    "    die = trajectories[i]['dieds']\n",
    "    action = trajectories[i]['actions']\n",
    "    for j in range(len(r)):\n",
    "        iv.append(action[j][0]*2000)\n",
    "        vaso.append(action[j][1])\n",
    "        reward.append(r[j])\n",
    "        mort.append(die[j])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "phys_df['0'] = reward\n",
    "phys_df['mort'] = mort\n",
    "phys_df['iv'] = iv\n",
    "phys_df['vaso'] = vaso\n",
    "phys_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vq = np.array([0.08,0.2,0.45])\n",
    "ivq = np.array([48,150,500])\n",
    "phys_df['vasot'][phys_df['vaso'] <= 0.0] = 0\n",
    "phys_df['vasot'][(phys_df['vaso'] > 0.0) & (phys_df['vaso'] < vq[0])] = 1\n",
    "phys_df['vasot'][(phys_df['vaso'] >= vq[0]) & (phys_df['vaso'] < vq[1])] = 2\n",
    "phys_df['vasot'][(phys_df['vaso'] >= vq[1]) & (phys_df['vaso'] < vq[2])] = 3\n",
    "a = phys_df['vaso'] >= vq[2]\n",
    "phys_df['vasot'][a] = 4\n",
    "\n",
    "phys_df['ivt'][phys_df['iv'] == 0.0] = 0\n",
    "phys_df['ivt'][(phys_df['iv'] > 0.0) & (phys_df['iv'] < ivq[0])] = 1\n",
    "phys_df['ivt'][(phys_df['iv'] >= ivq[0]) & (phys_df['iv'] < ivq[1])] = 2\n",
    "phys_df['ivt'][(phys_df['iv'] >= ivq[1]) & (phys_df['iv'] < ivq[2])] = 3\n",
    "a = phys_df['iv'] >= ivq[2]\n",
    "phys_df['ivt'][a] = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import sem\n",
    "bin_medians = []\n",
    "mort = []\n",
    "mort_std = []\n",
    "i = -15\n",
    "while i <= 20:\n",
    "    count =phys_df.loc[(phys_df['0']>i-0.5) & (phys_df['0']<i+0.5)]\n",
    "    try:\n",
    "        res = sum(count['mort'])/float(len(count))\n",
    "        if len(count) >=2:\n",
    "            bin_medians.append(i)\n",
    "            mort.append(res)\n",
    "            mort_std.append(sem(count['mort']))\n",
    "    except ZeroDivisionError:\n",
    "        pass\n",
    "    i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sliding_mean(data_array, window=2):\n",
    "    new_list = []\n",
    "    for i in range(len(data_array)):\n",
    "        indices = range(max(i - window + 1, 0),\n",
    "                        min(i + window + 1, len(data_array)))\n",
    "        avg = 0\n",
    "        for j in indices:\n",
    "            avg += data_array[j]\n",
    "        avg /= float(len(indices))\n",
    "        new_list.append(avg)     \n",
    "    return np.array(new_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6, 4.5))\n",
    "plt.plot(bin_medians, sliding_mean(mort))\n",
    "plt.fill_between(bin_medians, sliding_mean(mort) - 1*sliding_mean(mort_std),  \n",
    "                 sliding_mean(mort) + 1*sliding_mean(mort_std), color='#ADD8E6')\n",
    "plt.grid()\n",
    "plt.xticks(range(-15,20,5))\n",
    "r = [float(i)/10 for i in range(0,11,1)]\n",
    "_ = plt.yticks(r)\n",
    "_ = plt.title(\"Mortality vs Expected Return\", fontsize=15)  \n",
    "_ = plt.ylabel(\"Proportion Mortality\")\n",
    "_ = plt.xlabel(\"Expected Return\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "phy_iv = phys_df['ivt']\n",
    "phy_vaso = phys_df['vasot']\n",
    "hist1, _, _ = np.histogram2d(phy_iv, phy_vaso, bins=5)\n",
    "x_edges = np.arange(-0.5,5)\n",
    "y_edges = np.arange(-0.5,5)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "f, (ax1) = plt.subplots(1, 1, figsize=(16,4))\n",
    "ax1.imshow(np.flipud(hist1), cmap=\"Blues\",extent=[x_edges[0], x_edges[-1],  y_edges[0],y_edges[-1]])\n",
    "\n",
    "\n",
    "ax1.set_xticks(np.arange(0, 5, 1))\n",
    "ax1.set_yticks(np.arange(0, 5, 1))\n",
    "\n",
    "\n",
    "ax1.set_xticklabels(np.arange(0, 5, 1))\n",
    "ax1.set_yticklabels(np.arange(0, 5, 1))\n",
    "\n",
    "\n",
    "ax1.set_xticks(np.arange(-.5, 5, 1), minor=True)\n",
    "ax1.set_yticks(np.arange(-.5, 5, 1), minor=True)\n",
    "\n",
    "\n",
    "ax1.grid(which='minor', color='b', linestyle='-', linewidth=1)\n",
    "\n",
    "\n",
    "im1 = ax1.pcolormesh(x_edges, y_edges, hist1, cmap='Blues')\n",
    "f.colorbar(im1, ax=ax1, label = \"Action counts\")\n",
    "\n",
    "\n",
    "ax1.set_ylabel('IV fluid dose')\n",
    "\n",
    "ax1.set_xlabel('Vaso fluid dose')\n",
    "\n",
    "\n",
    "ax1.set_title(\"Physician policy\")\n",
    "\n",
    "\n",
    "plt.tight_layout()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
