{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import asdict, dataclass\n",
    "from typing import Any, DefaultDict, Dict, List, Optional, Tuple\n",
    "import sys\n",
    "sys.path.append(\"../Our-Model/CDT\")\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from CDT.osrl_cdt.algorithms import CDT, CDTTrainer\n",
    "from CDT.osrl_cdt.common.exp_util import load_config_and_model, seed_all\n",
    "import pickle\n",
    "from CDT.osrl_cdt.algorithms import CDT, CDTTrainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def get_val_path(path_,batch_size=200):\n",
    "    num_trajectories = len(path_)\n",
    "    \n",
    "    paths = []\n",
    "    die = 0 \n",
    "    nodie = 0\n",
    "    num = batch_size/2\n",
    "    for i in range(num_trajectories):\n",
    "        if die < num and path_[i]['dieds'][0] == 1: \n",
    "            traj = path_[i]\n",
    "            paths.append(traj)\n",
    "            die += 1\n",
    "        if nodie< num and path_[i]['dieds'][0] == 0:\n",
    "            traj = path_[i]\n",
    "            paths.append(traj)\n",
    "            nodie += 1\n",
    "        if die > num and nodie > num:\n",
    "            break\n",
    "    actions = []\n",
    "    next_observations = []\n",
    "    observations = []\n",
    "    rewards = []\n",
    "    dieds = []\n",
    "    costs = []\n",
    "    terminals = []\n",
    "    #print(len(paths))\n",
    "    for i in range(len(paths)):\n",
    "        path = paths[i]\n",
    "        for j in range(len(path['dieds'])):\n",
    "            path['actions'][j][0] = path['actions'][j][0]\n",
    "            actions.append(path['actions'][j])\n",
    "            next_observations.append(path['next_observations'][j])\n",
    "            rewards.append(path['rewards'][j])\n",
    "            dieds.append(path['dieds'][j])\n",
    "            costs.append(path['costs'][j])\n",
    "            observations.append(path['observations'][j])\n",
    "            terminals.append(path['terminals'][j])\n",
    "\n",
    "    out = dict({'actions': np.array(actions),'next_observations': \n",
    "                 np.array(next_observations),'observations': np.array(observations),\n",
    "                 'rewards': np.array(rewards),'terminals': np.array(terminals),'costs':np.array(costs),\n",
    "                 'dieds':np.array(dieds)})\n",
    "    print(len((paths)))\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "def get_random1000(path_,seed):\n",
    "    num_trajectories = len(path_['dieds'])\n",
    "    np.random.seed(seed)\n",
    "    batch_inds = np.random.choice(\n",
    "            np.arange(num_trajectories),\n",
    "            size=num_trajectories,\n",
    "            replace=True\n",
    "        )\n",
    "\n",
    "    actions = []\n",
    "    next_observations = []\n",
    "    observations = []\n",
    "    rewards = []\n",
    "    dieds = []\n",
    "    costs = []\n",
    "    terminals = []\n",
    "    #print(len(paths))\n",
    "    die_num, alive_num = 0,0\n",
    "    for i in range(len(batch_inds)):\n",
    "        j = batch_inds[i]\n",
    "        if die_num < 500 and path_['dieds'][j] == 1:\n",
    "            actions.append(path_['actions'][j])\n",
    "            next_observations.append(path_['next_observations'][j])\n",
    "            rewards.append(path_['rewards'][j])\n",
    "            dieds.append(path_['dieds'][j])\n",
    "            costs.append(path_['costs'][j])\n",
    "            observations.append(path_['observations'][j])\n",
    "            terminals.append(path_['terminals'][j])\n",
    "            die_num+=1\n",
    "        elif alive_num<500 and path_['dieds'][j] == 0:\n",
    "            actions.append(path_['actions'][j])\n",
    "            next_observations.append(path_['next_observations'][j])\n",
    "            rewards.append(path_['rewards'][j])\n",
    "            dieds.append(path_['dieds'][j])\n",
    "            costs.append(path_['costs'][j])\n",
    "            observations.append(path_['observations'][j])\n",
    "            terminals.append(path_['terminals'][j])\n",
    "            alive_num+=1\n",
    "        if die_num + alive_num == 1000:\n",
    "            break\n",
    "\n",
    "    out = dict({'actions': np.array(actions),'next_observations': \n",
    "                 np.array(next_observations),'observations': np.array(observations),\n",
    "                 'rewards': np.array(rewards),'terminals': np.array(terminals),'costs':np.array(costs),\n",
    "                 'dieds':np.array(dieds)})\n",
    "    if die_num+alive_num <1000:\n",
    "        print(num_trajectories,die_num,alive_num)\n",
    "        return True,out\n",
    "    return False,out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval(returns,costs,seed):\n",
    "\n",
    "    \n",
    "    rets = returns\n",
    "    costs = costs\n",
    "    eval_episodes = 1\n",
    "    # assert len(rets) == len(\n",
    "    #     costs\n",
    "    # ), f\"The length of returns {len(rets)} should be equal to costs {len(costs)}!\"\n",
    "    \n",
    "    cdt_actions = []\n",
    "    phy_actions = []\n",
    "    dieds = []\n",
    "    val_df = []\n",
    "    t = True\n",
    "    for target_ret, target_cost in zip(rets, costs):\n",
    "        seed_all(seed)\n",
    "        while t:\n",
    "            t,data_v = get_random1000(data_val,seed)\n",
    "        agent_action,phy_action,action_ems = trainer.evaluate(data_v,eval_episodes,\n",
    "                                             target_ret * 0.1,\n",
    "                                             target_cost * 1)\n",
    "        \n",
    "        cdt_actions=agent_action\n",
    "        phy_actions=phy_action\n",
    "        dieds=data_v['dieds']\n",
    "        #print(\"target ret:\",target_ret,\"target cost:\",target_cost)\n",
    "        val_df=data_v\n",
    "    return val_df,cdt_actions,phy_actions,dieds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_term(num,term,actions_a,actions_p,dieds):\n",
    "    t = 0\n",
    "    a = []\n",
    "    p = []\n",
    "    die = -1\n",
    "    for i in range(len(term)):\n",
    "        if term[i]==1:\n",
    "            t+=1\n",
    "        if t==num:\n",
    "            a.append(actions_a[i])\n",
    "            p.append(actions_p[i])\n",
    "            die = dieds[i]\n",
    "        if t>num:\n",
    "            break\n",
    "    return a,p,die"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_iv_vaso(batch_size,actions_a,actions_p,die,term):\n",
    "    iv_diff = []\n",
    "    diff = []\n",
    "    vaso_diff = []\n",
    "    dies = []\n",
    "    actions_a = actions_a[0]\n",
    "    actions_p = actions_p[0]\n",
    "    print(len(actions_a))\n",
    "    print(len(actions_p))\n",
    "    for j in range(len(actions_p)):\n",
    "        w_a = actions_a[j]\n",
    "        w_p= actions_p[j]\n",
    "        #print(w_a[0],w_p[0])\n",
    "        iv_diff.append((torch.mean((torch.tensor(w_a[0])-torch.tensor(w_p[0]))**2)).detach().cpu().item())\n",
    "        vaso_diff.append((torch.mean((torch.tensor(w_a[1])-torch.tensor(w_p[1]))**2)).detach().cpu().item())\n",
    "        diff.append((torch.mean((torch.tensor([w_a[0],w_a[1]])-torch.tensor([w_p[0],w_p[1]]))**2)).detach().cpu().item())\n",
    "        dies.append(die[j])\n",
    "    return iv_diff,vaso_diff,diff,dies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import random\n",
    "def get_gaps(returns,costs):\n",
    "    seed = [0,10,20,30]\n",
    "    ac_iv = []\n",
    "    ac_vaso = []\n",
    "    ac_diff = []\n",
    "    for i in seed:\n",
    "        val_df,cdt_actions,phy_actions,dieds = eval(returns,costs,i)\n",
    "        # time = 0\n",
    "        # cdt_actions = cdt_actions[time][0]\n",
    "        # phy_actions = phy_actions[time][0]\n",
    "        # dieds = dieds[time]\n",
    "        term = val_df['terminals']\n",
    "        iv_diff,vaso_diff,diff,die=get_iv_vaso(1000,cdt_actions,phy_actions,dieds,term)\n",
    "        # iv_diff_,vaso_diff_,diff_,die_=get_iv_vaso(200,cdt_actions,phy_actions,dieds,term)\n",
    "\n",
    "        # die_0 = [i for i, value in enumerate(die_) if value==0]\n",
    "        # die_1 = [i for i, value in enumerate(die_) if value==1]\n",
    "\n",
    "        # random.seed(i)\n",
    "\n",
    "        # selected_0 = random.sample(die_0,500)\n",
    "        # selected_1 = random.sample(die_1,500)\n",
    "\n",
    "        # iv_diff =[iv_diff_[i] for i in selected_0] + [iv_diff_[i] for i in selected_1]\n",
    "\n",
    "        # vaso_diff =[vaso_diff_[i] for i in selected_0] + [vaso_diff_[i] for i in selected_1]\n",
    "\n",
    "        # diff =[diff_[i] for i in selected_0] + [diff_[i] for i in selected_1]\n",
    "\n",
    "        # die =[die_[i] for i in selected_0] + [die_[i] for i in selected_1]\n",
    "\n",
    "        # print(sum(die))\n",
    "\n",
    "        diff_df = pd.DataFrame(columns=['iv','vaso','diff','die'])\n",
    "        diff_df['iv'] = iv_diff\n",
    "        diff_df['vaso'] = vaso_diff\n",
    "        diff_df['diff'] = diff\n",
    "        diff_df['die'] = die\n",
    "        l = ['iv','vaso','diff']\n",
    "        #print(\"seed:\",i,\"_CDT排序正确的比率\")\n",
    "        for j in range(len(l)):\n",
    "            df = diff_df.sort_values(l[j])\n",
    "            acc = 0\n",
    "            num = 0\n",
    "            for t in df.index:\n",
    "                if df.loc[t,'die']==0:\n",
    "                    acc=acc+1\n",
    "                num+=1\n",
    "                if num>=500:\n",
    "                    break\n",
    "            acc = acc/num\n",
    "            #print(\"num:\",num)\n",
    "            if j == 0:\n",
    "                ac_iv.append(acc)\n",
    "            elif j == 1:\n",
    "                ac_vaso.append(acc)\n",
    "            else:\n",
    "                ac_diff.append(acc)\n",
    "            #print(diff_df)\n",
    "    return np.mean(ac_iv),np.std(ac_iv),np.mean(ac_vaso),np.std(ac_vaso),np.mean(ac_diff),np.std(ac_diff),diff_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_val_some_path(path,batch_size=64,seed=10):\n",
    "    num_trajectories = len(path['rewards'])\n",
    "    np.random.seed(seed)\n",
    "    batch_inds = np.random.choice(\n",
    "            np.arange(num_trajectories),\n",
    "            size=batch_size,\n",
    "            replace=True,\n",
    "        )\n",
    "    paths = []\n",
    "\n",
    "    actions = []\n",
    "    costs = []\n",
    "    next_observations = []\n",
    "    observations = []\n",
    "    rewards = []\n",
    "    terminals = []\n",
    "    dieds = []\n",
    "\n",
    "    for i in range(batch_size):\n",
    "        actions.append(path['actions'][batch_inds[i]])\n",
    "        next_observations.append(path['next_observations'][batch_inds[i]])\n",
    "        observations.append(path['observations'][batch_inds[i]])\n",
    "        terminals.append(path['terminals'][batch_inds[i]])\n",
    "        rewards.append(path['rewards'][batch_inds[i]])\n",
    "        dieds.append(path['dieds'][batch_inds[i]])\n",
    "        costs.append(path['costs'][batch_inds[i]])\n",
    "\n",
    "\n",
    "    paths = dict({'actions': np.array(actions),'next_observations': \n",
    "                 np.array(next_observations),'observations': np.array(observations),\n",
    "                 'rewards': np.array(rewards),'terminals': np.array(terminals),'costs':np.array(costs),\n",
    "                 'dieds':np.array(dieds)})\n",
    "        \n",
    "    return paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "#path  = \"./My_model/cdt_without_atten.pt\"\n",
    "#path =f'./Mymodel/cdt_c6.pt'\n",
    "path = \"./My_model/cdt3.pt\"\n",
    "noise_scale = None\n",
    "eval_episodes = 20\n",
    "best = False\n",
    "device = \"cpu\"\n",
    "threads = 4\n",
    "state_dim = 48\n",
    "action_dim = 2\n",
    "max_action = [1.0,1.0]\n",
    "    \n",
    "if device == \"cpu\":\n",
    "    torch.set_num_threads(threads)\n",
    "\n",
    "target_entropy = -2\n",
    "\n",
    "dataset_path_val = f'../CDT/examples_my_cost/data/cost0_data_val.pkl'\n",
    "with open(dataset_path_val,'rb') as f:\n",
    "        data_val = pickle.load(f)\n",
    "# model & optimizer & scheduler setup\n",
    "cdt_model = CDT(\n",
    "        state_dim=48,\n",
    "        action_dim=2,\n",
    "        max_action=[1,1],\n",
    "        embedding_dim=128,\n",
    "        seq_len=10,\n",
    "        episode_len=300,\n",
    "        num_layers=3,\n",
    "        num_heads=8,\n",
    "        attention_dropout=0.1,\n",
    "        residual_dropout=0.1,\n",
    "        embedding_dropout=0.1,\n",
    "        time_emb=True,\n",
    "        use_rew=True,\n",
    "        use_cost=True,\n",
    "        cost_transform=True,\n",
    "        add_cost_feat=False,\n",
    "        mul_cost_feat=False,\n",
    "        cat_cost_feat=False,\n",
    "        action_head_layers=1,\n",
    "        cost_prefix=False,\n",
    "        stochastic=True,\n",
    "        init_temperature=0.1,\n",
    "        target_entropy=target_entropy,\n",
    "    )\n",
    "cdt_model.load_state_dict(torch.load(path))\n",
    "cdt_model.to(device)\n",
    "\n",
    "trainer = CDTTrainer(cdt_model,\n",
    "                    #costnet,\n",
    "                    reward_scale=0.1,\n",
    "                    cost_scale=1,\n",
    "                    cost_reverse=False,\n",
    "                    device=device)\n",
    "\n",
    "target_ret = 24\n",
    "target_cost = 0\n",
    "eval_episodes = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "mort vs diff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_iv_vaso_all(action1,action2):\n",
    "    l = len(action1)\n",
    "    iv1,iv2,vaso1,vaso2 = [],[],[],[]\n",
    "    for i in range(l):\n",
    "        iv1.append(action1[i][0])\n",
    "        iv2.append(action2[i][0])\n",
    "        vaso1.append(action1[i][1])\n",
    "        vaso2.append(action2[i][1])\n",
    "    return iv1,iv2,vaso1,vaso2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2f147a8b04994b6989e72a0db92cfb5c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Evaluating...:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "seed = 23\n",
    "cdt_actions = []\n",
    "phy_actions = []\n",
    "dieds = []\n",
    "val_df = []\n",
    "terminals = []\n",
    "\n",
    "agent_action,phy_action,action_ems = trainer.evaluate(data_val,eval_episodes,\n",
    "                                        target_ret * 0.1,\n",
    "                                        target_cost * 1)\n",
    "        \n",
    "cdt_actions.append(agent_action)\n",
    "phy_actions.append(phy_action)\n",
    "dieds.append(data_val['dieds'])\n",
    "val_df.append(data_val)\n",
    "terminals.append(data_val['terminals'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_state(state):\n",
    "    # first_one_index = torch.argmax(mask, dim=1)\n",
    "    # a_masked = [row[index:] for row, index in zip(state, first_one_index)]\n",
    "    with open('./data/state_features copy.txt') as f:\n",
    "        state_features = f.read().split()\n",
    "    df_agent_state = pd.DataFrame(columns=state_features)\n",
    "    for i in range(len(state_features)):\n",
    "        t = [state[j][i] for j in range(len(state))]\n",
    "        df_agent_state[state_features[i]] = t\n",
    "    return df_agent_state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_agent_state = get_state(data_val['observations'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "cdt_actions = agent_action[0]\n",
    "phy_actions = phy_action[0]\n",
    "dieds = data_val['dieds']\n",
    "term = data_val['terminals']\n",
    "\n",
    "iv1,iv2,vaso1,vaso2 =get_iv_vaso_all(cdt_actions,phy_actions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_agent_state['iv_cdt'] = iv1\n",
    "df_agent_state['vaso_cdt'] = vaso1 \n",
    "df_agent_state['iv_phy'] = iv2\n",
    "df_agent_state['vaso_phy'] = vaso2 \n",
    "df_agent_state['die'] = data_val['dieds']\n",
    "df_agent_state['term'] = data_val['terminals']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "df_log_path = f'../Process_Expert_Data/df_log.csv'\n",
    "df_norm_path = f'../Process_Expert_Data/df_norm.csv'\n",
    "df_log = pd.read_csv(df_log_path)\n",
    "df_norm = pd.read_csv(df_norm_path)\n",
    "binary_fields = ['gender','mechvent','re_admission']\n",
    "norm_fields= ['age','Weight_kg','GCS','HR','SysBP','MeanBP','DiaBP','RR','Temp_C','FiO2_1',\n",
    "    'Potassium','Sodium','Chloride','Glucose','Magnesium','Calcium',\n",
    "    'Hb','WBC_count','Platelets_count','PTT','PT','Arterial_pH','paO2','paCO2',\n",
    "    'Arterial_BE','HCO3','Arterial_lactate','SOFA','SIRS','Shock_Index',\n",
    "    'PaO2_FiO2', 'elixhauser', 'Albumin', u'CO2_mEqL', 'Ionised_Ca']\n",
    "log_fields = ['SpO2','BUN','Creatinine','SGOT','SGPT','Total_bili','INR',\n",
    "            'output_total','output_4hourly', 'bloc']\n",
    "import copy\n",
    "scalable_fields = copy.deepcopy(binary_fields)\n",
    "scalable_fields.extend(norm_fields)\n",
    "scalable_fields.extend(log_fields)\n",
    "dfminmax = pd.read_csv('../Process_Expert_Data/dfminmax.csv')\n",
    "for col in scalable_fields:\n",
    "    minimum = dfminmax.loc[0,col]\n",
    "    maximum = dfminmax.loc[1,col]\n",
    "    df_agent_state[col] = df_agent_state[col]*(maximum-minimum)+minimum\n",
    "\n",
    "# normalise binary fields\n",
    "df_agent_state[binary_fields] = df_agent_state[binary_fields] + 0.5 \n",
    "# normal distn fields\n",
    "for item in norm_fields:\n",
    "    av = df_norm.loc[0,item]\n",
    "    std = df_norm.loc[1,item]\n",
    "    df_agent_state[item] = df_agent_state[item]*std + av\n",
    "# log normal fields、\n",
    "for item in log_fields:\n",
    "    av = df_log.loc[0,item]\n",
    "    std = df_log.loc[1,item]\n",
    "    df_agent_state[item] = df_agent_state[item]*std + av\n",
    "# 修改\n",
    "df_agent_state[log_fields] = np.exp(df_agent_state[log_fields])-0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_agent_state['iv_cdt']*=2000\n",
    "df_agent_state['iv_phy']*=2000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Albumin</th>\n",
       "      <th>Arterial_BE</th>\n",
       "      <th>Arterial_lactate</th>\n",
       "      <th>Arterial_pH</th>\n",
       "      <th>BUN</th>\n",
       "      <th>CO2_mEqL</th>\n",
       "      <th>Calcium</th>\n",
       "      <th>Chloride</th>\n",
       "      <th>Creatinine</th>\n",
       "      <th>DiaBP</th>\n",
       "      <th>...</th>\n",
       "      <th>paCO2</th>\n",
       "      <th>paO2</th>\n",
       "      <th>re_admission</th>\n",
       "      <th>bloc</th>\n",
       "      <th>iv_cdt</th>\n",
       "      <th>vaso_cdt</th>\n",
       "      <th>iv_phy</th>\n",
       "      <th>vaso_phy</th>\n",
       "      <th>die</th>\n",
       "      <th>term</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4.0</td>\n",
       "      <td>5.999998e+00</td>\n",
       "      <td>1.0</td>\n",
       "      <td>7.310000</td>\n",
       "      <td>6.000001</td>\n",
       "      <td>27.000000</td>\n",
       "      <td>7.900000</td>\n",
       "      <td>103.000000</td>\n",
       "      <td>0.400000</td>\n",
       "      <td>61.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>72.000002</td>\n",
       "      <td>45.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>14.000001</td>\n",
       "      <td>250.087143</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>90.000004</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3.5</td>\n",
       "      <td>-8.857143e+00</td>\n",
       "      <td>0.7</td>\n",
       "      <td>7.298571</td>\n",
       "      <td>6.000001</td>\n",
       "      <td>30.999999</td>\n",
       "      <td>7.900000</td>\n",
       "      <td>103.000000</td>\n",
       "      <td>0.400000</td>\n",
       "      <td>63.222218</td>\n",
       "      <td>...</td>\n",
       "      <td>31.857143</td>\n",
       "      <td>218.142851</td>\n",
       "      <td>0.0</td>\n",
       "      <td>14.999999</td>\n",
       "      <td>283.453186</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>140.000001</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3.5</td>\n",
       "      <td>2.000002e+00</td>\n",
       "      <td>0.7</td>\n",
       "      <td>7.430000</td>\n",
       "      <td>6.000001</td>\n",
       "      <td>28.000000</td>\n",
       "      <td>7.900000</td>\n",
       "      <td>103.000000</td>\n",
       "      <td>0.400000</td>\n",
       "      <td>61.500002</td>\n",
       "      <td>...</td>\n",
       "      <td>41.000001</td>\n",
       "      <td>126.000004</td>\n",
       "      <td>0.0</td>\n",
       "      <td>16.000001</td>\n",
       "      <td>279.962433</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>167.500004</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3.0</td>\n",
       "      <td>2.000002e+00</td>\n",
       "      <td>0.7</td>\n",
       "      <td>7.450000</td>\n",
       "      <td>6.000001</td>\n",
       "      <td>28.000000</td>\n",
       "      <td>7.544444</td>\n",
       "      <td>102.111111</td>\n",
       "      <td>0.400000</td>\n",
       "      <td>57.777778</td>\n",
       "      <td>...</td>\n",
       "      <td>39.000000</td>\n",
       "      <td>76.999998</td>\n",
       "      <td>0.0</td>\n",
       "      <td>17.000000</td>\n",
       "      <td>282.454803</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>31.500001</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3.5</td>\n",
       "      <td>8.940697e-07</td>\n",
       "      <td>0.6</td>\n",
       "      <td>7.360000</td>\n",
       "      <td>39.000001</td>\n",
       "      <td>27.000000</td>\n",
       "      <td>8.500000</td>\n",
       "      <td>110.000002</td>\n",
       "      <td>1.600000</td>\n",
       "      <td>41.600002</td>\n",
       "      <td>...</td>\n",
       "      <td>46.000001</td>\n",
       "      <td>69.000000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>6.000000</td>\n",
       "      <td>207.211288</td>\n",
       "      <td>0.000161</td>\n",
       "      <td>39.999999</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35550</th>\n",
       "      <td>1.6</td>\n",
       "      <td>8.250002e+00</td>\n",
       "      <td>1.8</td>\n",
       "      <td>7.478750</td>\n",
       "      <td>5.000000</td>\n",
       "      <td>25.999999</td>\n",
       "      <td>7.300000</td>\n",
       "      <td>105.000000</td>\n",
       "      <td>0.600000</td>\n",
       "      <td>57.200001</td>\n",
       "      <td>...</td>\n",
       "      <td>43.375000</td>\n",
       "      <td>140.749996</td>\n",
       "      <td>1.0</td>\n",
       "      <td>8.000000</td>\n",
       "      <td>482.765930</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35551</th>\n",
       "      <td>4.2</td>\n",
       "      <td>-4.000001e+00</td>\n",
       "      <td>1.8</td>\n",
       "      <td>7.410000</td>\n",
       "      <td>5.000000</td>\n",
       "      <td>25.999999</td>\n",
       "      <td>7.300000</td>\n",
       "      <td>104.000000</td>\n",
       "      <td>0.516667</td>\n",
       "      <td>54.499998</td>\n",
       "      <td>...</td>\n",
       "      <td>28.000000</td>\n",
       "      <td>92.000002</td>\n",
       "      <td>1.0</td>\n",
       "      <td>9.000001</td>\n",
       "      <td>424.817413</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35552</th>\n",
       "      <td>4.3</td>\n",
       "      <td>-4.000001e+00</td>\n",
       "      <td>1.8</td>\n",
       "      <td>7.410000</td>\n",
       "      <td>5.000000</td>\n",
       "      <td>25.999999</td>\n",
       "      <td>7.300000</td>\n",
       "      <td>104.000000</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>51.333335</td>\n",
       "      <td>...</td>\n",
       "      <td>28.000000</td>\n",
       "      <td>92.000002</td>\n",
       "      <td>1.0</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>327.633392</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35553</th>\n",
       "      <td>4.3</td>\n",
       "      <td>-4.000001e+00</td>\n",
       "      <td>1.8</td>\n",
       "      <td>7.410000</td>\n",
       "      <td>5.000000</td>\n",
       "      <td>25.999999</td>\n",
       "      <td>7.300000</td>\n",
       "      <td>104.000000</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>51.333335</td>\n",
       "      <td>...</td>\n",
       "      <td>28.000000</td>\n",
       "      <td>92.000002</td>\n",
       "      <td>1.0</td>\n",
       "      <td>12.000000</td>\n",
       "      <td>223.401154</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35554</th>\n",
       "      <td>3.7</td>\n",
       "      <td>8.940697e-07</td>\n",
       "      <td>2.6</td>\n",
       "      <td>7.240000</td>\n",
       "      <td>6.000001</td>\n",
       "      <td>25.999999</td>\n",
       "      <td>9.300000</td>\n",
       "      <td>107.000000</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>34.599998</td>\n",
       "      <td>...</td>\n",
       "      <td>35.000000</td>\n",
       "      <td>112.000001</td>\n",
       "      <td>1.0</td>\n",
       "      <td>14.999999</td>\n",
       "      <td>300.635284</td>\n",
       "      <td>0.024528</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>35555 rows × 54 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       Albumin   Arterial_BE  Arterial_lactate  Arterial_pH        BUN  \\\n",
       "0          4.0  5.999998e+00               1.0     7.310000   6.000001   \n",
       "1          3.5 -8.857143e+00               0.7     7.298571   6.000001   \n",
       "2          3.5  2.000002e+00               0.7     7.430000   6.000001   \n",
       "3          3.0  2.000002e+00               0.7     7.450000   6.000001   \n",
       "4          3.5  8.940697e-07               0.6     7.360000  39.000001   \n",
       "...        ...           ...               ...          ...        ...   \n",
       "35550      1.6  8.250002e+00               1.8     7.478750   5.000000   \n",
       "35551      4.2 -4.000001e+00               1.8     7.410000   5.000000   \n",
       "35552      4.3 -4.000001e+00               1.8     7.410000   5.000000   \n",
       "35553      4.3 -4.000001e+00               1.8     7.410000   5.000000   \n",
       "35554      3.7  8.940697e-07               2.6     7.240000   6.000001   \n",
       "\n",
       "        CO2_mEqL   Calcium    Chloride  Creatinine      DiaBP  ...      paCO2  \\\n",
       "0      27.000000  7.900000  103.000000    0.400000  61.000000  ...  72.000002   \n",
       "1      30.999999  7.900000  103.000000    0.400000  63.222218  ...  31.857143   \n",
       "2      28.000000  7.900000  103.000000    0.400000  61.500002  ...  41.000001   \n",
       "3      28.000000  7.544444  102.111111    0.400000  57.777778  ...  39.000000   \n",
       "4      27.000000  8.500000  110.000002    1.600000  41.600002  ...  46.000001   \n",
       "...          ...       ...         ...         ...        ...  ...        ...   \n",
       "35550  25.999999  7.300000  105.000000    0.600000  57.200001  ...  43.375000   \n",
       "35551  25.999999  7.300000  104.000000    0.516667  54.499998  ...  28.000000   \n",
       "35552  25.999999  7.300000  104.000000    0.500000  51.333335  ...  28.000000   \n",
       "35553  25.999999  7.300000  104.000000    0.500000  51.333335  ...  28.000000   \n",
       "35554  25.999999  9.300000  107.000000    0.500000  34.599998  ...  35.000000   \n",
       "\n",
       "             paO2  re_admission       bloc      iv_cdt  vaso_cdt      iv_phy  \\\n",
       "0       45.000000           0.0  14.000001  250.087143  0.000000   90.000004   \n",
       "1      218.142851           0.0  14.999999  283.453186  0.000000  140.000001   \n",
       "2      126.000004           0.0  16.000001  279.962433  0.000000  167.500004   \n",
       "3       76.999998           0.0  17.000000  282.454803  0.000000   31.500001   \n",
       "4       69.000000           1.0   6.000000  207.211288  0.000161   39.999999   \n",
       "...           ...           ...        ...         ...       ...         ...   \n",
       "35550  140.749996           1.0   8.000000  482.765930  0.000000    0.000000   \n",
       "35551   92.000002           1.0   9.000001  424.817413  0.000000    0.000000   \n",
       "35552   92.000002           1.0  10.000000  327.633392  0.000000    0.000000   \n",
       "35553   92.000002           1.0  12.000000  223.401154  0.000000    0.000000   \n",
       "35554  112.000001           1.0  14.999999  300.635284  0.024528    0.000000   \n",
       "\n",
       "       vaso_phy  die  term  \n",
       "0           0.0  0.0     0  \n",
       "1           0.0  0.0     0  \n",
       "2           0.0  0.0     0  \n",
       "3           0.0  0.0     1  \n",
       "4           0.0  0.0     1  \n",
       "...         ...  ...   ...  \n",
       "35550       0.0  0.0     0  \n",
       "35551       0.0  0.0     0  \n",
       "35552       0.0  0.0     0  \n",
       "35553       0.0  0.0     0  \n",
       "35554       0.0  0.0     1  \n",
       "\n",
       "[35555 rows x 54 columns]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_agent_state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_agent_state.to_csv('Questionnaire.csv')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
