{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-11-05 01:28:23.689129: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
      "/home/fn/anaconda3/envs/py37/lib/python3.7/site-packages/flatbuffers/compat.py:19: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses\n",
      "  import imp\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/fn/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow/python/compat/v2_compat.py:107: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "non-resource variables are not supported in the long term\n"
     ]
    }
   ],
   "source": [
    "import gym\n",
    "import torch\n",
    "import numpy as np\n",
    "from collections import deque\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import numpy as np\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pandas import DataFrame\n",
    "import pickle as pickle\n",
    "import math\n",
    "import copy\n",
    "import tf_slim\n",
    "import tensorflow.compat.v1 as tf\n",
    "tf.disable_v2_behavior()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 非自编码数据，来源于autoencoder\n",
    "train = pd.read_csv('/home/fn/Mynew_Spesis/sepsisrl-master/data/rl_train_data_final_cont_reward2_cost_noauto.csv')\n",
    "val = pd.read_csv(\"/home/fn/Mynew_Spesis/sepsisrl-master/data/rl_val_data_final_cont_reward2_cost_noauto.csv\")\n",
    "test = pd.read_csv('/home/fn/Mynew_Spesis/sepsisrl-master/data/rl_test_data_final_cont_reward2_cost_noauto.csv')\n",
    "\n",
    "\n",
    "train_df = train\n",
    "val_df = val\n",
    "test_df = test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "#state_features = [str(i) for i in range(16)]\n",
    "with open('/home/fn/Mynew_Spesis/sepsisrl-master/data/state_features.txt') as f:\n",
    "    state_features = f.read().split()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>bloc</th>\n",
       "      <th>icustayid</th>\n",
       "      <th>charttime</th>\n",
       "      <th>gender</th>\n",
       "      <th>age</th>\n",
       "      <th>elixhauser</th>\n",
       "      <th>re_admission</th>\n",
       "      <th>died_in_hosp</th>\n",
       "      <th>died_within_48h_of_out_time</th>\n",
       "      <th>mortality_90d</th>\n",
       "      <th>...</th>\n",
       "      <th>input_4hourly</th>\n",
       "      <th>output_total</th>\n",
       "      <th>output_4hourly</th>\n",
       "      <th>cumulated_balance</th>\n",
       "      <th>SOFA</th>\n",
       "      <th>SIRS</th>\n",
       "      <th>vaso_input</th>\n",
       "      <th>iv_input</th>\n",
       "      <th>reward</th>\n",
       "      <th>cost</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>94.0</td>\n",
       "      <td>7.209781e+09</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.566876</td>\n",
       "      <td>0.538462</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>0.511369</td>\n",
       "      <td>0.304348</td>\n",
       "      <td>0.50</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.066388</td>\n",
       "      <td>0.001724</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.222560</td>\n",
       "      <td>94.0</td>\n",
       "      <td>7.209796e+09</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.566876</td>\n",
       "      <td>0.538462</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>0.511369</td>\n",
       "      <td>0.565217</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.392434</td>\n",
       "      <td>0.026633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.527957</td>\n",
       "      <td>94.0</td>\n",
       "      <td>7.209839e+09</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.566876</td>\n",
       "      <td>0.538462</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>0.511369</td>\n",
       "      <td>0.521739</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.222774</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.589582</td>\n",
       "      <td>94.0</td>\n",
       "      <td>7.209853e+09</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.566876</td>\n",
       "      <td>0.538462</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>0.511369</td>\n",
       "      <td>0.521739</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.053111</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.641832</td>\n",
       "      <td>94.0</td>\n",
       "      <td>7.209868e+09</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.566876</td>\n",
       "      <td>0.538462</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.780569</td>\n",
       "      <td>0.545574</td>\n",
       "      <td>0.71113</td>\n",
       "      <td>0.512054</td>\n",
       "      <td>0.565217</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.0</td>\n",
       "      <td>799.0</td>\n",
       "      <td>0.392434</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 63 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       bloc  icustayid     charttime  gender       age  elixhauser  \\\n",
       "0  0.000000       94.0  7.209781e+09     0.0  0.566876    0.538462   \n",
       "1  0.222560       94.0  7.209796e+09     0.0  0.566876    0.538462   \n",
       "2  0.527957       94.0  7.209839e+09     0.0  0.566876    0.538462   \n",
       "3  0.589582       94.0  7.209853e+09     0.0  0.566876    0.538462   \n",
       "4  0.641832       94.0  7.209868e+09     0.0  0.566876    0.538462   \n",
       "\n",
       "   re_admission  died_in_hosp  died_within_48h_of_out_time  mortality_90d  \\\n",
       "0           1.0           0.0                          0.0            1.0   \n",
       "1           1.0           0.0                          0.0            1.0   \n",
       "2           1.0           0.0                          0.0            1.0   \n",
       "3           1.0           0.0                          0.0            1.0   \n",
       "4           1.0           0.0                          0.0            1.0   \n",
       "\n",
       "   ...  input_4hourly  output_total  output_4hourly  cumulated_balance  \\\n",
       "0  ...       0.000000      0.000000         0.00000           0.511369   \n",
       "1  ...       0.000000      0.000000         0.00000           0.511369   \n",
       "2  ...       0.000000      0.000000         0.00000           0.511369   \n",
       "3  ...       0.000000      0.000000         0.00000           0.511369   \n",
       "4  ...       0.780569      0.545574         0.71113           0.512054   \n",
       "\n",
       "       SOFA  SIRS  vaso_input  iv_input    reward      cost  \n",
       "0  0.304348  0.50         0.0       0.0  0.066388  0.001724  \n",
       "1  0.565217  0.25         0.0       0.0  0.392434  0.026633  \n",
       "2  0.521739  0.25         0.0       0.0  0.222774  0.000000  \n",
       "3  0.521739  0.25         0.0       0.0  0.053111  0.000000  \n",
       "4  0.565217  0.25         0.0     799.0  0.392434  0.000000  \n",
       "\n",
       "[5 rows x 63 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dtdata_val_50and50(df1):\n",
    "    trajectories=[]\n",
    "    obs = []\n",
    "    next_states = []\n",
    "    actions = []\n",
    "    rewards = []\n",
    "    dones = []\n",
    "    dieds =[]\n",
    "    length=[]\n",
    "    path_r=[]\n",
    "    costs = []\n",
    "    si  = 0\n",
    "    huo = 0\n",
    "    for i in df1.index:\n",
    "        ob = df1.loc[i,state_features]\n",
    "        r = df1.loc[i,'reward']\n",
    "        iv = df1.loc[i, 'iv_input']/2000\n",
    "        vaso = df1.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        die = df1.loc[i,'died_in_hosp']\n",
    "        cost = df1.loc[i,'cost']\n",
    "        if i != df1.index[-1]:\n",
    "            # if not terminal step in trajectory             \n",
    "            if df1.loc[i, 'icustayid'] == df1.loc[i+1, 'icustayid']:\n",
    "                next_state = df1.loc[i + 1, state_features]\n",
    "                done = 0\n",
    "            else:\n",
    "                # trajectory is finished\n",
    "                next_state = np.zeros(len(ob))\n",
    "                done = 1\n",
    "        else:\n",
    "            # last entry in df is the final state of that trajectory\n",
    "            next_state = np.zeros(len(ob))\n",
    "            done = 1\n",
    "        obs.append(ob)\n",
    "        next_states.append(next_state)\n",
    "        actions.append(action)\n",
    "        rewards.append(r)\n",
    "        dones.append(done)\n",
    "        costs.append(cost)\n",
    "        dieds.append(die)\n",
    "        if done == 1 and len(actions)>0:\n",
    "            if die == 0 and huo < 50:\n",
    "                huo+=1\n",
    "                path = dict({'observations': np.array(obs),'next_observations': np.array(next_states),'actions': np.array(actions),'rewards': np.array(rewards),'terminals': np.array(dones),'dieds':np.array(dieds),'costs':np.array(costs)})\n",
    "                trajectories.append(path)\n",
    "                path_r.append(sum(rewards))\n",
    "                length.append(len(obs))\n",
    "            elif die == 1 and si < 50:\n",
    "                si+=1\n",
    "                path = dict({'observations': np.array(obs),'next_observations': np.array(next_states),'actions': np.array(actions),'rewards': np.array(rewards),'terminals': np.array(dones),'dieds':np.array(dieds),'costs':np.array(costs)})\n",
    "                trajectories.append(path)\n",
    "                path_r.append(sum(rewards))\n",
    "                length.append(len(obs))\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "            dieds = []\n",
    "            costs = []\n",
    "            done =0\n",
    "        elif done == 1:\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "            dieds =[]\n",
    "            costs = []\n",
    "            done = 0\n",
    "            \n",
    "    return trajectories,length,path_r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dtdata_val(df1):\n",
    "    trajectories=[]\n",
    "    obs = []\n",
    "    next_states = []\n",
    "    actions = []\n",
    "    rewards = []\n",
    "    dones = []\n",
    "    dieds =[]\n",
    "    length=[]\n",
    "    path_r=[]\n",
    "    costs = []\n",
    "    for i in df1.index:\n",
    "        ob = df1.loc[i,state_features]\n",
    "        r = df1.loc[i,'reward']\n",
    "        iv = df1.loc[i, 'iv_input']/2000\n",
    "        vaso = df1.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        die = df1.loc[i,'died_in_hosp']\n",
    "        cost = df1.loc[i,'cost']\n",
    "        if i != df1.index[-1]:\n",
    "            # if not terminal step in trajectory             \n",
    "            if df1.loc[i, 'icustayid'] == df1.loc[i+1, 'icustayid']:\n",
    "                next_state = df1.loc[i + 1, state_features]\n",
    "                done = 0\n",
    "            else:\n",
    "                # trajectory is finished\n",
    "                next_state = np.zeros(len(ob))\n",
    "                done = 1\n",
    "        else:\n",
    "            # last entry in df is the final state of that trajectory\n",
    "            next_state = np.zeros(len(ob))\n",
    "            done = 1\n",
    "        obs.append(ob)\n",
    "        next_states.append(next_state)\n",
    "        actions.append(action)\n",
    "        rewards.append(r)\n",
    "        dones.append(done)\n",
    "        costs.append(cost)\n",
    "        dieds.append(die)\n",
    "        if done == 1 and len(actions)>0:\n",
    "            path = dict({'observations': np.array(obs),'next_observations': np.array(next_states),'actions': np.array(actions),'rewards': np.array(rewards),'terminals': np.array(dones),'dieds':np.array(dieds),'costs':np.array(costs)})\n",
    "            trajectories.append(path)\n",
    "            path_r.append(sum(rewards))\n",
    "            length.append(len(obs))\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "            dieds = []\n",
    "            costs = []\n",
    "            done =0\n",
    "        elif done == 1:\n",
    "            obs = []\n",
    "            next_states = []\n",
    "            actions = []\n",
    "            rewards = []\n",
    "            dones = []\n",
    "            dieds =[]\n",
    "            costs = []\n",
    "            done = 0\n",
    "            \n",
    "    return trajectories,length,path_r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cdt_data(df):\n",
    "    actions = []\n",
    "    costs = []\n",
    "    next_observations = []\n",
    "    observations = []\n",
    "    rewards = []\n",
    "    terminals = []\n",
    "    dieds = []\n",
    "    for i in df.index:\n",
    "        o = df.loc[i,state_features]\n",
    "        c = df.loc[i,'cost']\n",
    "        r = df.loc[i,'reward']\n",
    "        iv = df.loc[i, 'iv_input']/2000\n",
    "        vaso = df.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        die = df.loc[i,'died_in_hosp']\n",
    "\n",
    "        if i != df.index[-1]:\n",
    "            if df.loc[i,'icustayid'] == df.loc[i+1,'icustayid']:\n",
    "                n_o = df.loc[i+1,state_features]\n",
    "                term = False \n",
    "            else:\n",
    "                n_o = np.zeros(len(o))\n",
    "                term = True\n",
    "        else:\n",
    "            n_o = np.zeros(len(o))\n",
    "            term = True\n",
    "        next_observations.append(n_o)\n",
    "        observations.append(o)\n",
    "        actions.append(action)\n",
    "        rewards.append(r)\n",
    "        terminals.append(term)\n",
    "        costs.append(c)\n",
    "        dieds.append(die)\n",
    "    path = dict({'actions': np.array(actions),'next_observations': np.array(next_observations),'observations': np.array(observations),'rewards': np.array(rewards),'terminals': np.array(terminals),'costs':np.array(costs),'dieds':np.array(dieds)})\n",
    "    return path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cdt_data_val(df):\n",
    "    actions = []\n",
    "    costs = []\n",
    "    next_observations = []\n",
    "    observations = []\n",
    "    rewards = []\n",
    "    terminals = []\n",
    "    dieds = []\n",
    "    path = []\n",
    "    for i in df.index:\n",
    "        o = df.loc[i,state_features]\n",
    "        c = df.loc[i,'cost']\n",
    "        r = df.loc[i,'reward']\n",
    "        iv = df.loc[i, 'iv_input']/2000\n",
    "        vaso = df.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        die = df.loc[i,'died_in_hosp']\n",
    "\n",
    "        if i != df.index[-1]:\n",
    "            if df.loc[i,'icustayid'] == df.loc[i+1,'icustayid']:\n",
    "                n_o = df.loc[i+1,state_features]\n",
    "                term = False \n",
    "            else:\n",
    "                n_o = np.zeros(len(o))\n",
    "                term = True\n",
    "        else:\n",
    "            n_o = np.zeros(len(o))\n",
    "            term = True\n",
    "        \n",
    "        if term == True:\n",
    "            next_observations.append(n_o)\n",
    "            observations.append(o)\n",
    "            actions.append(action)\n",
    "            rewards.append(r)\n",
    "            terminals.append(term)\n",
    "            costs.append(c)\n",
    "            dieds.append(die)\n",
    "            path.append(dict({'actions': np.array(actions),\n",
    "                              'next_observations': np.array(next_observations),\n",
    "                              'observations': np.array(observations),'rewards': np.array(rewards),\n",
    "                              'terminals': np.array(terminals),'costs':np.array(costs),'dieds':np.array(dieds)}))\n",
    "            actions = []\n",
    "            costs = []\n",
    "            next_observations = []\n",
    "            observations = []\n",
    "            rewards = []\n",
    "            terminals = []\n",
    "            dieds = []\n",
    "        else:\n",
    "            next_observations.append(n_o)\n",
    "            observations.append(o)\n",
    "            actions.append(action)\n",
    "            rewards.append(r)\n",
    "            terminals.append(term)\n",
    "            costs.append(c)\n",
    "            dieds.append(die)\n",
    "    \n",
    "    return path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "17224",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/pandas/core/indexes/base.py\u001b[0m in \u001b[0;36mget_loc\u001b[0;34m(self, key, method, tolerance)\u001b[0m\n\u001b[1;32m   3360\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3361\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_engine\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_loc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcasted_key\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   3362\u001b[0m             \u001b[0;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0merr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/pandas/_libs/index.pyx\u001b[0m in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/pandas/_libs/index.pyx\u001b[0m in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;32mpandas/_libs/hashtable_class_helper.pxi\u001b[0m in \u001b[0;36mpandas._libs.hashtable.Int64HashTable.get_item\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;32mpandas/_libs/hashtable_class_helper.pxi\u001b[0m in \u001b[0;36mpandas._libs.hashtable.Int64HashTable.get_item\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;31mKeyError\u001b[0m: 17224",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_343468/925982080.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpath_val_p\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlength\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mpath_r\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdtdata_val_50and50\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_df\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/home/fn/OSRL/examples/train/my_cdt_data_val_noauto_p100.pkl'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'wb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m#my_dt_test_2,my_cql_new_df_10\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath_val_p\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_343468/3826147881.py\u001b[0m in \u001b[0;36mdtdata_val_50and50\u001b[0;34m(df1)\u001b[0m\n\u001b[1;32m     23\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mdf1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     24\u001b[0m             \u001b[0;31m# if not terminal step in trajectory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m             \u001b[0;32mif\u001b[0m \u001b[0mdf1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'icustayid'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mdf1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'icustayid'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     26\u001b[0m                 \u001b[0mnext_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdf1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate_features\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m                 \u001b[0mdone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m    923\u001b[0m                 \u001b[0;32mwith\u001b[0m \u001b[0msuppress\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mKeyError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIndexError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    924\u001b[0m                     \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtakeable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_takeable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 925\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getitem_tuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    926\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    927\u001b[0m             \u001b[0;31m# we by definition only have the 0th axis\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m_getitem_tuple\u001b[0;34m(self, tup)\u001b[0m\n\u001b[1;32m   1098\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_getitem_tuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtup\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1099\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0msuppress\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mIndexingError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1100\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getitem_lowerdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtup\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1102\u001b[0m         \u001b[0;31m# no multi-index, so validate all of the indexers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m_getitem_lowerdim\u001b[0;34m(self, tup)\u001b[0m\n\u001b[1;32m    836\u001b[0m                 \u001b[0;31m# We don't need to check for tuples here because those are\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    837\u001b[0m                 \u001b[0;31m#  caught by the _is_nested_tuple_indexer check above.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 838\u001b[0;31m                 \u001b[0msection\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getitem_axis\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    839\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    840\u001b[0m                 \u001b[0;31m# We should never have a scalar section here, because\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m_getitem_axis\u001b[0;34m(self, key, axis)\u001b[0m\n\u001b[1;32m   1162\u001b[0m         \u001b[0;31m# fall thru to straight lookup\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1163\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_key\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1164\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_label\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1165\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1166\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_get_slice_axis\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mslice_obj\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mslice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m_get_label\u001b[0;34m(self, label, axis)\u001b[0m\n\u001b[1;32m   1111\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_get_label\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1112\u001b[0m         \u001b[0;31m# GH#5667 this will fail if the label is not present in the axis.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1113\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mxs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlabel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1114\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1115\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_handle_lowerdim_multi_index_axis0\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtup\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36mxs\u001b[0;34m(self, key, axis, level, drop_level)\u001b[0m\n\u001b[1;32m   3774\u001b[0m                 \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Expected label or tuple of labels, got {key}\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3775\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3776\u001b[0;31m             \u001b[0mloc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_loc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   3777\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3778\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/pandas/core/indexes/base.py\u001b[0m in \u001b[0;36mget_loc\u001b[0;34m(self, key, method, tolerance)\u001b[0m\n\u001b[1;32m   3361\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_engine\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_loc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcasted_key\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3362\u001b[0m             \u001b[0;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0merr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3363\u001b[0;31m                 \u001b[0;32mraise\u001b[0m \u001b[0mKeyError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0merr\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   3364\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3365\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mis_scalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0misna\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhasnans\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyError\u001b[0m: 17224"
     ]
    }
   ],
   "source": [
    "# path_val_p,length,path_r = dtdata_val_50and50(val_df)\n",
    "# with open('/home/fn/OSRL/examples/train/my_cdt_data_val_noauto_p100.pkl','wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "#     pickle.dump(path_val_p,f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "# path_train = get_cdt_data(train_df)\n",
    "# #path_test = get_cdt_data(test_df)\n",
    "# path_val = get_cdt_data(val_df)\n",
    "\n",
    "path_val_p,length,path_r = dtdata_val(val_df)\n",
    "# path_train_p,length,path_r = dtdata_val(train_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_die_alive_100(path,seed):\n",
    "    num_trajectories = len(path)\n",
    "    np.random.seed(seed)\n",
    "    batch_inds = np.random.choice(\n",
    "            np.arange(num_trajectories),\n",
    "            size=10000,\n",
    "            replace=True,\n",
    "        )\n",
    "    paths = []\n",
    "    die = 0 \n",
    "    nodie = 0\n",
    "    num = 50\n",
    "    for i in range(num_trajectories):\n",
    "        if die < num and path[batch_inds[i]]['dieds'][0] == 1: \n",
    "            traj = path[batch_inds[i]]\n",
    "            paths.append(traj)\n",
    "            die += 1\n",
    "        if nodie< num and path[batch_inds[i]]['dieds'][0] == 0:\n",
    "            traj = path[batch_inds[i]]\n",
    "            paths.append(traj)\n",
    "            nodie += 1\n",
    "        if die > num and nodie > num:\n",
    "            break\n",
    "    return paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.23278802 0.23278802 0.23278802 0.01111601 0.22279616 0.44446817\n",
      " 0.40245115 0.28915906 0.28915906 0.50391904 0.7133169  0.50123798]\n",
      "[0.44754799 0.72290999 0.55665598 0.71599795 0.93935108 0.77832799\n",
      " 0.99308796 0.56491982 0.56356802 0.49855692 0.51178333 0.60866486\n",
      " 0.93935108 0.56356802 0.88298003 0.93935108 0.60866486 0.50123798]\n",
      "[0.94934293 0.5569482  0.9349889  0.5578463  0.9349889  0.77951831\n",
      " 0.77951831 0.60985518 0.77832799 0.55570373 0.0665348  0.22279616\n",
      " 0.50123798]\n",
      "[0.51178333 0.99308796 0.94934293 0.61003852 0.61001666 0.77951831\n",
      " 0.60985518 0.55665598 0.50123798]\n",
      "[0.51083107 0.76968794 0.98444792 0.77951831 0.77951831 0.77951831\n",
      " 0.56475834 0.50123798]\n",
      "[0.22279616 0.3924593  0.05329451 0.44542043 0.28915906 0.06638812\n",
      " 0.44444335 0.44462966 0.11949593 0.23278802 0.01111601 0.22279616\n",
      " 0.43755613 0.05448484 0.22295765 0.05329451 0.43755613 0.22414797\n",
      " 0.009388   0.50123798]\n",
      "[0.22295765 0.22295765 0.44558191 0.33157485 0.28915906 0.4588222\n",
      " 0.50123798 0.28647801 0.40245115 0.28915906 0.28915906 0.28915906\n",
      " 0.28915906 0.4588222  0.50123798 0.50123798 0.50123798 0.50123798\n",
      " 0.33157485 0.50123798]\n",
      "[0.77970166 0.5569482  0.7133169  0.7133169  0.92807687 0.94799112\n",
      " 0.60985518 0.99308796 0.56491982 0.77832799 0.50123798]\n",
      "[0.28915906 0.4588222  0.50123798 0.71599795 0.76968794 0.76968794\n",
      " 0.54691463 0.50123798]\n",
      "[0.72386224 0.77951831 0.5578463  0.49164489 0.23278802 0.23278802\n",
      " 0.23278802 0.23278802 0.01111601 0.22279616 0.50123798]\n",
      "[0.50391904 0.88298003 0.60002481 0.7133169  0.50123798]\n"
     ]
    }
   ],
   "source": [
    "seed = [0,10,20,30,40,50,60,70,80,90,100]\n",
    "for i in seed:\n",
    "    path_100 = random_die_alive_100(path_val_p,i)\n",
    "    print(path_100[0]['rewards'])\n",
    "    fail = f'/home/fn/OSRL/examples/train/my_cdt_data_val_noauto_seed{i}.pkl'\n",
    "    with open(fail,'wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "        pickle.dump(path_100,f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100\n"
     ]
    }
   ],
   "source": [
    "print(len(path_100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据IV/2000，reward和cost都在[0,1]\n",
    "with open('/home/fn/OSRL/examples/train/my_cdt_data_val_noauto_p.pkl','wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "    pickle.dump(path_val_p,f)\n",
    "# 数据IV/2000，reward和cost都在[0,1]\n",
    "with open('/home/fn/OSRL/examples/train/my_cdt_data_train_noauto_p.pkl','wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "    pickle.dump(path_train_p,f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/home/fn/OSRL/examples/train/my_cdt_data_train_noauto.pkl','wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "    pickle.dump(path_train,f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/home/fn/OSRL/examples/train/my_cdt_data_val_noauto.pkl','wb') as f:  #my_dt_test_2,my_cql_new_df_10\n",
    "    pickle.dump(path_val,f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
