{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import torch\n",
    "import numpy as np\n",
    "from collections import deque\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import numpy as np\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pandas import DataFrame\n",
    "import pickle as pickle\n",
    "import math\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \n",
    "train = pd.read_csv('/home/ubuntu/Data/rl_train_data_final_cont_reward2_cost5.csv')\n",
    "val = pd.read_csv(\"/home/ubuntu/Data/rl_val_data_final_cont_reward2_cost5.csv\")\n",
    "\n",
    "\n",
    "train_df = train\n",
    "val_df = val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#state_features = [str(i) for i in range(16)]\n",
    "with open('/home/ubuntu/Data/state_features.txt') as f:\n",
    "    state_features = f.read().split()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dtdata_val_50and50(df1):\n",
    "    trajectories=[]\n",
    "    obs = []\n",
    "    next_states = []\n",
    "    actions = []\n",
    "    rewards = []\n",
    "    dones = []\n",
    "    dieds =[]\n",
    "    length=[]\n",
    "    path_r=[]\n",
    "    costs = []\n",
    "    si  = 0\n",
    "    huo = 0\n",
    "    for i in df1.index:\n",
    "        ob = df1.loc[i,state_features]\n",
    "        r = df1.loc[i,'reward']\n",
    "        iv = df1.loc[i, 'iv_input']/2000\n",
    "        vaso = df1.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        die = df1.loc[i,'died_in_hosp']\n",
    "        cost = df1.loc[i,'cost']\n",
    "        if i != df1.index[-1]:\n",
    "            # if not terminal step in trajectory             \n",
    "            if df1.loc[i, 'icustayid'] == df1.loc[i+1, 'icustayid']:\n",
    "                next_state = df1.loc[i + 1, state_features]\n",
    "                done = 0\n",
    "            else:\n",
    "                # trajectory is finished\n",
    "                next_state = np.zeros(len(ob))\n",
    "                done = 1\n",
    "        else:\n",
    "            # last entry in df is the final state of that trajectory\n",
    "            next_state = np.zeros(len(ob))\n",
    "            done = 1\n",
    "        obs.append(ob)\n",
    "        next_states.append(next_state)\n",
    "        actions.append(action)\n",
    "        rewards.append(r)\n",
    "        dones.append(done)\n",
    "        costs.append(cost)\n",
    "        dieds.append(die)\n",
    "        if done == 1 and len(actions)>0:\n",
    "            if die == 0 and huo < 50:\n",
    "                huo+=1\n",
    "                path = dict({'observations': np.array(obs),'next_observations': np.array(next_states),'actions': np.array(actions),'rewards': np.array(rewards),'terminals': np.array(dones),'dieds':np.array(dieds),'costs':np.array(costs)})\n",
    "                trajectories.append(path)\n",
    "                path_r.append(sum(rewards))\n",
    "                length.append(len(obs))\n",
    "            elif die == 1 and si < 50:\n",
    "                si+=1\n",
    "                path = dict({'observations': np.array(obs),'next_observations': np.array(next_states),'actions': np.array(actions),'rewards': np.array(rewards),'terminals': np.array(dones),'dieds':np.array(dieds),'costs':np.array(costs)})\n",
    "                trajectories.append(path)\n",
    "                path_r.append(sum(rewards))\n",
    "                length.append(len(obs))\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "            dieds = []\n",
    "            costs = []\n",
    "            done =0\n",
    "        elif done == 1:\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "            dieds =[]\n",
    "            costs = []\n",
    "            done = 0\n",
    "            \n",
    "    return trajectories,length,path_r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dtdata_val(df1):\n",
    "    trajectories=[]\n",
    "    obs = []\n",
    "    next_states = []\n",
    "    actions = []\n",
    "    rewards = []\n",
    "    dones = []\n",
    "    dieds =[]\n",
    "    length=[]\n",
    "    path_r=[]\n",
    "    costs = []\n",
    "    for i in df1.index:\n",
    "        ob = df1.loc[i,state_features]\n",
    "        r = df1.loc[i,'reward']\n",
    "        iv = df1.loc[i, 'iv_input']/2000\n",
    "        vaso = df1.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        die = df1.loc[i,'died_in_hosp']\n",
    "        cost = df1.loc[i,'cost']\n",
    "        if i != df1.index[-1]:\n",
    "            # if not terminal step in trajectory             \n",
    "            if df1.loc[i, 'icustayid'] == df1.loc[i+1, 'icustayid']:\n",
    "                next_state = df1.loc[i + 1, state_features]\n",
    "                done = 0\n",
    "            else:\n",
    "                # trajectory is finished\n",
    "                next_state = np.zeros(len(ob))\n",
    "                done = 1\n",
    "        else:\n",
    "            # last entry in df is the final state of that trajectory\n",
    "            next_state = np.zeros(len(ob))\n",
    "            done = 1\n",
    "        obs.append(ob)\n",
    "        next_states.append(next_state)\n",
    "        actions.append(action)\n",
    "        rewards.append(r)\n",
    "        dones.append(done)\n",
    "        costs.append(cost)\n",
    "        dieds.append(die)\n",
    "        if done == 1 and len(actions)>0:\n",
    "            path = dict({'observations': np.array(obs),'next_observations': np.array(next_states),'actions': np.array(actions),'rewards': np.array(rewards),'terminals': np.array(dones),'dieds':np.array(dieds),'costs':np.array(costs)})\n",
    "            trajectories.append(path)\n",
    "            path_r.append(sum(rewards))\n",
    "            length.append(len(obs))\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "            dieds = []\n",
    "            costs = []\n",
    "            done =0\n",
    "        elif done == 1:\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "            dieds =[]\n",
    "            costs = []\n",
    "            done = 0\n",
    "            \n",
    "    return trajectories,length,path_r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cdt_data(df):\n",
    "    actions = []\n",
    "    costs = []\n",
    "    next_observations = []\n",
    "    observations = []\n",
    "    rewards = []\n",
    "    terminals = []\n",
    "    dieds = []\n",
    "    for i in df.index:\n",
    "        o = df.loc[i,state_features]\n",
    "        c = df.loc[i,'cost']\n",
    "        r = df.loc[i,'reward']\n",
    "        iv = df.loc[i, 'iv_input']/2000\n",
    "        vaso = df.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        die = df.loc[i,'died_in_hosp']\n",
    "\n",
    "        if i != df.index[-1]:\n",
    "            if df.loc[i,'icustayid'] == df.loc[i+1,'icustayid']:\n",
    "                n_o = df.loc[i+1,state_features]\n",
    "                term = False \n",
    "            else:\n",
    "                n_o = np.zeros(len(o))\n",
    "                term = True\n",
    "        else:\n",
    "            n_o = np.zeros(len(o))\n",
    "            term = True\n",
    "        next_observations.append(n_o)\n",
    "        observations.append(o)\n",
    "        actions.append(action)\n",
    "        rewards.append(r)\n",
    "        terminals.append(term)\n",
    "        costs.append(c)\n",
    "        dieds.append(die)\n",
    "    path = dict({'actions': np.array(actions),'next_observations': np.array(next_observations),'observations': np.array(observations),'rewards': np.array(rewards),'terminals': np.array(terminals),'costs':np.array(costs),'dieds':np.array(dieds)})\n",
    "    return path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cdt_data_val(df):\n",
    "    actions = []\n",
    "    costs = []\n",
    "    next_observations = []\n",
    "    observations = []\n",
    "    rewards = []\n",
    "    terminals = []\n",
    "    dieds = []\n",
    "    path = []\n",
    "    for i in df.index:\n",
    "        o = df.loc[i,state_features]\n",
    "        c = df.loc[i,'cost']\n",
    "        r = df.loc[i,'reward']\n",
    "        iv = df.loc[i, 'iv_input']/2000\n",
    "        vaso = df.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        die = df.loc[i,'died_in_hosp']\n",
    "\n",
    "        if i != df.index[-1]:\n",
    "            if df.loc[i,'icustayid'] == df.loc[i+1,'icustayid']:\n",
    "                n_o = df.loc[i+1,state_features]\n",
    "                term = False \n",
    "            else:\n",
    "                n_o = np.zeros(len(o))\n",
    "                term = True\n",
    "        else:\n",
    "            n_o = np.zeros(len(o))\n",
    "            term = True\n",
    "        \n",
    "        if term == True:\n",
    "            next_observations.append(n_o)\n",
    "            observations.append(o)\n",
    "            actions.append(action)\n",
    "            rewards.append(r)\n",
    "            terminals.append(term)\n",
    "            costs.append(c)\n",
    "            dieds.append(die)\n",
    "            path.append(dict({'actions': np.array(actions),\n",
    "                              'next_observations': np.array(next_observations),\n",
    "                              'observations': np.array(observations),'rewards': np.array(rewards),\n",
    "                              'terminals': np.array(terminals),'costs':np.array(costs),'dieds':np.array(dieds)}))\n",
    "            actions = []\n",
    "            costs = []\n",
    "            next_observations = []\n",
    "            observations = []\n",
    "            rewards = []\n",
    "            terminals = []\n",
    "            dieds = []\n",
    "        else:\n",
    "            next_observations.append(n_o)\n",
    "            observations.append(o)\n",
    "            actions.append(action)\n",
    "            rewards.append(r)\n",
    "            terminals.append(term)\n",
    "            costs.append(c)\n",
    "            dieds.append(die)\n",
    "    \n",
    "    return path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_train = get_cdt_data(train_df)\n",
    "#path_test = get_cdt_data(test_df)\n",
    "path_val = get_cdt_data(val_df)\n",
    "\n",
    "# path_val_p,length,path_r = dtdata_val(val_df)\n",
    "# path_train_p,length,path_r = dtdata_val(train_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_die_alive_100(path,seed):\n",
    "    num_trajectories = len(path)\n",
    "    np.random.seed(seed)\n",
    "    batch_inds = np.random.choice(\n",
    "            np.arange(num_trajectories),\n",
    "            size=10000,\n",
    "            replace=True,\n",
    "        )\n",
    "    paths = []\n",
    "    die = 0 \n",
    "    nodie = 0\n",
    "    num = 100\n",
    "    print(batch_inds)\n",
    "    for i in range(10000):\n",
    "        if die < num and path[batch_inds[i]]['dieds'][0] == 1: \n",
    "            traj = path[batch_inds[i]]\n",
    "            paths.append(traj)\n",
    "            die += 1\n",
    "        if nodie< num and path[batch_inds[i]]['dieds'][0] == 0:\n",
    "            traj = path[batch_inds[i]]\n",
    "            paths.append(traj)\n",
    "            nodie += 1\n",
    "        if die == num and nodie == num:\n",
    "            break\n",
    "    \n",
    "    print(die,nodie)\n",
    "    if die >= num and nodie >= num:\n",
    "        return True,paths\n",
    "    else:\n",
    "        return False,paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1283\n",
      "189\n"
     ]
    }
   ],
   "source": [
    "print(len(path_val_p))\n",
    "t = 0\n",
    "for i in range(len(path_val_p)):\n",
    "    if path_val_p[i]['dieds'][0] == 1:\n",
    "        t+=1\n",
    "print(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 684  559 1216 ...  182  922 1081]\n",
      "100 100\n",
      "[1149  527 1147 ...  613    8  728]\n",
      "100 100\n",
      "[ 474  271 1247 ...  138 1200  387]\n",
      "100 100\n",
      "[ 421  941  500 ...  793  574 1009]\n",
      "100 100\n",
      "[219   7 165 ... 496  80 290]\n",
      "100 100\n",
      "[109 132  70 ... 676 765  81]\n",
      "100 100\n",
      "[ 205 1089  399 ...  529  386  279]\n",
      "100 100\n",
      "[ 214 1138 1083 ...  846  885 1192]\n",
      "100 100\n",
      "[1199 1212  934 ...  859  617  230]\n",
      "100 100\n",
      "[669 671 679 ... 351  87 438]\n",
      "100 100\n",
      "[ 792   79  350 ...  986 1246  885]\n",
      "100 100\n"
     ]
    }
   ],
   "source": [
    "seed = [0,10,20,30,40,50,60,70,80,90,100]\n",
    "for i in seed:\n",
    "    t,path_100 = random_die_alive_100(path_val_p,i)\n",
    "    if t == False:\n",
    "        print(\"stop\",i)\n",
    "        break\n",
    "    fail = f'/home/fn/OSRL/examples/train/my_cdt_data_val_noauto_seed200{i}.pkl'\n",
    "    with open(fail,'wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "        pickle.dump(path_100,f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据IV/2000，reward和cost都在[0,1]\n",
    "with open('/home/fn/OSRL/examples/train/my_cdt_data_val_noauto_p.pkl','wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "    pickle.dump(path_val_p,f)\n",
    "# 数据IV/2000，reward和cost都在[0,1]\n",
    "with open('/home/fn/OSRL/examples/train/my_cdt_data_train_noauto_p.pkl','wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "    pickle.dump(path_train_p,f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/home/ubuntu/OSRL/examples/data/my_cdt_data_train_noauto_c5.pkl','wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "    pickle.dump(path_train,f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/home/ubuntu/OSRL/examples/data/my_cdt_data_val_noauto_c5.pkl','wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "    pickle.dump(path_val,f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
