{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "#from critics import *\n",
    "#from actors import *\n",
    "#from actor_critic import *\n",
    "from helper_functions import *\n",
    "from DDPG_torch import DDPG\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import numpy as np\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pandas import DataFrame\n",
    "import pickle as pickle\n",
    "import math\n",
    "import copy\n",
    "import tf_slim\n",
    "import tensorflow.compat.v1 as tf\n",
    "tf.disable_v2_behavior()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# train = pd.read_csv('../Process_data/train_data.csv')\n",
    "# val = pd.read_csv(\"../Process_data/val_data.csv\")\n",
    "# test = pd.read_csv('../Process_data/test_data.csv')\n",
    "# # train = pd.read_csv('../Process_data/expert_data_train.csv')\n",
    "# # val = pd.read_csv(\"../Process_data/expert_data_val.csv\") # max iv = 2000,vaso=1,die[0,1]\n",
    "# df = train\n",
    "# val_df = val\n",
    "# #test_df = test\n",
    "\n",
    "# REWARD_THRESHOLD =1 \n",
    "# reg_lambda = 5 \n",
    "# df['iv_input']/=2000\n",
    "# val_df['iv_input']/=2000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "train = pd.read_csv('../Process_data/rl_train_data_final_auto.csv')\n",
    "val = pd.read_csv(\"../Process_data/rl_val_data_final_auto.csv\")\n",
    "test = pd.read_csv('../Process_data/rl_test_data_final_auto.csv')\n",
    "# [-15,15]\n",
    "train_orig = pd.read_csv('../Process_data/rl_train_data_final_cont.csv')\n",
    "val_orig = pd.read_csv('../Process_data/rl_val_data_final_cont.csv')\n",
    "test_orig = pd.read_csv('../Process_data/rl_test_data_final_cont.csv')\n",
    "\n",
    "\n",
    "train['died_in_hosp'] = train_orig['died_in_hosp']\n",
    "val['died_in_hosp'] = val_orig['died_in_hosp']\n",
    "test['died_in_hosp'] = test_orig['died_in_hosp']\n",
    "df = train\n",
    "val_df = val\n",
    "test_df = test\n",
    "\n",
    "REWARD_THRESHOLD =15 \n",
    "reg_lambda = 5 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # normalization action\n",
    "# def normalization_action(df):\n",
    "#     max_iv = max(df['iv_input'])\n",
    "#     min_iv = min(df['iv_input'])\n",
    "#     max_vaso = max(df['vaso_input'])\n",
    "#     min_vaso = min(df['vaso_input'])\n",
    "#     df['vaso_input'] = (df['vaso_input']-min_vaso)/(max_vaso-min_vaso)\n",
    "#     df['iv_input'] = (df['iv_input']-min_iv)/(max_iv-min_iv)\n",
    "#     return df,max_iv,max_vaso,min_iv,min_vaso"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df,max_iv,max_vaso,min_iv,min_vaso = normalization_action(df)\n",
    "# val_df,val_max_iv,val_max_vaso,val_min_iv,val_min_vaso = normalization_action(val_df)\n",
    "# test_df,test_max_iv,test_max_vaso,test_min_iv,test_min_vaso = normalization_action(test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(test_max_iv,test_max_vaso,test_min_iv,test_min_vaso)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max(val_df['iv_input'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state_features = [str(i) for i in range(200)]\n",
    "\n",
    "# with open(\"../Process_data/state_features.txt\") as f:\n",
    "#     feat = f.read()\n",
    "# state_features = feat.split()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Replay_mem_size = 20000\n",
    "Train_batch_size = 64\n",
    "Actor_Learning_rate = 1e-3\n",
    "Critic_Learning_rate = 1e-3\n",
    "Gamma = 0.99\n",
    "explore_rate = 10\n",
    "tau = 1e-3\n",
    "State_dim = len(state_features)\n",
    "Action_dim = 2\n",
    "action_low = [0,0]\n",
    "action_high = [2000,1] # iv,vaso"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ounoise = OUNoise(Action_dim, 8, 3 , 0.9995)\n",
    "gsnoise = GaussNoise(10,2,0.9995)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(max(df['iv_input']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_one_state(train=True, eval_type = None):\n",
    "    if not train:\n",
    "        if eval_type is None:\n",
    "            raise Exception('Provide eval_type to process_batch')\n",
    "        elif eval_type == 'train':\n",
    "            a = df.sample(n=1)\n",
    "        elif eval_type == 'val':\n",
    "            a = val_df.sample(n=1)\n",
    "        elif eval_type == 'test':\n",
    "            a = test_df.sample(n=1)\n",
    "        else:\n",
    "            raise Exception('Unknown eval_type')\n",
    "    else:\n",
    "         a = df.sample(n=1)\n",
    "    for i in a.index:\n",
    "        state = a.loc[i,state_features]\n",
    "        return i,np.array(state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_next_step(idex,train=True, eval_type = None):\n",
    "    if not train:\n",
    "        if eval_type is None:\n",
    "            raise Exception('Provide eval_type to process_batch')\n",
    "        elif eval_type == 'train':\n",
    "            a = df\n",
    "        elif eval_type == 'val':\n",
    "            a = val_df\n",
    "        elif eval_type == 'test':\n",
    "            a = test_df\n",
    "        else:\n",
    "            raise Exception('Unknown eval_type')\n",
    "    else:\n",
    "         a = df\n",
    "            \n",
    "    cur_state = df.loc[idex,state_features]\n",
    "    iv = df.loc[idex, 'iv_input']\n",
    "    vaso = a.loc[idex, 'vaso_input']\n",
    "    action = [iv,vaso]    # action \n",
    "    reward = a.loc[idex,'reward']\n",
    "    done = 0\n",
    "    \n",
    "    if idex != df.index[-1]:           \n",
    "        if df.loc[idex, 'icustayid'] == df.loc[idex+1, 'icustayid']:\n",
    "            next_state = df.loc[idex + 1, state_features]\n",
    "            done = 0\n",
    "            idex = idex+1\n",
    "        else:\n",
    "            next_state = np.zeros(len(cur_state))\n",
    "            done = 1\n",
    "    else:\n",
    "        next_state = np.zeros(len(cur_state))\n",
    "        done = 1\n",
    "    return idex, next_state, action, reward, done"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class observation():\n",
    "    def __init__(self):\n",
    "        self.shape = [str(i) for i in range(200)]\n",
    "    def sample():\n",
    "        _,state = get_one_state()\n",
    "        return state\n",
    "class env():\n",
    "    def __init__(self,observation):\n",
    "        self.observation_space = observation\n",
    "observation = observation()\n",
    "env = env(observation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state_featurize = Featurize_state(env, True)\n",
    "After_featurize_state_dim = len(state_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \n",
    "def play(agent,num_epoach,epoach_step):\n",
    "    tot_phy_q = 0\n",
    "    tot_agent_q = 0\n",
    "    sum = 0\n",
    "    record_phy_action = []\n",
    "    record_phy_q = []\n",
    "    record_agent_action = []\n",
    "    record_agent_q = []\n",
    "    for epoach in range(num_epoach):\n",
    "        idex, pre_state = get_one_state(train=False, eval_type = 'val')\n",
    "        for step in range(epoach_step):\n",
    "            idex, next_state,action, reward, done = get_next_step(idex,train=False, eval_type = 'val')\n",
    "            \n",
    "#             print(\"action_phy\",action)\n",
    "            phys_q = agent.get_value(state_featurize.transfer(pre_state),action)\n",
    "            \n",
    "            agent_action = agent.action(state_featurize.transfer(pre_state), add_noise = False).squeeze(0).numpy()\n",
    "#             print(\"action_agent\",agent_action)\n",
    "            agent_q = agent.get_value(state_featurize.transfer(pre_state),agent_action)\n",
    "            \n",
    "            tot_phy_q += phys_q\n",
    "            tot_agent_q += agent_q\n",
    "            sum = sum+1\n",
    "            record_phy_action.append(action)\n",
    "            record_phy_q.append(phys_q)\n",
    "            record_agent_action.append(agent_action)\n",
    "            record_agent_q.append(agent_q)\n",
    "            if done:\n",
    "                break\n",
    "            pre_state = next_state\n",
    "        \n",
    "    return tot_phy_q/(sum+0.0), tot_agent_q/(sum+0.0),record_phy_action,record_phy_q,record_agent_action,record_agent_q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generates batches for the Q network - depending on train and eval_type, can select data from train/val/test sets.\n",
    "def process_batch(size=None, train=True, eval_type = None):\n",
    "    \n",
    "    if not train:\n",
    "        if eval_type is None:\n",
    "            raise Exception('Provide eval_type to process_batch')\n",
    "        elif eval_type == 'train':\n",
    "            a = df.copy()\n",
    "        elif eval_type == 'val':\n",
    "            a = val_df.copy()\n",
    "        elif eval_type == 'test':\n",
    "            a = test_df.copy()\n",
    "        else:\n",
    "            raise Exception('Unknown eval_type')\n",
    "    else:\n",
    "            a = df.sample(n=size)\n",
    "    \n",
    "    mem = deque()\n",
    "    \n",
    "    for i in a.index:\n",
    "        cur_state = a.loc[i,state_features]\n",
    "        iv = a.loc[i, 'iv_input']\n",
    "        vaso = a.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        reward = a.loc[i,'reward']\n",
    "\n",
    "        if i != df.index[-1]:\n",
    "            # if not terminal step in trajectory             \n",
    "            if df.loc[i, 'icustayid'] == df.loc[i+1, 'icustayid']:\n",
    "                next_state = df.loc[i + 1, state_features]\n",
    "                done = 0\n",
    "            else:\n",
    "                # trajectory is finished\n",
    "                next_state = np.zeros(len(cur_state))\n",
    "                done = 1\n",
    "        else:\n",
    "            # last entry in df is the final state of that trajectory\n",
    "            next_state = np.zeros(len(cur_state))\n",
    "            done = 1\n",
    "        \n",
    "        mem.append([cur_state, action, reward, next_state, done])\n",
    "    return mem,a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generates batches for the Q network - depending on train and eval_type, can select data from train/val/test sets.\n",
    "def process_batch2(size=None, train=True, eval_type = None):\n",
    "    \n",
    "    a = val_df.copy()\n",
    "    mem = deque()\n",
    "    \n",
    "    for i in a.index:\n",
    "        cur_state = a.loc[i,state_features]\n",
    "        iv = a.loc[i, 'iv_input']\n",
    "        vaso = a.loc[i, 'vaso_input']\n",
    "        action = [iv,vaso]\n",
    "        reward = a.loc[i,'reward']\n",
    "\n",
    "        id = a.loc[i,'icustayid']\n",
    "        die = a.loc[i,'died_in_hosp'] \t\n",
    "\n",
    "        if i != a.index[-1]:\n",
    "            # if not terminal step in trajectory             \n",
    "            if a.loc[i, 'icustayid'] == a.loc[i+1, 'icustayid']:\n",
    "                next_state = a.loc[i + 1, state_features]\n",
    "                done = 0\n",
    "            else:\n",
    "                # trajectory is finished\n",
    "                next_state = np.zeros(len(cur_state))\n",
    "                done = 1\n",
    "        else:\n",
    "            # last entry in df is the final state of that trajectory\n",
    "            next_state = np.zeros(len(cur_state))\n",
    "            done = 1\n",
    "        \n",
    "        mem.append([cur_state, action, reward, next_state, done,id,die])\n",
    "    return mem,a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \n",
    "def play_presents(agent,eval_type,batch_size=None):\n",
    "    device = 'cpu'\n",
    "    if batch_size is None:\n",
    "        batch,_ = process_batch2(train=False,eval_type=eval_type)\n",
    "    \n",
    "    batch_size = len(batch)\n",
    "    \n",
    "    pre_state_batch = torch.tensor([x[0] for x in batch], dtype=torch.float, device = device) \n",
    "    action_batch = torch.tensor([x[1] for x in batch], dtype = torch.float, device = device) \n",
    "        # view to make later computation happy\n",
    "    reward_batch = torch.tensor([x[2] for x in batch], dtype=torch.float, device = device).view(batch_size,1)\n",
    "    next_state_batch = torch.tensor([x[3] for x in batch], dtype=torch.float, device = device)\n",
    "    if_end = [x[4] for x in batch]\n",
    "    if_end = torch.tensor(np.array(if_end).astype(float),device = device, dtype=torch.float).view(batch_size,1)\n",
    "\n",
    "    \n",
    "    id = [x[5] for x in batch]\n",
    "    die = [x[6] for x in batch]\n",
    "\n",
    "    \n",
    "    phys_q = agent.get_batch_value(pre_state_batch,action_batch)\n",
    "#     print(\"phy action\",action_batch)\n",
    "    \n",
    "    agent_actions = agent.get_batch_action(pre_state_batch, add_noise = False)\n",
    "    print(\"phy action\",agent_actions)\n",
    "    \n",
    "    agent_q = agent.get_batch_value(pre_state_batch,agent_actions)\n",
    "        \n",
    "    return phys_q,action_batch,agent_q,agent_actions,id,die"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \n",
    "def play2(agent,eval_type,batch_size=None):\n",
    "    device = 'cpu'\n",
    "    if batch_size is None:\n",
    "        batch,_ = process_batch(train=False,eval_type=eval_type)\n",
    "    else: \n",
    "        mem,_ = process_batch(train=False,eval_type=eval_type)\n",
    "        batch = random.sample(mem, batch_size)\n",
    "    \n",
    "    batch_size = len(batch)\n",
    "    \n",
    "    pre_state_batch = torch.tensor([x[0] for x in batch], dtype=torch.float, device = device) \n",
    "    action_batch = torch.tensor([x[1] for x in batch], dtype = torch.float, device = device) \n",
    "        # view to make later computation happy\n",
    "    reward_batch = torch.tensor([x[2] for x in batch], dtype=torch.float, device = device).view(batch_size,1)\n",
    "    next_state_batch = torch.tensor([x[3] for x in batch], dtype=torch.float, device = device)\n",
    "    if_end = [x[4] for x in batch]\n",
    "    if_end = torch.tensor(np.array(if_end).astype(float),device = device, dtype=torch.float).view(batch_size,1)\n",
    "    \n",
    "    phys_q = agent.get_batch_value(pre_state_batch,action_batch)\n",
    "#     print(\"phy action\",action_batch)\n",
    "    \n",
    "    agent_actions = agent.get_batch_action(pre_state_batch, add_noise = False)\n",
    "#     print(\"phy action\",agent_actions)\n",
    "    \n",
    "    agent_q = agent.get_batch_value(pre_state_batch,agent_actions)\n",
    "        \n",
    "    return phys_q,action_batch,agent_q,agent_actions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train epoach   30 000 * 20\n",
    "def train2(agent,train_epoach,epoach_step):\n",
    "    \n",
    "    _,a = process_batch(train_epoach,train=True, eval_type = 'train')\n",
    "    \n",
    "    epoach = 0\n",
    "    \n",
    "    for idex in a.index:\n",
    "        pre_state = a.loc[idex,state_features]\n",
    "        record = []\n",
    "        acc_reward = 0\n",
    "        epoach = epoach+1\n",
    "        for step in range(epoach_step):\n",
    "            idex, next_state,action, reward, done = get_next_step(idex)\n",
    "            \n",
    "            acc_reward += reward\n",
    "            \n",
    "            record.append([pre_state,action,reward])\n",
    "            \n",
    "            agent.train(state_featurize.transfer(pre_state), action, reward, state_featurize.transfer(next_state), done)\n",
    "            \n",
    "            if done or step == epoach_step-1:\n",
    "                break\n",
    "            \n",
    "            pre_state = next_state\n",
    "        if epoach % 1000 == 0 and epoach > 0:\n",
    "            print(\"epoach:\",epoach)\n",
    "            phys_q,action_batch,agent_q,agent_actions = play2(agent,'val',128)\n",
    "            print(\"----physician's Q-value-----\",torch.mean(phys_q))\n",
    "            print(\"----Agent's Q-value-----\",torch.mean(agent_q))\n",
    "    return agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(agent,train_epoach,epoach_step):\n",
    "    \n",
    "#     _,a = process_batch(train_epoach)\n",
    "    \n",
    "    for epoach in range(train_epoach):\n",
    "        idex, pre_state = get_one_state()\n",
    "        record = []\n",
    "        acc_reward = 0\n",
    "        \n",
    "        for step in range(epoach_step):\n",
    "            idex, next_state,action, reward, done = get_next_step(idex)\n",
    "            \n",
    "            acc_reward += reward\n",
    "            \n",
    "            record.append([pre_state,action,reward])\n",
    "            \n",
    "            agent.train(state_featurize.transfer(pre_state), action, reward, state_featurize.transfer(next_state), done)\n",
    "            \n",
    "            if done or step == epoach_step-1:\n",
    "                break\n",
    "            \n",
    "            pre_state = next_state\n",
    "        if epoach % 100 == 0 and epoach > 0:\n",
    "            print(\"epoach:\",epoach)\n",
    "            phys_q,action_batch,agent_q,agent_actions = play2(agent,'val',128)\n",
    "            print(\"----physician's Q-value-----：\",torch.mean(phys_q))\n",
    "            print(\"----Agent's Q-value-----：\",torch.mean(agent_q))\n",
    "#         if epoach % 100 == 0 and epoach > 0:\n",
    "#             avr_reward_phy, avr_reward_agent,_,_,_,_ = play(agent,10,200)\n",
    "#             print('--------------episode ', epoach,  'average reward phy: ', avr_reward_phy, '---------------')\n",
    "#             print('--------------episode ', epoach,  'average reward agent: ', avr_reward_agent, '---------------')\n",
    "    return agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "import random\n",
    "numsteps = 20000\n",
    "state_dim = len(state_features)\n",
    "runs = random.random()\n",
    "wandb.init(\n",
    "    project='DDPG',\n",
    "    name=f\"experiment_{runs}\",\n",
    "    config={\n",
    "        \"State_dim\":48,\n",
    "        \"numsteps\":numsteps,\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train epoach  \n",
    "critic_losses=[]\n",
    "actor_losses=[]\n",
    "\n",
    "def train3(agent,num_steps=60000,train_epoach=32):\n",
    "    \n",
    "    for i in range(num_steps):\n",
    "        _,a = process_batch(train_epoach,train=True, eval_type = 'train')\n",
    "        \n",
    "        cc = []\n",
    "        aa = []\n",
    "        \n",
    "        for idex in a.index:\n",
    "            pre_state = a.loc[idex,state_features]\n",
    "            \n",
    "            idex, next_state,action, reward, done = get_next_step(idex)\n",
    "            \n",
    "            agent.train(state_featurize.transfer(pre_state), action, reward, state_featurize.transfer(next_state), done)\n",
    "            \n",
    "            critic_loss, actor_loss = agent.get_loss()\n",
    "            \n",
    "            cc.append(critic_loss)\n",
    "            aa.append(actor_loss)\n",
    "            \n",
    "            if done:\n",
    "                break\n",
    "        \n",
    "            critic_losses.append(np.mean(cc))\n",
    "            actor_losses.append(np.mean(aa))\n",
    "            wandb.log(\n",
    "                {\n",
    "                    \"critic_losses\":critic_loss,\n",
    "                    \"actor_losses\":actor_loss,\n",
    "                }\n",
    "            )\n",
    "                \n",
    "        if i % 1000 == 0 and i > 0:\n",
    "            print(\"num_steps:\",i)\n",
    "            phys_q,action_batch,agent_q,agent_actions = play2(agent,'val',128)\n",
    "            print('----DDPG---physician\\'s Q-values------：',torch.mean(phys_q))\n",
    "            print('----DDPG---angent\\'s Q-values-----：',torch.mean(agent_q))\n",
    "            print('DDPG-physician', np.mean(action_batch[:,0].numpy()),np.mean(action_batch[:,1].numpy()))\n",
    "            print('DDPG-Agent',np.mean(agent_actions[:,0].numpy()),np.mean(agent_actions[:,1].numpy()))\n",
    "#             print(critic_losses)\n",
    "            print('mean_critic_loss:',np.mean(cc))\n",
    "            print('mean_actor_loss:',np.mean(aa))\n",
    "    \n",
    "    \n",
    "    return agent,critic_loss,actor_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Action_dim = 2\n",
    "Replay_mem_size = 20000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ddpg = DDPG(After_featurize_state_dim, Action_dim, Replay_mem_size, Train_batch_size,\n",
    "             Gamma, 1e-3, 1e-4, action_low, action_high, 0.1, gsnoise, False) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "agent,critic_loss,actor_loss = train3(ddpg,10000,32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(agent,'./Save_model/ddpg_model.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "phys_q,action_batch,agent_q,agent_actions,id,die = play_presents(agent,'val')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({'phy_q':phys_q.numpy().ravel(),'phy_action_iv':action_batch[:,0].numpy(),\n",
    "                   'phy_action_vaso':action_batch[:,1].numpy(),'agent_q':agent_q.numpy().ravel(),'agent_action_iv':agent_actions[:,0].numpy(),\n",
    "                   'agent_action_vaso':agent_actions[:,1].numpy(),'id':id,'die':die})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head(n=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Ranking accuracy based on differences with doctors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ddpg_iv = agent_actions[:,0].numpy()\n",
    "ddpg_vaso = agent_actions[:,1].numpy()\n",
    "phy_iv = action_batch[:,0].numpy()\n",
    "phy_vaso = action_batch[:,1].numpy()\n",
    "die = die"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iv_diff = []\n",
    "diff = []\n",
    "vaso_diff = []\n",
    "dies = []\n",
    "batch_size = len(id)\n",
    "ac_ddpg_iv = []\n",
    "ac_ddpg_vaso = []\n",
    "ac_ph_iv = []\n",
    "ac_ph_vaso = []\n",
    "action_ddpg = []\n",
    "action_ph = []\n",
    "for w in range(batch_size):\n",
    "    ac_ddpg_iv.append(ddpg_iv[w])\n",
    "    ac_ddpg_vaso.append(ddpg_vaso[w])\n",
    "    ac_ph_iv.append(phy_iv[w])\n",
    "    ac_ph_vaso.append(phy_vaso[w])\n",
    "    action_ddpg.append([ddpg_iv[w],ddpg_vaso[w]])\n",
    "    action_ph.append([phy_iv[w],phy_vaso[w]])\n",
    "    iv_diff.append((torch.mean((torch.tensor(ddpg_iv[w])-torch.tensor(phy_iv[w]))**2)).detach().cpu().item())\n",
    "    vaso_diff.append((torch.mean((torch.tensor(ddpg_vaso[w])-torch.tensor(phy_vaso[w]))**2)).detach().cpu().item())\n",
    "    diff.append((torch.mean((torch.tensor([ddpg_iv[w],ddpg_vaso[w]])-torch.tensor([phy_iv[w],phy_vaso[w]]))**2)).detach().cpu().item())\n",
    "    dies.append(die[w])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# iv_diff = []\n",
    "# diff = []\n",
    "# vaso_diff = []\n",
    "# dies = []\n",
    "# batch_size = len(id)\n",
    "# ac_ddpg_iv = []\n",
    "# ac_ddpg_vaso = []\n",
    "# ac_ph_iv = []\n",
    "# ac_ph_vaso = []\n",
    "# action_ddpg = []\n",
    "# action_ph = []\n",
    "# for w in range(batch_size):\n",
    "#     ac_ddpg_iv.append(ddpg_iv[w])\n",
    "#     ac_ddpg_vaso.append(ddpg_vaso[w])\n",
    "#     ac_ph_iv.append(phy_iv[w])\n",
    "#     ac_ph_vaso.append(phy_vaso[w])\n",
    "#     action_ddpg.append([ddpg_iv[w],ddpg_vaso[w]])\n",
    "#     action_ph.append([phy_iv[w],phy_vaso[w]])\n",
    "#     if w!=batch_size-1 and id[w] != id[w+1]:\n",
    "#         iv_diff.append((torch.mean((torch.tensor(ac_ddpg_iv)-torch.tensor(ac_ph_iv))**2)).detach().cpu().item())\n",
    "#         vaso_diff.append((torch.mean((torch.tensor(ac_ddpg_vaso)-torch.tensor(ac_ph_vaso))**2)).detach().cpu().item())\n",
    "#         diff.append((torch.mean((torch.tensor(action_ddpg)-torch.tensor(action_ph))**2)).detach().cpu().item())\n",
    "#         dies.append(die[w])\n",
    "#         ac_ddpg_iv = []\n",
    "#         ac_ddpg_vaso = []\n",
    "#         ac_ph_iv = []\n",
    "#         ac_ph_vaso = []\n",
    "#         action_ddpg = []\n",
    "#         action_ph = []\n",
    "# iv_diff.append((torch.mean((torch.tensor(ac_ddpg_iv)-torch.tensor(ac_ph_iv))**2)).detach().cpu().item())\n",
    "# vaso_diff.append((torch.mean((torch.tensor(ac_ddpg_vaso)-torch.tensor(ac_ph_vaso))**2)).detach().cpu().item())\n",
    "# diff.append((torch.mean((torch.tensor(action_ddpg)-torch.tensor(action_ph))**2)).detach().cpu().item())\n",
    "# dies.append(die[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diff_df = pd.DataFrame(columns=['iv','vaso','diff','die'])\n",
    "diff_df['iv'] = iv_diff\n",
    "diff_df['vaso'] = vaso_diff\n",
    "diff_df['diff'] = diff\n",
    "diff_df['die'] = dies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sum(diff_df['die'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(diff_df))\n",
    "unique_elements = set(id)\n",
    "num_unique_elements = len(unique_elements)\n",
    "print(num_unique_elements)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rank(diff_df):\n",
    "    seed = [0,1,2,3]\n",
    "    ac_iv = []\n",
    "    ac_vaso = []\n",
    "    ac_diff = []\n",
    "    for i in seed:\n",
    " \n",
    "        df_id_1_sample = diff_df[diff_df['die'] == 1].sample(n=500, random_state=i)\n",
    "\n",
    "  \n",
    "        df_id_0_sample = diff_df[diff_df['die'] == 0].sample(n=500, random_state=i)\n",
    "        df_combined = pd.concat([df_id_1_sample, df_id_0_sample])\n",
    "        l = ['iv','vaso','diff']\n",
    " \n",
    "        for j in range(len(l)):\n",
    "            df = df_combined.sort_values(l[j])\n",
    "            acc = 0\n",
    "            num = 0\n",
    "            for t in df.index:\n",
    "                if df.loc[t,'die']==0:\n",
    "                    acc=acc+1\n",
    "                num+=1\n",
    "                if num>=500:\n",
    "                    break\n",
    "            acc = acc/num\n",
    "            if j == 0:\n",
    "                ac_iv.append(acc)\n",
    "            elif j == 1:\n",
    "                ac_vaso.append(acc)\n",
    "            else:\n",
    "                ac_diff.append(acc)\n",
    "    return np.mean(ac_iv),np.std(ac_iv),np.mean(ac_vaso),np.std(ac_vaso),np.mean(ac_diff),np.std(ac_diff),diff_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a,a_std,b,b_std,c,c_std,diff_df = rank(diff_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('iv gaps:',a,a_std,'vaso gaps:',b,b_std,'diff gaps:',c,c_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head(n=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df['agent_action_iv'] = df['agent_action_iv']*(val_max_iv-val_min_iv)+val_min_iv\n",
    "# df['agent_action_vaso'] = df['agent_action_vaso']*(val_max_vaso-val_min_vaso)+val_min_vaso\n",
    "\n",
    "# df['phy_action_iv'] = df['phy_action_iv']*(val_max_iv-val_min_iv)+val_min_iv\n",
    "# df['phy_action_vaso'] = df['phy_action_vaso']*(val_max_vaso-val_min_vaso)+val_min_vaso"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['agent_action_iv']*=2000\n",
    "df['phy_action_iv']*=2000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " \n",
    "mean_per_agent_iv = []\n",
    "mean_per_agent_vaso = []\n",
    "mean_per_phy_iv = []\n",
    "mean_per_phy_vaso = []\n",
    "gap_iv = []\n",
    "gap_vaso =[]\n",
    "a_iv,a_vaso,p_iv,p_vaso = 0.0,0.0,0.0,0.0\n",
    "j=0\n",
    "for i in df.index:\n",
    "    if df.loc[i,'die'] ==1:\n",
    "        continue\n",
    "    if i == 0 or df.loc[i-1,'id']==df.loc[i,'id']:\n",
    "        j+=1\n",
    "        a_iv += df.loc[i,'agent_action_iv']\n",
    "        a_vaso += df.loc[i,'agent_action_vaso']\n",
    "        p_iv += df.loc[i,'phy_action_iv']\n",
    "        p_vaso += df.loc[i,'phy_action_vaso']\n",
    "    elif df.loc[i-1,'id']!=df.loc[i,'id']: \n",
    "        j+=1\n",
    "        mean_per_agent_iv.append(a_iv/j)\n",
    "        mean_per_agent_vaso.append(a_vaso/j)\n",
    "        mean_per_phy_iv.append(p_iv/j)\n",
    "        mean_per_phy_vaso.append(p_vaso/j)\n",
    "        gap_iv.append(abs(a_iv-p_iv)/j)\n",
    "        gap_vaso.append(abs(a_vaso-p_vaso)/j)\n",
    "        a_iv,a_vaso,p_iv,p_vaso = 0.0,0.0,0.0,0.0\n",
    "        j=1\n",
    "        a_iv += df.loc[i,'agent_action_iv']\n",
    "        a_vaso += df.loc[i,'agent_action_vaso']\n",
    "        p_iv += df.loc[i,'phy_action_iv']\n",
    "        p_vaso += df.loc[i,'phy_action_vaso']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#  \n",
    "t=0\n",
    "for i in range(len(gap_iv)):\n",
    "    if gap_iv[i]>500:\n",
    "        t+=1\n",
    "print(t/(len(gap_iv)))\n",
    "t=0\n",
    "for i in range(len(gap_vaso)):\n",
    "    if gap_vaso[i]>0.75:\n",
    "        t+=1\n",
    "print(t/(len(gap_vaso)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "delta_iv = []\n",
    "delta_vaso = []\n",
    "a = []\n",
    "b = []\n",
    "a_iv_delta,a_vaso_delta = 0.0,0.0\n",
    "j=0\n",
    "for i in df.index:\n",
    "    if df.loc[i,'die'] ==1  or i == 0:\n",
    "        continue\n",
    "    if df.loc[i-1,'id']==df.loc[i,'id']:\n",
    "\n",
    "        a_vaso_delta = max(abs(df.loc[i,'agent_action_vaso'] - df.loc[i-1,'agent_action_vaso']),a_vaso_delta)\n",
    "        a_iv_delta = max(abs(df.loc[i,'agent_action_iv']- df.loc[i-1,'agent_action_iv']),a_iv_delta)\n",
    "        a.append(df.loc[i,'agent_action_vaso'])\n",
    "        b.append(df.loc[i,'agent_action_iv'])\n",
    "\n",
    "    elif df.loc[i-1,'id']!=df.loc[i,'id']: \n",
    "        delta_iv.append(a_iv_delta)\n",
    "        delta_vaso.append(a_vaso_delta)\n",
    "        a_iv_delta,a_vaso_delta = 0.0,0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(sum(value > 0.75 for value in delta_vaso)/len(delta_vaso))\n",
    "print(sum(value > 750 for value in delta_iv)/len(delta_iv))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "print(sum(value > 0.75 for value in a)/len(a))\n",
    "\n",
    "print(sum(value > 1000 for value in b)/len(b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "plt.scatter(range(len(gap_iv)),gap_iv)\n",
    "plt.title(\" IV Gap Between Agent & Physician\")\n",
    "plt.xlabel(\"per patient\")\n",
    "plt.ylabel(\"IV gap\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Too high and sudden change"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sudden change\n",
    "\n",
    "def sudden_change_and_too_high(df,item_vaso,item_iv,max_sudden_change_vaso_data,max_sudden_change_iv_data,too_high_iv,too_high_vaso):\n",
    "    delta_iv = []\n",
    "    delta_vaso = []\n",
    "    a = []\n",
    "    b = []\n",
    "    a_iv_delta,a_vaso_delta = 0.0,0.0\n",
    "\n",
    "    for i in df.index:\n",
    "        if i == 0:\n",
    "            continue\n",
    "        if df.loc[i-1,'id']==df.loc[i,'id']:\n",
    "\n",
    "            a_vaso_delta = max(df.loc[i,item_vaso] - df.loc[i-1,item_vaso],a_vaso_delta)\n",
    "            a_iv_delta = max(df.loc[i,item_iv] - df.loc[i-1,item_iv],a_iv_delta)\n",
    "            a.append(df.loc[i,item_vaso])\n",
    "            b.append(df.loc[i,item_iv])\n",
    "\n",
    "        elif df.loc[i-1,'id']!=df.loc[i,'id']: \n",
    "            delta_iv.append(a_iv_delta)\n",
    "            delta_vaso.append(a_vaso_delta)\n",
    "            a_iv_delta,a_vaso_delta = 0.0,0.0\n",
    "            \n",
    "    print(\"change_vaso>\",max_sudden_change_vaso_data,\":\",sum(value > max_sudden_change_vaso_data for value in delta_vaso)/len(delta_vaso))\n",
    "    print(\"change_iv>\",max_sudden_change_iv_data,\":\",sum(value > max_sudden_change_iv_data for value in delta_iv)/len(delta_iv))\n",
    "    print(\"-------\"*10)\n",
    "    print(\"too high vaso>:\",too_high_vaso,\":\",sum(value > too_high_vaso for value in a)/len(a))\n",
    "\n",
    "    print(\"too high iv>:\",too_high_iv,\":\",sum(value > too_high_iv for value in b)/len(b))\n",
    "    print(\"-****---\"*10)\n",
    "\n",
    "    return delta_iv,delta_vaso,a,b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_iv,delta_vaso,a,b = sudden_change_and_too_high(df,'agent_action_vaso','agent_action_iv',0.9,1000,1000,0.75)\n",
    "delta_iv,delta_vaso,a,b = sudden_change_and_too_high(df,'agent_action_vaso','agent_action_iv',0.75,1500,1000,0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_iv,delta_vaso,a,b = sudden_change_and_too_high(df,'phy_action_vaso','phy_action_iv',0.9,1000,1000,0.75)\n",
    "delta_iv,delta_vaso,a,b = sudden_change_and_too_high(df,'phy_action_vaso','phy_action_iv',0.75,1500,1000,0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head(n=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv('../data/df_test_jijin.csv',index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils_fqe.net.common import MLP\n",
    "from utils_fqe.net.continuous import DistributionalCritic\n",
    "from copy import deepcopy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def FQE(policy):\n",
    "    init_critic =None\n",
    "    discount=0.99\n",
    "    target_update_period = 100\n",
    "    critic_lr = 1e-4\n",
    "    num_steps = 100000  #250000\n",
    "    polyak = 0.0\n",
    "    batch_size = 256\n",
    "    critic_hidden_features=1024\n",
    "    critic_hidden_layers=4\n",
    "    device=\"cpu\"\n",
    "    verbose = False\n",
    "\n",
    "\n",
    "    mem,_ = process_batch(train=False,eval_type='val')\n",
    "    batch = random.sample(mem, batch_size)\n",
    "    \n",
    "    batch_size = len(batch)\n",
    "    \n",
    "    pre_state_batch = torch.tensor([x[0] for x in batch], dtype=torch.float, device = device) \n",
    "    action_batch = torch.tensor([x[1] for x in batch], dtype = torch.float, device = device) \n",
    "    reward_batch = torch.tensor([x[2] for x in batch], dtype=torch.float, device = device).view(batch_size,1)\n",
    "    next_state_batch = torch.tensor([x[3] for x in batch], dtype=torch.float, device = device)\n",
    "    if_end = [x[4] for x in batch]\n",
    "    if_end = torch.tensor(np.array(if_end).astype(float),device = device, dtype=torch.float).view(batch_size,1)\n",
    "    \n",
    "    min_reward = torch.max(reward_batch)\n",
    "    max_reward = torch.min(reward_batch)\n",
    "\n",
    "    max_value = (1.2 * max_reward + 0.8 * min_reward) / (1 - discount)\n",
    "    min_value = (1.2 * min_reward + 0.8 * max_reward) / (1 - discount)\n",
    "\n",
    "    data = random.sample(mem, batch_size)\n",
    "    input_dim = State_dim + Action_dim\n",
    "    critic = MLP(input_dim, 1, critic_hidden_features, critic_hidden_layers).to(device)\n",
    "    if init_critic is not None: critic.load_state_dict(init_critic.state_dict())\n",
    "    critic_optimizer = torch.optim.Adam(critic.parameters(), lr=critic_lr)\n",
    "    target_critic = deepcopy(critic).to(device)\n",
    "    target_critic.requires_grad_(False)\n",
    "\n",
    "    if verbose:\n",
    "        counter = tqdm(total=num_steps)\n",
    "\n",
    "    print('Training Fqe...')\n",
    "    for t in range(num_steps):\n",
    "        batch = random.sample(mem, batch_size)\n",
    "        \n",
    "        r = torch.tensor([x[2] for x in batch], dtype=torch.float, device = device).view(batch_size,1)\n",
    "        terminals = [x[4] for x in batch]\n",
    "        terminals = torch.tensor(np.array(terminals).astype(float),device = device, dtype=torch.float).view(batch_size,1)\n",
    "        o1 = torch.tensor([x[0] for x in batch], dtype=torch.float, device = device) \n",
    "        a1 = torch.tensor([x[1] for x in batch], dtype = torch.float, device = device) \n",
    "\n",
    "        o2 = torch.tensor([x[3] for x in batch], dtype=torch.float, device = device)\n",
    "        \n",
    "        a2 = policy.get_batch_action(o2)\n",
    "\n",
    "        q_target = target_critic(torch.cat((o2, a2), -1)).detach()\n",
    "        current_discount = discount * (1 - terminals)\n",
    "        backup = r + current_discount * q_target\n",
    "        backup = torch.clamp(backup, min_value, max_value) # prevent explosion\n",
    "            \n",
    "        q = critic(torch.cat((o1, a1), -1))\n",
    "        critic_loss = ((q - backup) ** 2).mean()\n",
    "\n",
    "        critic_optimizer.zero_grad()\n",
    "        critic_loss.backward()\n",
    "        critic_optimizer.step()\n",
    "        \n",
    "        if t % target_update_period == 0:\n",
    "            with torch.no_grad():\n",
    "                for p, p_targ in zip(critic.parameters(), target_critic.parameters()):\n",
    "                    p_targ.data.mul_(polyak)\n",
    "                    p_targ.data.add_((1 - polyak) * p.data)\n",
    "\n",
    "        if verbose:\n",
    "            counter.update(1)\n",
    "\n",
    "    return critic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "critic = FQE(agent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def play_fqe(critic,eval_type,batch_size=None):\n",
    "    device = 'cpu'\n",
    "    if batch_size is None:\n",
    "        batch,_ = process_batch(train=False,eval_type=eval_type)\n",
    "    else: \n",
    "        mem,_ = process_batch(train=False,eval_type=eval_type)\n",
    "        batch = random.sample(mem, batch_size)\n",
    "    \n",
    "    batch_size = len(batch)\n",
    "    \n",
    "    pre_state_batch = torch.tensor([x[0] for x in batch], dtype=torch.float, device = device) \n",
    "    action_batch = torch.tensor([x[1] for x in batch], dtype = torch.float, device = device) \n",
    "        # view to make later computation happy\n",
    "    reward_batch = torch.tensor([x[2] for x in batch], dtype=torch.float, device = device).view(batch_size,1)\n",
    "    next_state_batch = torch.tensor([x[3] for x in batch], dtype=torch.float, device = device)\n",
    "    if_end = [x[4] for x in batch]\n",
    "    if_end = torch.tensor(np.array(if_end).astype(float),device = device, dtype=torch.float).view(batch_size,1)\n",
    "    \n",
    "    q = critic(torch.cat((pre_state_batch, action_batch), -1))\n",
    "        \n",
    "    return torch.mean(q),torch.mean(reward_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_fqe = 10000\n",
    "aq = []\n",
    "pq = []\n",
    "for i in range(num_fqe):\n",
    "    agent_q,phy_q = play_fqe(critic,'test',128)\n",
    "    aq.append(agent_q)\n",
    "    pq.append(phy_q)\n",
    "\n",
    "print('agent',torch.mean(aq))\n",
    "print('physician',torch.mean(pq))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "l = [str(x) for x in range(len(critic_losses))]\n",
    "\n",
    "f, (ax1) = plt.subplots(1, 1, sharex='col', sharey='row', figsize = (7.5,4))\n",
    "ax1.plot(l, critic_losses, color='r')\n",
    "ax1.set_title('DDPG - critic_losses')\n",
    "x_r = [i for i in range(0,30000,5000)]\n",
    "ax1.set_xticks(x_r)\n",
    "ax1.grid()\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, (ax1) = plt.subplots(1, 1, sharex='col', sharey='row', figsize = (7.5,4))\n",
    "plt.plot(l,actor_losses,'b')\n",
    "x_iv = [i for i in range(0,30000,5000)]\n",
    "ax1.set_xticks(x_iv)\n",
    "ax1.grid()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#phys_q_train,action_batch_train,agent_q_train,agent_actions_train = play2(agent,'train')\n",
    "phys_q_val,action_batch_val,agent_q_val,agent_actions_val,id_val,die_val = play_presents(agent,'val')\n",
    "phys_q_test,action_batch_test,agent_q_test,agent_actions_test,id_test,die_test = play_presents(agent,'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "action_train_iv_agent = agent_actions_test[:,0].numpy()\n",
    "action_train_vaso_agent = agent_actions_test[:,1].numpy()\n",
    "action_train_iv_phy = action_batch_test[:,0].numpy()\n",
    "action_train_vaso_phy = action_batch_test[:,1].numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.Series(action_train_iv_agent).plot.hist(range=[0,2000],bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.Series(action_train_iv_phy).plot.hist(range=[0,2000],bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.Series(action_train_vaso_agent).plot.hist(range=[0,1],bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.Series(action_train_vaso_phy).plot.hist(range=[0,1],bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.Series(agent_q_train.numpy().squeeze(1)).plot.hist(range=[-10000,20000],bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.Series(phys_q_train.numpy().squeeze(1)).plot.hist(range=[-20000,20000],bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.mean(agent_q_test.numpy().squeeze(1)))\n",
    "print(np.mean(phys_q_test.numpy().squeeze(1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle as pickle\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vq = np.array([0.08,0.2,0.45])\n",
    "ivq = np.array([48,150,500])\n",
    "_agent = pd.DataFrame(columns=['iv','vaso'])\n",
    "_phy = pd.DataFrame(columns=['iv','vaso','died_in_hosp'])\n",
    "\n",
    "_agent['iv']= agent_actions_test[:,0]\n",
    "_agent['vaso']= agent_actions_test[:,1]\n",
    "\n",
    "_phy['iv']= action_batch_test[:,0]\n",
    "_phy['vaso']= action_batch_test[:,1]\n",
    "_phy['died_in_hosp'] = test['died_in_hosp']\n",
    "_phy.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "_agent['vaso'][_agent['vaso'] == 0.0] = 0\n",
    "_agent['vaso'][(_agent['vaso'] > 0.0) & (_agent['vaso'] < vq[0])] = 1\n",
    "_agent['vaso'][(_agent['vaso'] >= vq[0]) & (_agent['vaso'] < vq[1])] = 2\n",
    "_agent['vaso'][(_agent['vaso'] >= vq[1]) & (_agent['vaso'] < vq[2])] = 3\n",
    "a = _agent['vaso'] >= vq[2]\n",
    "_agent['vaso'][a] = 4\n",
    "\n",
    "_phy['vaso'][_phy['vaso'] == 0.0] = 0\n",
    "_phy['vaso'][(_phy['vaso'] > 0.0) & (_phy['vaso'] < vq[0])] = 1\n",
    "_phy['vaso'][(_phy['vaso'] >= vq[0]) & (_phy['vaso'] < vq[1])] = 2\n",
    "_phy['vaso'][(_phy['vaso'] >= vq[1]) & (_phy['vaso'] < vq[2])] = 3\n",
    "a = _phy['vaso'] >= vq[2]\n",
    "_phy['vaso'][a] = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_agent['iv'][_agent['iv'] == 0.0] = 0\n",
    "_agent['iv'][(_agent['iv'] > 0.0) & (_agent['iv'] < ivq[0])] = 1\n",
    "_agent['iv'][(_agent['iv'] >= ivq[0]) & (_agent['iv'] < ivq[1])] = 2\n",
    "_agent['iv'][(_agent['iv'] >= ivq[1]) & (_agent['iv'] < ivq[2])] = 3\n",
    "a = _agent['iv'] >= ivq[2]\n",
    "_agent['iv'][a] = 4\n",
    "\n",
    "_phy['iv'][_phy['iv'] == 0.0] = 0\n",
    "_phy['iv'][(_phy['iv'] > 0.0) & (_phy['iv'] < ivq[0])] = 1\n",
    "_phy['iv'][(_phy['iv'] >= ivq[0]) & (_phy['iv'] < ivq[1])] = 2\n",
    "_phy['iv'][(_phy['iv'] >= ivq[1]) & (_phy['iv'] < ivq[2])] = 3\n",
    "a = _phy['iv'] >= ivq[2]\n",
    "_phy['iv'][a] = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_iv = _agent['iv']\n",
    "agent_vaso = _agent['vaso']\n",
    "phy_iv = _phy['iv']\n",
    "phy_vaso = _phy['vaso']\n",
    "\n",
    "hist1, _, _ = np.histogram2d(phy_iv, phy_vaso, bins=5)\n",
    "hist2, _, _ = np.histogram2d(agent_iv, agent_vaso, bins=5)\n",
    "\n",
    "x_edges = np.arange(-0.5,5)\n",
    "y_edges = np.arange(-0.5,5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, (ax1, ax2) = plt.subplots(1, 2, figsize=(16,4))\n",
    "ax1.imshow(np.flipud(hist1), cmap=\"Blues\",extent=[x_edges[0], x_edges[-1],  y_edges[0],y_edges[-1]])\n",
    "ax2.imshow(np.flipud(hist2), cmap=\"OrRd\", extent=[x_edges[0], x_edges[-1],  y_edges[0],y_edges[-1]])\n",
    "\n",
    "ax1.set_xticks(np.arange(0, 5, 1));\n",
    "ax1.set_yticks(np.arange(0, 5, 1));\n",
    "ax2.set_xticks(np.arange(0, 5, 1));\n",
    "ax2.set_yticks(np.arange(0, 5, 1));\n",
    "\n",
    "ax1.set_xticklabels(np.arange(0, 5, 1));\n",
    "ax1.set_yticklabels(np.arange(0, 5, 1));\n",
    "ax2.set_xticklabels(np.arange(0, 5, 1));\n",
    "ax2.set_yticklabels(np.arange(0, 5, 1));\n",
    "\n",
    "ax1.set_xticks(np.arange(-.5, 5, 1), minor=True);\n",
    "ax1.set_yticks(np.arange(-.5, 5, 1), minor=True);\n",
    "ax2.set_xticks(np.arange(-.5, 5, 1), minor=True);\n",
    "ax2.set_yticks(np.arange(-.5, 5, 1), minor=True);\n",
    "\n",
    "ax1.grid(which='minor', color='b', linestyle='-', linewidth=1)\n",
    "ax2.grid(which='minor', color='r', linestyle='-', linewidth=1)\n",
    "\n",
    "im1 = ax1.pcolormesh(x_edges, y_edges, hist1, cmap='Blues')\n",
    "f.colorbar(im1, ax=ax1, label = \"Action counts\")\n",
    "\n",
    "im2 = ax2.pcolormesh(x_edges, y_edges, hist2, cmap='OrRd')\n",
    "f.colorbar(im2, ax=ax2, label = \"Action counts\")\n",
    "\n",
    "ax1.set_ylabel('IV fluid dose')\n",
    "ax2.set_ylabel('IV fluid dose')\n",
    "ax1.set_xlabel('Vaso fluid dose')\n",
    "ax2.set_xlabel('Vaso fluid dose')\n",
    "\n",
    "ax1.set_title(\"Physician policy\")\n",
    "ax2.set_title(\"DDPG policy\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \n",
    "def make_df_diff(_agent):\n",
    "    iv_diff = np.array(_agent['iv']) - np.array(_phy['iv'])\n",
    "    vaso_diff = np.array(_agent['vaso']) - np.array(_phy['vaso'])\n",
    "    df_diff = pd.DataFrame()\n",
    "    df_diff['mort'] = np.array(_phy['died_in_hosp'])\n",
    "    df_diff['iv_diff'] = iv_diff\n",
    "    df_diff['vaso_diff'] = vaso_diff\n",
    "    return df_diff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import sem\n",
    "def make_iv_plot_data(df_diff):\n",
    "    bin_medians_iv = []\n",
    "    mort_iv = []\n",
    "    mort_std_iv= []\n",
    "    i = -800\n",
    "    while i <= 900:\n",
    "        count =df_diff.loc[(df_diff['iv_diff']>i-50) & (df_diff['iv_diff']<i+50)]\n",
    "        try:\n",
    "            res = sum(count['mort'])/float(len(count))\n",
    "            if len(count) >=2:\n",
    "                bin_medians_iv.append(i)\n",
    "                mort_iv.append(res)\n",
    "                mort_std_iv.append(sem(count['mort']))\n",
    "        except ZeroDivisionError:\n",
    "            pass\n",
    "        i += 100\n",
    "    return bin_medians_iv, mort_iv, mort_std_iv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import sem\n",
    "def make_vaso_plot_data(df_diff):\n",
    "    bin_medians_vaso = []\n",
    "    mort_vaso= []\n",
    "    mort_std_vaso= []\n",
    "    i = -0.6\n",
    "    while i <= 0.8:\n",
    "        count =df_diff.loc[(df_diff['vaso_diff']>i-0.05) & (df_diff['vaso_diff']<i+0.05)]\n",
    "        try:\n",
    "            res = sum(count['mort'])/float(len(count)) \n",
    "            if len(count) >=2:\n",
    "                bin_medians_vaso.append(i)\n",
    "                mort_vaso.append(res)\n",
    "                mort_std_vaso.append(sem(count['mort'])) # 标准误差\n",
    "        except ZeroDivisionError:\n",
    "            pass\n",
    "        i += 0.1\n",
    "    return bin_medians_vaso, mort_vaso, mort_std_vaso # 药量，死亡率，死亡的标准误差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sliding_mean(data_array, window=1):\n",
    "    new_list = []\n",
    "    for i in range(len(data_array)):\n",
    "        indices = range(max(i - window + 1, 0),\n",
    "                        min(i + window + 1, len(data_array)))\n",
    "        avg = 0\n",
    "        for j in indices:\n",
    "            avg += data_array[j]\n",
    "        avg /= float(len(indices))\n",
    "        new_list.append(avg)     \n",
    "    return np.array(new_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vq = np.array([0.08,0.2,0.45])\n",
    "ivq = np.array([48,150,500])\n",
    "_agent = pd.DataFrame(columns=['iv','vaso'])\n",
    "_phy = pd.DataFrame(columns=['iv','vaso','died_in_hosp'])\n",
    "\n",
    "_agent['iv']= agent_actions_test[:,0]\n",
    "_agent['vaso']= agent_actions_test[:,1]\n",
    "\n",
    "_phy['iv']= action_batch_test[:,0]\n",
    "_phy['vaso']= action_batch_test[:,1]\n",
    "_phy['died_in_hosp'] = test['died_in_hosp']\n",
    "_phy.head()\n",
    "df_diff_DDPG = make_df_diff(_agent)\n",
    "bin_med_iv_DDPG, mort_iv_DDPG, mort_std_iv_DDPG = make_iv_plot_data(df_diff_DDPG)\n",
    "bin_med_vaso_DDPG, mort_vaso_DDPG, mort_std_vaso_DDPG = make_vaso_plot_data(df_diff_DDPG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "f, (ax1,ax2) = plt.subplots(1, 2, sharex='col', sharey='row', figsize = (7.5,4))\n",
    "ax1.plot(bin_med_vaso_DDPG, sliding_mean(mort_vaso_DDPG), color='r')\n",
    "ax1.fill_between(bin_med_vaso_DDPG, sliding_mean(mort_vaso_DDPG) - 1*mort_std_vaso_DDPG,  \n",
    "                 sliding_mean(mort_vaso_DDPG) + 1*mort_std_vaso_DDPG, color='tomato')\n",
    "ax1.set_title('DDPG - Vasopressors')\n",
    "x_r = [i/10.0 for i in range(-6,10,2)]\n",
    "y_r = [i/20.0 for i in range(0,20,1)]\n",
    "ax1.set_xticks(x_r)\n",
    "ax1.set_yticks(y_r)\n",
    "ax1.grid()\n",
    "\n",
    "ax2.plot(bin_med_iv_DDPG, sliding_mean(mort_iv_DDPG), color='r')\n",
    "ax2.fill_between(bin_med_iv_DDPG, sliding_mean(mort_iv_DDPG) - 1*mort_std_iv_DDPG,  \n",
    "                 sliding_mean(mort_iv_DDPG) + 1*mort_std_iv_DDPG, color='tomato')\n",
    "ax2.set_title('DDPG - IV fluids')\n",
    "x_iv = [i for i in range(-800,900,400)]\n",
    "ax2.set_xticks(x_iv)\n",
    "ax2.grid()\n",
    "\n",
    "plt.tight_layout()\n",
    "f.text(0.225, -0.03, 'Difference between optimal and physician vasopressor dose', ha='center', fontsize=10)\n",
    "f.text(0.775, -0.03, 'Difference between optimal and physician IV dose', ha='center', fontsize=10)\n",
    "# f.text(-0.02, 0.5, 'Observed Mortality', va='center', rotation='vertical', fontsize = 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "phys_df = pd.DataFrame(agent_q_test.numpy().squeeze(1))\n",
    "import copy\n",
    "phys_df['mort'] = copy.deepcopy(np.array(_phy['died_in_hosp']))\n",
    "phys_df.tail()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bin_medians = []\n",
    "mort = []\n",
    "mort_std = []\n",
    "i = -15000\n",
    "while i <= 0:\n",
    "    count =phys_df.loc[(phys_df[0]>i-250) & (phys_df[0]<i+250)]\n",
    "    try:\n",
    "        res = sum(count['mort'])/float(len(count))\n",
    "        if len(count) >=2:\n",
    "            bin_medians.append(i)\n",
    "            mort.append(res)\n",
    "            mort_std.append(sem(count['mort']))\n",
    "    except ZeroDivisionError:\n",
    "        pass\n",
    "    i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6, 4.5))\n",
    "plt.plot(bin_medians, sliding_mean(mort))\n",
    "plt.fill_between(bin_medians, sliding_mean(mort) - 1*sliding_mean(mort_std),  \n",
    "                 sliding_mean(mort) + 1*sliding_mean(mort_std), color='#ADD8E6')\n",
    "plt.grid()\n",
    "plt.xticks(range(-15000,20,15000))\n",
    "r = [float(i)/10 for i in range(0,11,1)]\n",
    "_ = plt.yticks(r)\n",
    "_ = plt.title(\"Mortality vs Expected Return\", fontsize=15)  \n",
    "_ = plt.ylabel(\"Proportion Mortality\")\n",
    "_ = plt.xlabel(\"Expected Return\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = './DDPG_results/'\n",
    "with open(save_dir + 'DDPG_autoencode_agent_actions_train.p', 'wb') as f:\n",
    "    pickle.dump(agent_actions_train, f)\n",
    "with open(save_dir + 'DDPG_autoencode_agent_actions_val.p', 'wb') as f:\n",
    "    pickle.dump(agent_actions_val, f)\n",
    "with open(save_dir + 'DDPG_autoencode_agent_actions_test.p', 'wb') as f:\n",
    "    pickle.dump(agent_actions_test, f)\n",
    "    \n",
    "with open(save_dir + 'DDPG_autoencode_agent_q_train.p', 'wb') as f:\n",
    "    pickle.dump(agent_q_train, f)\n",
    "with open(save_dir + 'DDPG_autoencode_agent_q_val.p', 'wb') as f:\n",
    "    pickle.dump(agent_q_val, f)\n",
    "with open(save_dir + 'DDPG_autoencode_agent_q_test.p', 'wb') as f:\n",
    "    pickle.dump(agent_q_test, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(save_dir + 'DDPG_autoencode_phy_actions_train.p', 'wb') as f:\n",
    "    pickle.dump(action_batch_train, f)\n",
    "with open(save_dir + 'DDPG_autoencode_phy_actions_val.p', 'wb') as f:\n",
    "    pickle.dump(action_batch_val, f)\n",
    "with open(save_dir + 'DDPG_autoencode_phy_actions_test.p', 'wb') as f:\n",
    "    pickle.dump(action_batch_test, f)\n",
    "    \n",
    "with open(save_dir + 'DDPG_autoencode_phy_q_train.p', 'wb') as f:\n",
    "    pickle.dump(phys_q_train, f)\n",
    "with open(save_dir + 'DDPG_autoencode_phy_q_val.p', 'wb') as f:\n",
    "    pickle.dump(phys_q_val, f)\n",
    "with open(save_dir + 'DDPG_autoencode_phy_q_test.p', 'wb') as f:\n",
    "    pickle.dump(phys_q_test, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  },
  "vscode": {
   "interpreter": {
    "hash": "e18c91ba94b3e8dba19516f2d61dd1bbd78801f86b94afb4022d4b178189e8b2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
