{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['/home/fn/A-Sepsis-RL/Our-Model', '/home/fn/anaconda3/envs/py37/lib/python37.zip', '/home/fn/anaconda3/envs/py37/lib/python3.7', '/home/fn/anaconda3/envs/py37/lib/python3.7/lib-dynload', '', '/home/fn/anaconda3/envs/py37/lib/python3.7/site-packages', '/home/fn/d4rl', '/home/fn/anaconda3/envs/py37/lib/python3.7/site-packages/mujoco_py-2.1.2.14-py3.7.egg', '/home/fn/anaconda3/envs/py37/lib/python3.7/site-packages/IPython/extensions', '/home/fn/.ipython']\n"
     ]
    }
   ],
   "source": [
    "import gym\n",
    "import numpy as np\n",
    "import torch\n",
    "import wandb\n",
    "from torch.nn import functional as F  # noqa\n",
    "import argparse\n",
    "import dill as pickle\n",
    "import random\n",
    "import sys\n",
    "print(sys.path)\n",
    "from ICRL.models.DT_ICRL import DecisionTransformer_icrl\n",
    "from ICRL.models.mlp_bc import MLPBCModel\n",
    "from ICRL.training.act_trainer import ActTrainer\n",
    "from ICRL.training.seq_trainer import SequenceTrainer\n",
    "from ICRL.evaluation.evaluate_episodes import evaluation_cost\n",
    "import pandas as pd\n",
    "\n",
    "dataset_path_val = f'.data/sepsis_data/violate_data_s.pkl'\n",
    "\n",
    "with open(dataset_path_val, 'rb') as f:\n",
    "    trajectories_val = pickle.load(f)\n",
    "\n",
    "dataset_path_expert = f'./data/sepsis_data/expert_data_s.pkl'\n",
    "\n",
    "with open(dataset_path_expert, 'rb') as f:\n",
    "    trajectories_expert = pickle.load(f)\n",
    "device = 'cuda'\n",
    "\n",
    "dataset_path_test = f'./data/sepsis_data/val_data_s.pkl'\n",
    "\n",
    "with open(dataset_path_test, 'rb') as f:\n",
    "    trajectories_test = pickle.load(f)\n",
    "device = 'cuda'\n",
    "\n",
    "dataset_path_expert_val = f'./data/sepsis_data/expert_data_val_s.pkl'\n",
    "\n",
    "with open(dataset_path_expert_val, 'rb') as f:\n",
    "    trajectories_expert_val = pickle.load(f)\n",
    "device = 'cuda'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset_path_train_my_cost = f'/home/fn/A-Sepsis-RL/Process_data/my_cost_data/my_cost_data_train.pkl'\n",
    "# with open(dataset_path_train_my_cost, 'rb') as f:\n",
    "#     trajectories_expert = pickle.load(f)\n",
    "# device = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6146\n"
     ]
    }
   ],
   "source": [
    "print(len(trajectories_expert_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = DecisionTransformer_icrl(state_dim=48,act_dim=2,\n",
    "        max_length=10,\n",
    "        max_ep_len=10,\n",
    "        hidden_size=64,\n",
    "        n_layer=3,\n",
    "        n_head=8,\n",
    "        n_inner=4*64,\n",
    "        activation_function='relu',\n",
    "        n_positions=1024,\n",
    "        resid_pdrop=0.9,\n",
    "        attn_pdrop=0.9,\n",
    "        pre_attn_embd_dim = 64,\n",
    "        use_weighted_sum = True,)\n",
    "model.load_state_dict(torch.load(\"./My_model/icrl3.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_data(trajectories,mode):\n",
    "    states, traj_lens, returns = [], [], []\n",
    "    for path in trajectories:\n",
    "        states.append(path['observations'])\n",
    "        traj_lens.append(len(path['observations']))\n",
    "        returns.append(path['rewards'].sum())\n",
    "    traj_lens, returns = np.array(traj_lens), np.array(returns)\n",
    "\n",
    "    states = np.concatenate(states, axis=0)\n",
    "\n",
    "    num_timesteps = sum(traj_lens)\n",
    "\n",
    "    print('=' * 50)\n",
    "    print(f'Starting new experiment:{mode}')\n",
    "    print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')\n",
    "    print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')\n",
    "    print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')\n",
    "    print('=' * 50)\n",
    "\n",
    "    K = 10\n",
    "    pct_traj = 1\n",
    "\n",
    "    # only train on top pct_traj trajectories (for %BC experiment)\n",
    "    num_timesteps = max(int(pct_traj*num_timesteps), 1)\n",
    "    sorted_inds = np.argsort(returns)  # lowest to highest \n",
    "    num_trajectories = 1\n",
    "    timesteps = traj_lens[sorted_inds[-1]]\n",
    "    ind = len(trajectories) - 2\n",
    "    while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] <= num_timesteps:\n",
    "        timesteps += traj_lens[sorted_inds[ind]]\n",
    "        num_trajectories += 1\n",
    "        ind -= 1\n",
    "    sorted_inds = sorted_inds[-num_trajectories:]\n",
    "\n",
    "    # used to reweight sampling so we sample according to timesteps instead of trajectories\n",
    "    p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds])\n",
    "    \n",
    "    return p_sample, traj_lens, returns,num_trajectories,sorted_inds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def discount_cumsum(x, gamma):\n",
    "    discount_cumsum = np.zeros_like(x)\n",
    "    discount_cumsum[-1] = x[-1]\n",
    "    for t in reversed(range(x.shape[0]-1)):\n",
    "        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1]\n",
    "    return discount_cumsum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_batch_o(trajectories_obs_val,\n",
    "                sorted_inds_obs_val,num_trajectories_obs_val,\n",
    "                p_sample_obs_val,device,batch_size=256,seed=10,\n",
    "                batch_inds=None,max_len=20,batch_inds_val=None):\n",
    "    max_ep_len = 10\n",
    "    act_dim = 2\n",
    "    state_dim = 48\n",
    "    trajs = trajectories_obs_val\n",
    "    inds = sorted_inds_obs_val\n",
    "    np.random.seed(seed)\n",
    "    if batch_inds_val is None:\n",
    "        batch_inds = np.random.choice(\n",
    "            np.arange(num_trajectories_obs_val),\n",
    "            size=batch_size,\n",
    "            replace=True,\n",
    "            p=p_sample_obs_val,  # reweights so we sample according to timesteps\n",
    "        )\n",
    "    else:\n",
    "        batch_inds = batch_inds_val\n",
    "\n",
    "    s,s_next, a, r, d, rtg, timesteps, mask = [],[], [], [], [], [], [], []\n",
    "    die = []\n",
    "    for i in range(batch_size):\n",
    "        traj = trajs[int(inds[batch_inds[i]])]\n",
    "        si = random.randint(0, traj['rewards'].shape[0] - 1)\n",
    "\n",
    "        # get sequences from dataset\n",
    "        s.append(traj['observations'][si:si + max_len].reshape(1, -1, state_dim).astype(np.float64))\n",
    "        #s_next.append(traj['observations'][si+1:si + max_len+1].reshape(1, -1, state_dim))\n",
    "        a.append(traj['actions'][si:si + max_len].reshape(1, -1, act_dim))\n",
    "        r.append(traj['rewards'][si:si + max_len].reshape(1, -1, 1))\n",
    "        die.append(traj['dieds'][si:si + max_len].reshape(1, -1, 1))\n",
    "        #r_next.append(traj['rewards'][si+1:si + max_len + 1].reshape(1, -1, 1))\n",
    "        if 'terminals' in traj:\n",
    "            d.append(traj['terminals'][si:si + max_len].reshape(1, -1))\n",
    "        else:\n",
    "            d.append(traj['dones'][si:si + max_len].reshape(1, -1))\n",
    "        timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))\n",
    "        timesteps[-1][timesteps[-1] >= max_ep_len] = max_ep_len-1  # padding cutoff\n",
    "        rtg.append(discount_cumsum(traj['rewards'][si:], gamma=1.)[:s[-1].shape[1] + 1].reshape(1, -1, 1))\n",
    "        if rtg[-1].shape[1] <= s[-1].shape[1]:\n",
    "            rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)\n",
    "\n",
    "        # padding and state + reward normalization\n",
    "        tlen = s[-1].shape[1]\n",
    "        s[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1)\n",
    "        a[-1] = np.concatenate([np.ones((1, max_len - tlen, act_dim)) * -10., a[-1]], axis=1)\n",
    "        r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1)\n",
    "        die[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), die[-1]], axis=1)\n",
    "        d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)\n",
    "        rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1)\n",
    "        timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1)\n",
    "        mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1))\n",
    "\n",
    "    s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device)\n",
    "    a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device)\n",
    "    r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=device)\n",
    "    die = torch.from_numpy(np.concatenate(die, axis=0)).to(dtype=torch.float32, device=device)\n",
    "    d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=device)\n",
    "    rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device)\n",
    "    timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device)\n",
    "    mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device)\n",
    "\n",
    "    return s, a, r, d, rtg, timesteps, mask,die"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation_cost(model,use_weighted_sum,\n",
    "                    train_type,states_e, actions_e,\n",
    "                    timesteps_e,attention_mask_e,):\n",
    "        \n",
    "    B,T,_ = actions_e.shape\n",
    "\n",
    "    trans_pred_e,_ = model.forward(\n",
    "        states_e, actions_e, timesteps_e, attention_mask=attention_mask_e,training=False\n",
    "    )\n",
    "    #print(trans_pred_e)\n",
    "    if use_weighted_sum:\n",
    "        trans_pred_e = trans_pred_e[\"weighted_sum\"]\n",
    "    else:\n",
    "        trans_pred_e = trans_pred_e[\"value\"]\n",
    "\n",
    "    if train_type == \"mean\":\n",
    "        results = torch.mean(trans_pred_e.reshape(B, T), axis=1).reshape(-1, 1)\n",
    "    elif train_type == \"sum\":\n",
    "        results = torch.sum(trans_pred_e.reshape(B, T), axis=1).reshape(-1, 1)\n",
    "    elif train_type == \"last\":\n",
    "        results = trans_pred_e.reshape(B, T)[:, -1].reshape(-1, 1)\n",
    "    elif train_type == \"every\":\n",
    "        results = trans_pred_e.reshape(B, T)\n",
    "        \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_no_mask_tolist(data,mask):\n",
    "\n",
    "    first_one_index = torch.argmax(mask, dim=1)\n",
    "\n",
    "    a_masked = [row[index:].view(-1).tolist() for row, index in zip(data, first_one_index)]\n",
    "    flattened_list = [item for sublist in a_masked for item in sublist]\n",
    "    return flattened_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize_list(lst,max_val=-1000,min_val=1000):\n",
    "    if max_val == -1000:\n",
    "        min_val = min(lst)\n",
    "        max_val = max(lst)\n",
    "\n",
    "    normalized_list = [(x - min_val) / (max_val - min_val) for x in lst]\n",
    "    return normalized_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_id(matrix):\n",
    "    encoded_matrix = []\n",
    "    num = 1\n",
    "    for row in matrix:\n",
    "        encoded_row = []\n",
    "        for elem in row:\n",
    "            if elem == 1:\n",
    "                encoded_row.append(num)\n",
    "        encoded_matrix.append(encoded_row)\n",
    "        num+=1\n",
    "    \n",
    "    flattened_list = [item for sublist in encoded_matrix for item in sublist]\n",
    "\n",
    "    return flattened_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_iv_vaso(action,mask):\n",
    "    #print(action)\n",
    "    first_one_index = torch.argmax(mask, dim=1)\n",
    "    \n",
    "    a_masked = [row[index:] for row, index in zip(action, first_one_index)]\n",
    "    #print(a_masked)\n",
    "    iv = [item[0].cpu().item() for sublist in a_masked for item in sublist]\n",
    "    vaso = [item[1].cpu().item() for sublist in a_masked for item in sublist]\n",
    "    #print(iv)\n",
    "    #print(vaso)\n",
    "    return iv,vaso"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_no_mask_state(state,mask):\n",
    "    first_one_index = torch.argmax(mask, dim=1)\n",
    "    a_masked = [row[index:] for row, index in zip(state, first_one_index)]\n",
    "    with open('./data/state_features copy.txt') as f:\n",
    "        state_features = f.read().split()\n",
    "    df_agent_state = pd.DataFrame(columns=state_features)\n",
    "    for i in range(len(state_features)):\n",
    "        t = [item[i].cpu().item() for sublist in a_masked for item in sublist]\n",
    "        df_agent_state[state_features[i]] = t\n",
    "    return df_agent_state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "Starting new experiment:expert_data\n",
      "14313 trajectories, 164325 timesteps found\n",
      "Average return: 6.31, std: 2.49\n",
      "Max return: 10.79, min: 0.26\n",
      "==================================================\n"
     ]
    }
   ],
   "source": [
    "use_weighted_sum = True\n",
    "train_type = 'every'\n",
    "p_sample_expert, traj_lens_expert, returns_expert,num_trajectories_expert,sorted_inds_expert = process_data(trajectories_expert,'expert_data')\n",
    "states_expert, actions_expert, rewards_expert, dones_expert, rtg_expert, timesteps_expert, attention_mask_expert,die_expert = get_batch_o(\n",
    "                trajectories_expert,\n",
    "                sorted_inds_expert,num_trajectories_expert,\n",
    "                p_sample_expert,device,batch_size=5000,seed=10) \n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    sum_pred_expert = evaluation_cost(\n",
    "        model.to(device=device),\n",
    "        use_weighted_sum,   \n",
    "        train_type,\n",
    "        states_expert, \n",
    "        actions_expert,\n",
    "        timesteps_expert,\n",
    "        attention_mask_expert,)\n",
    "df_expert = pd.DataFrame(columns=['reward','cost','die'])\n",
    "    \n",
    "df_expert['reward'] = get_no_mask_tolist(rewards_expert,attention_mask_expert)\n",
    "df_expert['cost'] = get_no_mask_tolist(sum_pred_expert,attention_mask_expert)\n",
    "df_expert['die'] =  get_no_mask_tolist(die_expert,attention_mask_expert)\n",
    "df_expert['id'] = get_id(attention_mask_expert)\n",
    "a,b = get_iv_vaso(actions_expert,attention_mask_expert)\n",
    "df_expert['iv'] = a \n",
    "df_expert['vaso'] = b \n",
    "\n",
    "df_state = get_no_mask_state(states_expert,attention_mask_expert)\n",
    "df_expert = pd.concat([df_expert, df_state], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>reward</th>\n",
       "      <th>cost</th>\n",
       "      <th>die</th>\n",
       "      <th>id</th>\n",
       "      <th>iv</th>\n",
       "      <th>vaso</th>\n",
       "      <th>Albumin</th>\n",
       "      <th>Arterial_BE</th>\n",
       "      <th>Arterial_lactate</th>\n",
       "      <th>Arterial_pH</th>\n",
       "      <th>...</th>\n",
       "      <th>age</th>\n",
       "      <th>elixhauser</th>\n",
       "      <th>gender</th>\n",
       "      <th>mechvent</th>\n",
       "      <th>output_4hourly</th>\n",
       "      <th>output_total</th>\n",
       "      <th>paCO2</th>\n",
       "      <th>paO2</th>\n",
       "      <th>re_admission</th>\n",
       "      <th>bloc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.492307</td>\n",
       "      <td>-3.878531</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.04500</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.612245</td>\n",
       "      <td>0.610390</td>\n",
       "      <td>0.034130</td>\n",
       "      <td>0.545455</td>\n",
       "      <td>...</td>\n",
       "      <td>0.394326</td>\n",
       "      <td>0.214286</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.722235</td>\n",
       "      <td>0.789443</td>\n",
       "      <td>0.366382</td>\n",
       "      <td>0.048904</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.877971</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.676733</td>\n",
       "      <td>-3.878532</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.07000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.510204</td>\n",
       "      <td>0.417440</td>\n",
       "      <td>0.023891</td>\n",
       "      <td>0.533911</td>\n",
       "      <td>...</td>\n",
       "      <td>0.394326</td>\n",
       "      <td>0.214286</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.709944</td>\n",
       "      <td>0.791666</td>\n",
       "      <td>0.129162</td>\n",
       "      <td>0.340882</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.901554</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.553846</td>\n",
       "      <td>-3.878525</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.08375</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.510204</td>\n",
       "      <td>0.558442</td>\n",
       "      <td>0.023891</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>...</td>\n",
       "      <td>0.394326</td>\n",
       "      <td>0.214286</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.722235</td>\n",
       "      <td>0.794151</td>\n",
       "      <td>0.183191</td>\n",
       "      <td>0.185497</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.923625</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.553843</td>\n",
       "      <td>-3.878512</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.01575</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.408163</td>\n",
       "      <td>0.558442</td>\n",
       "      <td>0.023891</td>\n",
       "      <td>0.686869</td>\n",
       "      <td>...</td>\n",
       "      <td>0.394326</td>\n",
       "      <td>0.214286</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.768060</td>\n",
       "      <td>0.798235</td>\n",
       "      <td>0.171372</td>\n",
       "      <td>0.102867</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.944365</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.544677</td>\n",
       "      <td>-3.033291</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2</td>\n",
       "      <td>0.02000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.510204</td>\n",
       "      <td>0.532468</td>\n",
       "      <td>0.020478</td>\n",
       "      <td>0.595960</td>\n",
       "      <td>...</td>\n",
       "      <td>0.882943</td>\n",
       "      <td>0.714286</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.671419</td>\n",
       "      <td>0.625773</td>\n",
       "      <td>0.212738</td>\n",
       "      <td>0.089376</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.589582</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 54 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     reward      cost  die  id       iv  vaso   Albumin  Arterial_BE  \\\n",
       "0  0.492307 -3.878531  0.0   1  0.04500   0.0  0.612245     0.610390   \n",
       "1  0.676733 -3.878532  0.0   1  0.07000   0.0  0.510204     0.417440   \n",
       "2  0.553846 -3.878525  0.0   1  0.08375   0.0  0.510204     0.558442   \n",
       "3  0.553843 -3.878512  0.0   1  0.01575   0.0  0.408163     0.558442   \n",
       "4  0.544677 -3.033291  0.0   2  0.02000   0.0  0.510204     0.532468   \n",
       "\n",
       "   Arterial_lactate  Arterial_pH  ...       age  elixhauser  gender  mechvent  \\\n",
       "0          0.034130     0.545455  ...  0.394326    0.214286     1.0       0.0   \n",
       "1          0.023891     0.533911  ...  0.394326    0.214286     1.0       0.0   \n",
       "2          0.023891     0.666667  ...  0.394326    0.214286     1.0       0.0   \n",
       "3          0.023891     0.686869  ...  0.394326    0.214286     1.0       0.0   \n",
       "4          0.020478     0.595960  ...  0.882943    0.714286     1.0       0.0   \n",
       "\n",
       "   output_4hourly  output_total     paCO2      paO2  re_admission      bloc  \n",
       "0        0.722235      0.789443  0.366382  0.048904           0.0  0.877971  \n",
       "1        0.709944      0.791666  0.129162  0.340882           0.0  0.901554  \n",
       "2        0.722235      0.794151  0.183191  0.185497           0.0  0.923625  \n",
       "3        0.768060      0.798235  0.171372  0.102867           0.0  0.944365  \n",
       "4        0.671419      0.625773  0.212738  0.089376           1.0  0.589582  \n",
       "\n",
       "[5 rows x 54 columns]"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_expert.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_cost = max(df_expert['cost'])\n",
    "min_cost = min(df_expert['cost'])\n",
    "df_expert['cost'] = (df_expert['cost']-min_cost)/(max_cost-min_cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_r = max(df_expert['reward'])\n",
    "min_r = min(df_expert['reward'])\n",
    "df_expert['reward'] = (df_expert['reward']-min_r)/(max_r-min_r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../Process_data/state_features.txt') as f:\n",
    "    state_features = f.read().split()\n",
    "# state_features = [str(i) for i in range(16)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "def voce_new_data(df):\n",
    "    obs = []\n",
    "    next_states = []\n",
    "    actions = []\n",
    "    rewards = []\n",
    "    dones = []\n",
    "    died = []\n",
    "    costs = []\n",
    "    for i in df.index:\n",
    "        ob = df.loc[i,state_features]\n",
    "        r = df.loc[i,'reward']\n",
    "        c = 0 #df.loc[i,'cost']\n",
    "        iv = df.loc[i, 'iv']\n",
    "        vaso = df.loc[i, 'vaso']\n",
    "        action = [iv,vaso]\n",
    "        die = df.loc[i,'die']\n",
    "        if i != df.index[-1]:\n",
    "            # if not terminal step in trajectory             \n",
    "            if df.loc[i, 'id'] == df.loc[i+1, 'id']:\n",
    "                next_state = df.loc[i + 1, state_features]\n",
    "                done = 0\n",
    "            else:\n",
    "                # trajectory is finished\n",
    "                next_state = np.zeros(len(ob))\n",
    "                done = 1\n",
    "        else:\n",
    "            # last entry in df is the final state of that trajectory\n",
    "            next_state = np.zeros(len(ob))\n",
    "            done = 1\n",
    "        obs.append(np.array(ob))\n",
    "        next_states.append(np.array(next_state))\n",
    "        actions.append(np.array(action))\n",
    "        rewards.append(np.array(r))\n",
    "        dones.append(np.array(done))\n",
    "        costs.append(np.array(c))\n",
    "        died.append(np.array(die))\n",
    "    \n",
    "    path = dict({'observations': np.array(obs),'next_observations': np.array(next_states),'actions': np.array(actions),'rewards': np.array(rewards),'terminals': np.array(dones),'dieds':np.array(died),'costs':np.array(costs)})\n",
    "\n",
    "    return path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = voce_new_data(df_expert)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "print(sum(path['costs']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../CDT/examples_my_cost/data/cost0_data.pkl','wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "    pickle.dump(path,f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
