{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import torch\n",
    "import numpy as np\n",
    "from collections import deque\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import numpy as np\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pandas import DataFrame\n",
    "import pickle as pickle\n",
    "import math\n",
    "import copy\n",
    "import tf_slim\n",
    "import tensorflow.compat.v1 as tf\n",
    "tf.disable_v2_behavior()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "train = pd.read_csv('./data/rl_train_data_final_cont_reward2_cost_noauto.csv')\n",
    "val = pd.read_csv(\"./data/rl_val_data_final_cont_reward2_cost_noauto.csv\")\n",
    "test = pd.read_csv('./data/rl_test_data_final_cont_reward2_cost_noauto.csv')\n",
    "\n",
    "train_df = train\n",
    "val_df = val\n",
    "test_df = test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#state_features = [str(i) for i in range(16)]\n",
    "with open('./data/state_features.txt') as f:\n",
    "    state_features = f.read().split()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dtdata_val_50and50(df1):\n",
    "    trajectories=[]\n",
    "    obs = []\n",
    "    next_states = []\n",
    "    actions = []\n",
    "    rewards = []\n",
    "    dones = []\n",
    "    dieds =[]\n",
    "    length=[]\n",
    "    path_r=[]\n",
    "    costs = []\n",
    "    si  = 0\n",
    "    huo = 0\n",
    "    for i in df1.index:\n",
    "        ob = df1.loc[i,state_features]\n",
    "        r = df1.loc[i,'reward']\n",
    "        iv = df1.loc[i, 'iv_input']/2000\n",
    "        vaso = df1.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        die = df1.loc[i,'died_in_hosp']\n",
    "        cost = df1.loc[i,'cost']\n",
    "        if i != df1.index[-1]:\n",
    "            # if not terminal step in trajectory             \n",
    "            if df1.loc[i, 'icustayid'] == df1.loc[i+1, 'icustayid']:\n",
    "                next_state = df1.loc[i + 1, state_features]\n",
    "                done = 0\n",
    "            else:\n",
    "                # trajectory is finished\n",
    "                next_state = np.zeros(len(ob))\n",
    "                done = 1\n",
    "        else:\n",
    "            # last entry in df is the final state of that trajectory\n",
    "            next_state = np.zeros(len(ob))\n",
    "            done = 1\n",
    "        obs.append(ob)\n",
    "        next_states.append(next_state)\n",
    "        actions.append(action)\n",
    "        rewards.append(r)\n",
    "        dones.append(done)\n",
    "        costs.append(cost)\n",
    "        dieds.append(die)\n",
    "        if done == 1 and len(actions)>0:\n",
    "            if die == 0 and huo < 50:\n",
    "                huo+=1\n",
    "                path = dict({'observations': np.array(obs),'next_observations': np.array(next_states),'actions': np.array(actions),'rewards': np.array(rewards),'terminals': np.array(dones),'dieds':np.array(dieds),'costs':np.array(costs)})\n",
    "                trajectories.append(path)\n",
    "                path_r.append(sum(rewards))\n",
    "                length.append(len(obs))\n",
    "            elif die == 1 and si < 50:\n",
    "                si+=1\n",
    "                path = dict({'observations': np.array(obs),'next_observations': np.array(next_states),'actions': np.array(actions),'rewards': np.array(rewards),'terminals': np.array(dones),'dieds':np.array(dieds),'costs':np.array(costs)})\n",
    "                trajectories.append(path)\n",
    "                path_r.append(sum(rewards))\n",
    "                length.append(len(obs))\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "            dieds = []\n",
    "            costs = []\n",
    "            done =0\n",
    "        elif done == 1:\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "            dieds =[]\n",
    "            costs = []\n",
    "            done = 0\n",
    "            \n",
    "    return trajectories,length,path_r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dtdata_val(df1):\n",
    "    trajectories=[]\n",
    "    obs = []\n",
    "    next_states = []\n",
    "    actions = []\n",
    "    rewards = []\n",
    "    dones = []\n",
    "    dieds =[]\n",
    "    length=[]\n",
    "    path_r=[]\n",
    "    costs = []\n",
    "    for i in df1.index:\n",
    "        ob = df1.loc[i,state_features]\n",
    "        r = df1.loc[i,'reward']\n",
    "        iv = df1.loc[i, 'iv_input']/2000\n",
    "        vaso = df1.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        die = df1.loc[i,'died_in_hosp']\n",
    "        cost = df1.loc[i,'cost']\n",
    "        if i != df1.index[-1]:\n",
    "            # if not terminal step in trajectory             \n",
    "            if df1.loc[i, 'icustayid'] == df1.loc[i+1, 'icustayid']:\n",
    "                next_state = df1.loc[i + 1, state_features]\n",
    "                done = 0\n",
    "            else:\n",
    "                # trajectory is finished\n",
    "                next_state = np.zeros(len(ob))\n",
    "                done = 1\n",
    "        else:\n",
    "            # last entry in df is the final state of that trajectory\n",
    "            next_state = np.zeros(len(ob))\n",
    "            done = 1\n",
    "        obs.append(ob)\n",
    "        next_states.append(next_state)\n",
    "        actions.append(action)\n",
    "        rewards.append(r)\n",
    "        dones.append(done)\n",
    "        costs.append(cost)\n",
    "        dieds.append(die)\n",
    "        if done == 1 and len(actions)>0:\n",
    "            path = dict({'observations': np.array(obs),'next_observations': np.array(next_states),'actions': np.array(actions),'rewards': np.array(rewards),'terminals': np.array(dones),'dieds':np.array(dieds),'costs':np.array(costs)})\n",
    "            trajectories.append(path)\n",
    "            path_r.append(sum(rewards))\n",
    "            length.append(len(obs))\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "            dieds = []\n",
    "            costs = []\n",
    "            done =0\n",
    "        elif done == 1:\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "            dieds =[]\n",
    "            costs = []\n",
    "            done = 0\n",
    "            \n",
    "    return trajectories,length,path_r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cdt_data(df):\n",
    "    actions = []\n",
    "    costs = []\n",
    "    next_observations = []\n",
    "    observations = []\n",
    "    rewards = []\n",
    "    terminals = []\n",
    "    dieds = []\n",
    "    for i in df.index:\n",
    "        o = df.loc[i,state_features]\n",
    "        c = df.loc[i,'cost']\n",
    "        r = df.loc[i,'reward']\n",
    "        iv = df.loc[i, 'iv_input']/2000\n",
    "        vaso = df.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        die = df.loc[i,'died_in_hosp']\n",
    "\n",
    "        if i != df.index[-1]:\n",
    "            if df.loc[i,'icustayid'] == df.loc[i+1,'icustayid']:\n",
    "                n_o = df.loc[i+1,state_features]\n",
    "                term = False \n",
    "            else:\n",
    "                n_o = np.zeros(len(o))\n",
    "                term = True\n",
    "        else:\n",
    "            n_o = np.zeros(len(o))\n",
    "            term = True\n",
    "        next_observations.append(n_o)\n",
    "        observations.append(o)\n",
    "        actions.append(action)\n",
    "        rewards.append(r)\n",
    "        terminals.append(term)\n",
    "        costs.append(c)\n",
    "        dieds.append(die)\n",
    "    path = dict({'actions': np.array(actions),'next_observations': np.array(next_observations),'observations': np.array(observations),'rewards': np.array(rewards),'terminals': np.array(terminals),'costs':np.array(costs),'dieds':np.array(dieds)})\n",
    "    return path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cdt_data_val(df):\n",
    "    actions = []\n",
    "    costs = []\n",
    "    next_observations = []\n",
    "    observations = []\n",
    "    rewards = []\n",
    "    terminals = []\n",
    "    dieds = []\n",
    "    path = []\n",
    "    for i in df.index:\n",
    "        o = df.loc[i,state_features]\n",
    "        c = df.loc[i,'cost']\n",
    "        r = df.loc[i,'reward']\n",
    "        iv = df.loc[i, 'iv_input']/2000\n",
    "        vaso = df.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        die = df.loc[i,'died_in_hosp']\n",
    "\n",
    "        if i != df.index[-1]:\n",
    "            if df.loc[i,'icustayid'] == df.loc[i+1,'icustayid']:\n",
    "                n_o = df.loc[i+1,state_features]\n",
    "                term = False \n",
    "            else:\n",
    "                n_o = np.zeros(len(o))\n",
    "                term = True\n",
    "        else:\n",
    "            n_o = np.zeros(len(o))\n",
    "            term = True\n",
    "        \n",
    "        if term == True:\n",
    "            next_observations.append(n_o)\n",
    "            observations.append(o)\n",
    "            actions.append(action)\n",
    "            rewards.append(r)\n",
    "            terminals.append(term)\n",
    "            costs.append(c)\n",
    "            dieds.append(die)\n",
    "            path.append(dict({'actions': np.array(actions),\n",
    "                              'next_observations': np.array(next_observations),\n",
    "                              'observations': np.array(observations),'rewards': np.array(rewards),\n",
    "                              'terminals': np.array(terminals),'costs':np.array(costs),'dieds':np.array(dieds)}))\n",
    "            actions = []\n",
    "            costs = []\n",
    "            next_observations = []\n",
    "            observations = []\n",
    "            rewards = []\n",
    "            terminals = []\n",
    "            dieds = []\n",
    "        else:\n",
    "            next_observations.append(n_o)\n",
    "            observations.append(o)\n",
    "            actions.append(action)\n",
    "            rewards.append(r)\n",
    "            terminals.append(term)\n",
    "            costs.append(c)\n",
    "            dieds.append(die)\n",
    "    \n",
    "    return path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# path_train = get_cdt_data(train_df)\n",
    "# #path_test = get_cdt_data(test_df)\n",
    "# path_val = get_cdt_data(val_df)\n",
    "\n",
    "path_val_p,length,path_r = dtdata_val(val_df)\n",
    "# path_train_p,length,path_r = dtdata_val(train_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_die_alive_100(path,seed):\n",
    "    num_trajectories = len(path)\n",
    "    np.random.seed(seed)\n",
    "    batch_inds = np.random.choice(\n",
    "            np.arange(num_trajectories),\n",
    "            size=10000,\n",
    "            replace=True,\n",
    "        )\n",
    "    paths = []\n",
    "    die = 0 \n",
    "    nodie = 0\n",
    "    num = 50\n",
    "    for i in range(num_trajectories):\n",
    "        if die < num and path[batch_inds[i]]['dieds'][0] == 1: \n",
    "            traj = path[batch_inds[i]]\n",
    "            paths.append(traj)\n",
    "            die += 1\n",
    "        if nodie< num and path[batch_inds[i]]['dieds'][0] == 0:\n",
    "            traj = path[batch_inds[i]]\n",
    "            paths.append(traj)\n",
    "            nodie += 1\n",
    "        if die > num and nodie > num:\n",
    "            break\n",
    "    return paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据IV/2000，reward和cost都在[0,1]\n",
    "with open('./my_cdt_data_val_noauto_p.pkl','wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "    pickle.dump(path_val_p,f)\n",
    "# 数据IV/2000，reward和cost都在[0,1]\n",
    "with open('./my_cdt_data_train_noauto_p.pkl','wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "    pickle.dump(path_train_p,f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./my_cdt_data_train_noauto.pkl','wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "    pickle.dump(path_train,f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./my_cdt_data_val_noauto.pkl','wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "    pickle.dump(path_val,f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
