{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from tensorflow.python.ops.rnn import rnn_cell_impl, _should_cache, nest, vs, tensor_shape, _is_keras_rnn_cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0530 18:49:21.992362 140003711117120 lazy_loader.py:50] \n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "W0530 18:49:21.992774 140003711117120 module_wrapper.py:139] From qnetwork.py:50: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import os \n",
    "import multiprocessing as mp\n",
    "from qnetwork import *\n",
    "from utils import *\n",
    "import pandas as pd\n",
    "from sklearn.metrics import roc_auc_score, average_precision_score\n",
    "import scipy.stats as stats\n",
    "rnn = tf.contrib.rnn\n",
    "slim = tf.contrib.slim\n",
    "import random\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "SEED = 2599\n",
    "np.random.seed(SEED)\n",
    "tf.set_random_seed(SEED)\n",
    "random.seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_shock_train = pd.read_csv(\"./data/df_shock_train.csv\", index_col=\"TrainSampleIdx\")\n",
    "df_shock_test = pd.read_csv(\"./data/df_shock_test.csv\", index_col=\"TrainSampleIdx\")\n",
    "df_non_shock_train = pd.read_csv(\"./data/df_non_shock_train.csv\", index_col=\"TrainSampleIdx\")\n",
    "df_non_shock_test = pd.read_csv(\"./data/df_non_shock_test.csv\", index_col=\"TrainSampleIdx\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# determine a numerical value to represent nan values\n",
    "_max = -np.infty\n",
    "_min = np.infty\n",
    "for _df in [df_shock_train, df_non_shock_train]:\n",
    "    _df_values = np.copy(_df.values)\n",
    "    _df_values[np.isnan(_df.values)] = 0.\n",
    "    if np.max(_df_values) > _max:\n",
    "        _max = np.max(_df_values)\n",
    "    if np.min(_df_values) < _min:\n",
    "        _min = np.min(_df_values)\n",
    "\n",
    "nan_replacement = 3*_max\n",
    "# nan_replacement = 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# determine the max sequence length\n",
    "max_seq_len = -np.infty\n",
    "for _df in [df_shock_train, df_non_shock_train, df_shock_test, df_non_shock_test]:\n",
    "    max_for_current_df = np.max(np.unique(_df.index.values, return_counts=True)[1])\n",
    "    if max_for_current_df > max_seq_len:\n",
    "        max_seq_len = max_for_current_df\n",
    "\n",
    "\n",
    "# replace nan values\n",
    "for _df in [df_shock_train, df_non_shock_train, df_shock_test, df_non_shock_test]:\n",
    "    _df[_df.isna()]=nan_replacement\n",
    "\n",
    "def seq_length(sequence):\n",
    "    used = tf.sign(tf.reduce_max(tf.abs(sequence), 2))\n",
    "    length = tf.reduce_sum(used, 1)\n",
    "    length = tf.cast(length, tf.int32)\n",
    "    return length\n",
    "\n",
    "def gen_train():\n",
    "    # Output mask's dimensions correspond to [num_timesteps, batch_size, num_input/sequence_length]\n",
    "    for i in df_shock_train.index.unique():\n",
    "        current_df = df_shock_train.loc[i]\n",
    "        if isinstance(current_df, pd.core.frame.DataFrame):\n",
    "            current_values = df_shock_train.loc[i].values\n",
    "            out = np.vstack([current_values, np.zeros((max_seq_len-current_values.shape[0], current_values.shape[1]))])\n",
    "            mask = out == nan_replacement\n",
    "            mask = mask.astype(np.int)\n",
    "            label = np.array([0., 1.])\n",
    "            yield out, label, mask\n",
    "    for i in df_non_shock_train.index.unique():\n",
    "        current_df = df_non_shock_train.loc[i]\n",
    "        if isinstance(current_df, pd.core.frame.DataFrame):\n",
    "            current_values = df_non_shock_train.loc[i].values\n",
    "            out = np.vstack([current_values, np.zeros((max_seq_len-current_values.shape[0], current_values.shape[1]))])\n",
    "            mask = out == nan_replacement\n",
    "            mask = mask.astype(np.int)\n",
    "            label = np.array([1., 0.])\n",
    "            yield out, label, mask\n",
    "        \n",
    "\n",
    "def gen_test():\n",
    "    # Output mask's dimensions correspond to [num_timesteps, batch_size, num_input/sequence_length]\n",
    "    for i in df_shock_test.index.unique():\n",
    "        current_df = df_shock_test.loc[i]\n",
    "        if isinstance(current_df, pd.core.frame.DataFrame):\n",
    "            current_values = df_shock_test.loc[i].values\n",
    "            out = np.vstack([current_values, np.zeros((max_seq_len-current_values.shape[0], current_values.shape[1]))])\n",
    "            mask = out == nan_replacement\n",
    "            mask = mask.astype(np.int)\n",
    "            label = np.array([0., 1.])\n",
    "            yield out, label, mask\n",
    "    for i in df_non_shock_test.index.unique():\n",
    "        current_df = df_non_shock_test.loc[i]\n",
    "        if isinstance(current_df, pd.core.frame.DataFrame):\n",
    "            current_values = df_non_shock_test.loc[i].values\n",
    "            out = np.vstack([current_values, np.zeros((max_seq_len-current_values.shape[0], current_values.shape[1]))])\n",
    "            mask = out == nan_replacement\n",
    "            mask = mask.astype(np.int)\n",
    "            label = np.array([1., 0.])\n",
    "            yield out, label, mask\n",
    "        \n",
    "\n",
    "# Setting up the truncated normal distribution for exploration\n",
    "\n",
    "lower, upper = 0, 1\n",
    "mu, sigma = 0, 0.2\n",
    "left_truncnorm = stats.truncnorm(\n",
    "    (lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)\n",
    "right_truncnorm = stats.truncnorm(\n",
    "    (lower - 1.) / sigma, (upper - 1.) / sigma, loc=1., scale=sigma)\n",
    "\n",
    "# fig, ax = plt.subplots(1, sharex=True)\n",
    "# ax.hist(np.concatenate([left_truncnorm.rvs(10000),right_truncnorm.rvs(10000)]), normed=True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0530 18:51:05.907478 140003711117120 deprecation.py:323] From <ipython-input-12-8bcd202987d1>:68: make_one_shot_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.\n",
      "W0530 18:51:05.966027 140003711117120 deprecation.py:323] From <ipython-input-12-8bcd202987d1>:45: __init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.\n",
      "W0530 18:51:05.984174 140003711117120 deprecation.py:323] From /home/gaoqitong/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:735: add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.add_weight` method instead.\n",
      "W0530 18:51:05.994344 140003711117120 deprecation.py:506] From /home/gaoqitong/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:739: calling __init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "W0530 18:51:07.333360 140003711117120 deprecation.py:323] From /home/gaoqitong/anaconda2/lib/python2.7/site-packages/tensorflow_core/contrib/layers/python/layers/layers.py:1866: apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "W0530 18:51:07.363034 140003711117120 deprecation.py:323] From <ipython-input-12-8bcd202987d1>:78: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "\n",
      "Future major versions of TensorFlow will allow gradients to flow\n",
      "into the labels input on backprop by default.\n",
      "\n",
      "See `tf.nn.softmax_cross_entropy_with_logits_v2`.\n",
      "\n",
      "W0530 18:51:33.353984 140003711117120 deprecation.py:323] From qnetwork.py:37: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Deprecated in favor of operator or tf.math.divide.\n",
      "W0530 18:51:33.437069 140003711117120 module_wrapper.py:139] From qnetwork.py:47: The name tf.train.Saver is deprecated. Please use tf.compat.v1.train.Saver instead.\n",
      "\n",
      "W0530 18:51:33.791659 140003711117120 module_wrapper.py:139] From qnetwork.py:137: The name tf.losses.mean_squared_error is deprecated. Please use tf.compat.v1.losses.mean_squared_error instead.\n",
      "\n",
      "W0530 18:51:33.801971 140003711117120 deprecation.py:323] From /home/gaoqitong/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/losses/losses_impl.py:121: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    }
   ],
   "source": [
    "SEED = 2599\n",
    "np.random.seed(SEED)\n",
    "tf.set_random_seed(SEED)\n",
    "random.seed(SEED)\n",
    "\n",
    "# RL learning rates\n",
    "actor_lr, critic_lr = 0.0005, 0.0001\n",
    "\n",
    "# Prediction Model Parameters\n",
    "start_learning_rate = 0.0005\n",
    "decay_step = 500\n",
    "decay_rate = 1.\n",
    "num_hidden = 1024 \n",
    "\n",
    "# Threshold for decaying RL learning rates\n",
    "rl_reward_thres_for_decay = -20\n",
    "\n",
    "session_config = tf.ConfigProto(log_device_placement=False)\n",
    "session_config.gpu_options.allow_growth = True\n",
    "\n",
    "training_steps = 2000\n",
    "batch_size = 128\n",
    "\n",
    "num_input = 15 \n",
    "timesteps = max_seq_len # timesteps\n",
    "num_classes = 2 \n",
    "\n",
    "display_step = 10\n",
    "\n",
    "gpu = 0\n",
    "\n",
    "graph = tf.Graph()\n",
    "\n",
    "file_appendix = \"MIMIC_LSTMRL_MaskGradients_\" + str(start_learning_rate) + \"_\" + str(decay_step) + \"_\" + str(decay_rate) + \"_\" + str(num_hidden) + \"_\" + str(actor_lr) + \"_\" + str(critic_lr)\n",
    "\n",
    "\n",
    "def build_net(x, is_training=True, reuse=tf.AUTO_REUSE, graph=graph):\n",
    "\n",
    "        with graph.as_default():\n",
    "            with tf.variable_scope(\"lstm\", reuse=reuse) as scope:\n",
    "                # LSTM Encoder\n",
    "                seq_len = seq_length(x)\n",
    "                enumerated_last_idxs = tf.cast(tf.stack([seq_len-1, tf.range(tf.shape(seq_len)[0])], axis=1), tf.int32)\n",
    "                x = tf.unstack(x, timesteps, 1)\n",
    "                lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0, reuse=reuse)\n",
    "                outputs, state, all_states = my_static_rnn(lstm_cell, x, dtype=tf.float32)\n",
    "                last_outputs = tf.gather_nd(outputs, enumerated_last_idxs)\n",
    "                # Output Layer\n",
    "                with slim.arg_scope([slim.fully_connected], \n",
    "                                        activation_fn=tf.nn.relu,\n",
    "                                        weights_initializer=tf.random_uniform_initializer(0.001, 0.01),\n",
    "                                        weights_regularizer=slim.l2_regularizer(0.005),\n",
    "                                        biases_regularizer=slim.l2_regularizer(0.005),\n",
    "                                        normalizer_fn = slim.batch_norm,\n",
    "                                        normalizer_params = {\"is_training\": is_training},\n",
    "                                        reuse = reuse,\n",
    "                                        scope = scope):\n",
    "\n",
    "                    logits = slim.fully_connected(last_outputs,num_classes,activation_fn=None, weights_regularizer=None, normalizer_fn=None, scope='logits')\n",
    "                    pred = slim.softmax(logits, scope='pred')\n",
    "\n",
    "                    return logits, pred, outputs, x, all_states, seq_len\n",
    "\n",
    "\n",
    "with graph.as_default():\n",
    "\n",
    "    dataset_train = tf.data.Dataset.from_generator(gen_train, (tf.float32, tf.float32, tf.int32), ([ timesteps, 15],[ 2],[timesteps, 15])).repeat(1000).shuffle(5000).batch(batch_size)\n",
    "    input_train, label_train, mask_train = dataset_train.make_one_shot_iterator().get_next()\n",
    "\n",
    "    dataset_test = tf.data.Dataset.from_generator(gen_test, (tf.float32, tf.float32, tf.int32), ([ timesteps, 15],[ 2],[timesteps, 15])).repeat(10000).batch(len(df_shock_test.index.unique())+len(df_non_shock_test.index.unique()))\n",
    "    input_test, label_test, mask_test = dataset_test.make_one_shot_iterator().get_next()\n",
    "\n",
    "    input_train_holder = tf.placeholder(shape=[batch_size, timesteps, num_input], dtype=tf.float32)\n",
    "    label_train_holder = tf.placeholder(shape=[batch_size, 2], dtype=tf.float32)\n",
    "    mask_train_holder = tf.placeholder(shape=[batch_size, timesteps, num_input], dtype=tf.int32)\n",
    "\n",
    "    logits, prediction, outs, xs, states, seq_lens = build_net(input_train_holder)\n",
    "    loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=label_train_holder) + tf.reduce_mean(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=\"lstm\")), axis=0)\n",
    "    learning_rate = tf.train.exponential_decay(start_learning_rate, tf.train.get_or_create_global_step(), decay_steps=decay_step, decay_rate=decay_rate)\n",
    "    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n",
    "\n",
    "    missing_idxs = tf.where_v2(mask_train)\n",
    "    missing_idxs = tf.stack([missing_idxs[:,1], missing_idxs[:,0], missing_idxs[:,2]], axis=-1)\n",
    "\n",
    "    # Get the encoding LSTM weights and obtain the gradients using regular SGD sovler\n",
    "\n",
    "    i_gates = [graph.get_tensor_by_name(\"lstm/rnn/basic_lstm_cell/split_\"+str(t)+\":0\") if t>0 else graph.get_tensor_by_name(\"lstm/rnn/basic_lstm_cell/split:0\") for t in range(timesteps)]\n",
    "    j_gates = [graph.get_tensor_by_name(\"lstm/rnn/basic_lstm_cell/split_\"+str(t)+\":1\") if t>0 else graph.get_tensor_by_name(\"lstm/rnn/basic_lstm_cell/split:1\") for t in range(timesteps)]\n",
    "    f_gates = [graph.get_tensor_by_name(\"lstm/rnn/basic_lstm_cell/split_\"+str(t)+\":2\") if t>0 else graph.get_tensor_by_name(\"lstm/rnn/basic_lstm_cell/split:2\") for t in range(timesteps)]\n",
    "    o_gates = [graph.get_tensor_by_name(\"lstm/rnn/basic_lstm_cell/split_\"+str(t)+\":3\") if t>0 else graph.get_tensor_by_name(\"lstm/rnn/basic_lstm_cell/split:3\") for t in range(timesteps)]\n",
    "\n",
    "    grads_i = optimizer.compute_gradients(loss_op,i_gates)\n",
    "    grads_i = [g[0] for g in grads_i]\n",
    "    grads_j = optimizer.compute_gradients(loss_op,j_gates)\n",
    "    grads_j = [g[0] for g in grads_j]\n",
    "    grads_f = optimizer.compute_gradients(loss_op,f_gates)\n",
    "    grads_f = [g[0] for g in grads_f]\n",
    "    grads_o = optimizer.compute_gradients(loss_op,o_gates)\n",
    "    grads_o = [g[0] for g in grads_o]\n",
    "\n",
    "    grads_i_j_f_o = [tf.concat([grads_i[t], grads_j[t], grads_f[t], grads_o[t]], axis=1) for t in range(timesteps)]\n",
    "    \n",
    "    # Apply importance to the gradients calculated from regular SGD solver\n",
    "    \n",
    "    grad_attention = tf.placeholder(shape=[timesteps, batch_size, num_input], dtype=tf.float32)\n",
    "    xs_for_grads = tf.multiply(xs, grad_attention)\n",
    "    W_grads = tf.tensordot(xs_for_grads, grads_i_j_f_o, axes=[[0,1],[0,1]])/batch_size\n",
    "\n",
    "    enumerated_seq_lens = tf.cast(tf.stack([seq_lens, tf.range(tf.shape(seq_lens)[0])], axis=1), tf.int32)\n",
    "\n",
    "    def cond(i, e, o):\n",
    "        return i < batch_size\n",
    "    def body(i, e, o):\n",
    "        o = tf.concat([o,tf.stack([tf.range(e[i,0]),tf.repeat(e[i,1],e[i,0])],axis=-1)],axis=0)\n",
    "        return i+1, e, o\n",
    "\n",
    "    _,_,nonzero_out_idxs = tf.while_loop(cond,body,[tf.constant(1, dtype=tf.int32), enumerated_seq_lens, tf.stack([tf.range(enumerated_seq_lens[0,0]),tf.repeat(enumerated_seq_lens[0,1],enumerated_seq_lens[0,0])],axis=-1)], shape_invariants=[tf.TensorShape([]),tf.TensorShape([None,2]),tf.TensorShape([None,2])])\n",
    "\n",
    "    outs_non_zero = tf.gather_nd(outs,nonzero_out_idxs)\n",
    "    outs_updates = tf.scatter_nd(indices=nonzero_out_idxs, updates=outs_non_zero, shape=[timesteps, batch_size, num_hidden])\n",
    "    outs = tf.zeros((timesteps,batch_size,num_hidden)) + outs_updates\n",
    "    U_grads = tf.tensordot(outs, grads_i_j_f_o, axes=[[0,1],[0,1]])/batch_size\n",
    "    lstm_kernel_grads = tf.concat([W_grads,U_grads],axis=0)     \n",
    "\n",
    "    logits_final, pred_final, _, _, _, _ = build_net(input_test, is_training=False)\n",
    "\n",
    "\n",
    "    grads = optimizer.compute_gradients(loss_op, [v for v in tf.trainable_variables() if v.name.find(\"lstm\")!=-1])\n",
    "    grads = [g[0] for g in grads]\n",
    "\n",
    "    grads[0] = lstm_kernel_grads\n",
    "\n",
    "\n",
    "    grads_update_op = optimizer.apply_gradients(zip(grads, [v for v in tf.trainable_variables() if v.name.find(\"lstm\")!=-1]))\n",
    "    \n",
    "    # Setting up metrics\n",
    "    \n",
    "    train_correct_pred = tf.equal(tf.cast(tf.argmax(prediction, 1),tf.float32), tf.cast(tf.argmax(label_train_holder, 1),tf.float32) )\n",
    "    train_accuracy = tf.reduce_mean(tf.cast(train_correct_pred, tf.float32))\n",
    "    train_kld = tf.keras.losses.KLDivergence()(prediction, label_train_holder)\n",
    "\n",
    "    final_correct_pred = tf.equal(tf.cast(tf.argmax(pred_final, 1), tf.float32), tf.cast(tf.argmax(label_test, 1),tf.float32))\n",
    "    final_accuracy = tf.reduce_mean(tf.cast(final_correct_pred, tf.float32))\n",
    "    final_kld = tf.keras.losses.KLDivergence()(pred_final, label_test)\n",
    "\n",
    "    final_score = pred_final[:,1]\n",
    "\n",
    "    max_final_acc = tf.Variable(0, dtype=tf.float32, name=\"max_final_acc\", trainable=False)\n",
    "    assign_max_final_acc = max_final_acc.assign(final_accuracy)\n",
    "\n",
    "with graph.as_default():\n",
    "    actor = Actor(graph=graph, state_dim=num_input*2+num_hidden*2, action_dim=num_input, learning_rate=actor_lr, tau=0.001, batch_size=batch_size, save_path=\"./saved_model/\"+file_appendix+\"/actor.ckpt\")\n",
    "    critic = Critic(graph=graph, state_dim=num_input*2+num_hidden*2, action_dim=num_input, learning_rate=critic_lr, tau=0.001, gamma=0.99, save_path=\"./saved_model/\"+file_appendix+\"/critic.ckpt\")\n",
    "    init = tf.global_variables_initializer()\n",
    "    saver = tf.train.Saver()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Start training\n",
    "with tf.Session(config=session_config, graph=graph) as sess:\n",
    "    sess.run(init)\n",
    "    \n",
    "    # Probability of random exploration (p3 in Appendix D) in the behavioral policy\n",
    "    ## This probability will be decayed exponentially during training\n",
    "    EXPLORATION_RATE = 0.6\n",
    "    \n",
    "    # Probability of following the heuristic (p2 in Appendix D) in the behavioral policy\n",
    "    ## This probability will be decayed exponentially during training\n",
    "    GUIDE_RATE = 0.15\n",
    "    \n",
    "    \n",
    "    \n",
    "    ep_reward = 0\n",
    "    ep_ave_max_q = 0\n",
    "\n",
    "    data_in, label_in, s_mask = sess.run([input_train, label_train, mask_train])\n",
    "\n",
    "\n",
    "    s_1, s_2 = sess.run([states, outs], feed_dict = {input_train_holder:data_in, label_train_holder:label_in, mask_train_holder:s_mask})\n",
    "    s = np.concatenate([np.asarray(np.split(data_in,timesteps,axis=1)).reshape(timesteps,batch_size,num_input),\n",
    "                   np.asarray(np.split(s_mask,timesteps,axis=1)).reshape(timesteps,batch_size,num_input)\n",
    "                        ,s_1,s_2], axis=-1)\n",
    "\n",
    "\n",
    "    reward_list = []\n",
    "    ave_max_q_list = []\n",
    "    replay_buffer = ReplayBuffer(10**5, random_seed=SEED)\n",
    "\n",
    "    # Run the initializer\n",
    "\n",
    "\n",
    "    max_auc = 0.\n",
    "    max_ap = 0.\n",
    "\n",
    "    actor.update_target_network(sess)\n",
    "    critic.update_target_network(sess)\n",
    "\n",
    "    for step in range(training_steps):\n",
    "        rand_num = np.random.rand(1)\n",
    "\n",
    "        if rand_num <= EXPLORATION_RATE:\n",
    "            a = np.concatenate([left_truncnorm.rvs(timesteps*batch_size*num_input/2),right_truncnorm.rvs(timesteps*batch_size*num_input/2)])\n",
    "            np.random.shuffle(a)\n",
    "            a = a.reshape(timesteps, batch_size, num_input).astype(np.float32)\n",
    "\n",
    "        elif rand_num <= GUIDE_RATE+EXPLORATION_RATE and rand_num > EXPLORATION_RATE:\n",
    "            a = np.asarray(np.split((1-s_mask).astype(np.float32), timesteps, axis=1)).reshape(timesteps,batch_size,num_input)\n",
    "\n",
    "        else:\n",
    "            a = actor.predict(s.reshape(-1,num_input*2+num_hidden*2), sess)\n",
    "            a = a.reshape(timesteps, batch_size, num_input)\n",
    "\n",
    "        _, kld = sess.run([grads_update_op, train_kld], feed_dict={grad_attention:a, input_train_holder:data_in, label_train_holder:label_in, mask_train_holder:s_mask})\n",
    "        acc, score = sess.run([final_accuracy, final_score])\n",
    "        data_in, label_in, s2_mask = sess.run([input_train, label_train, mask_train])\n",
    "        s2_1, s2_2 = sess.run([states, outs], feed_dict = {input_train_holder:data_in, label_train_holder:label_in})\n",
    "        s2 = np.concatenate([np.asarray(np.split(data_in,timesteps,axis=1)).reshape(timesteps,batch_size,num_input),\n",
    "                   np.asarray(np.split(s_mask,timesteps,axis=1)).reshape(timesteps,batch_size,num_input)\n",
    "                        ,s2_1,s2_2], axis=-1)\n",
    "        r = np.repeat(-kld, batch_size)\n",
    "        replay_buffer.add_batch([list(i) for i in zip(s.reshape(-1,num_input*2+num_hidden*2),a.reshape(-1,num_input),r,s2.reshape(-1,num_input*2+num_hidden*2))])\n",
    "\n",
    "        if replay_buffer.size() > batch_size:\n",
    "            s_batch, a_batch, r_batch, s2_batch = replay_buffer.sample_batch(batch_size)\n",
    "\n",
    "            # Calculate targets\n",
    "            target_q = critic.predict_target(\n",
    "                s2_batch, actor.predict_target(s2_batch, sess), sess)\n",
    "\n",
    "            y_i = []\n",
    "            for k in range(batch_size):\n",
    "                y_i.append(r_batch[k] + critic.gamma * target_q[k])\n",
    "\n",
    "            # Update the critic given the targets\n",
    "            predicted_q_value, _ = critic.train(\n",
    "                s_batch, a_batch, np.reshape(y_i, (batch_size, 1)), step, sess)\n",
    "\n",
    "            ave_max_q = np.amax(predicted_q_value)\n",
    "            ave_max_q_list += [ave_max_q]\n",
    "\n",
    "            # Update the actor policy using the sampled gradient\n",
    "            a_outs = actor.predict(s_batch, sess)\n",
    "            grads = critic.action_gradients(s_batch, a_outs, sess)\n",
    "            actor.train(s_batch, grads[0], step, sess)\n",
    "\n",
    "            # Update target networks\n",
    "            actor.update_target_network(sess)\n",
    "            critic.update_target_network(sess)\n",
    "\n",
    "        s = s2\n",
    "        s_mask = s2_mask\n",
    "\n",
    "        reward_list += [r[0]]\n",
    "        \n",
    "        # Decay p2 and p3 in the behavioral policy\n",
    "        EXPLORATION_RATE = EXPLORATION_RATE * 0.95\n",
    "        GUIDE_RATE = GUIDE_RATE * 0.95\n",
    "\n",
    "\n",
    "        if step % display_step == 0 and step > 0:\n",
    "            # Calculate batch loss and accuracy\n",
    "            loss, acc, train_acc = sess.run([loss_op, final_accuracy, train_accuracy], feed_dict = {input_train_holder:data_in, label_train_holder:label_in})\n",
    "            auc = roc_auc_score(np.argmax(sess.run(label_test), axis=1), final_score.eval())\n",
    "            ap = average_precision_score(np.argmax(sess.run(label_test), axis=1), final_score.eval())\n",
    "            if np.mean(reward_list[-display_step:]) >= rl_reward_thres_for_decay:\n",
    "                actor.decay_learning_rate(0.965, sess)\n",
    "                critic.decay_learning_rate(0.965, sess)\n",
    "            if acc > max_final_acc.eval():\n",
    "                max_auc = auc\n",
    "                max_ap = ap\n",
    "                sess.run(assign_max_final_acc)\n",
    "                saver.save(sess, \"./saved_model/\"+file_appendix+\"/best.ckpt\")\n",
    "            print \"Step \" + str(step) + \", Reward=\" + str(np.sum(reward_list[-display_step:])) + \", Minibatch Loss= \" + \\\n",
    "                  \"{:.4f}\".format(loss) + \", Training Accuracy= \" + \\\n",
    "                  \"{:.3f}\".format(train_acc) + \\\n",
    "                  \", Max Testing Accuracy= \", \"{:6f}\".format(max_final_acc.eval()) + \\\n",
    "                  \", Max AUC= \", \"{:6f}\".format(max_auc) + \\\n",
    "                  \", Max AP= \", \"{:6f}\".format(max_ap) + \\\n",
    "                  \", Max Q= \", \"{:6f}\".format(np.mean(ave_max_q_list[-display_step:]))\n",
    "            with open(\"./stats/rl_log/\" + file_appendix + \".txt\", \"ab\") as myfile:\n",
    "                myfile.write(\"Step \" + str(step) + \", Reward=\" + str(np.sum(reward_list[-display_step:])) + \", Minibatch Loss= \" + \"{:.4f}\".format(loss) + \", Training Accuracy= \" + \"{:.3f}\".format(train_acc) + \", Max Final Accuracy= \" + \"{:6f}\".format(max_final_acc.eval()) + \", Max AUC= \" + \"{:6f}\".format(max_auc) + \", Max AP= \" + \"{:6f}\".format(max_ap) + \"\\n\")\n",
    "\n",
    "    print \"Optimization Finished!\"\n",
    "\n",
    "    print \"Testing Accuracy:\", sess.run(max_final_acc)\n",
    "    with open(\"./stats/MIMIC_LSTMRL_maskGradients.txt\", \"ab\") as myfile:\n",
    "        myfile.write(\"%.6f\\t%i\\t%.3f\\t%.6f\\t%.6f\\t%i\\t%.6f\\t%.6f\\t%.6f\\n\" %(start_learning_rate, decay_step, decay_rate, actor_lr, critic_lr, num_hidden, max_final_acc.eval(), max_auc, max_ap))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.93317974\n",
      "0.9609462931894922\n"
     ]
    }
   ],
   "source": [
    "with tf.Session(config=session_config, graph=graph) as sess:\n",
    "    saver.restore(sess, \"../saved_model/MIMIC_LSTMRL_MaskGradients_0.0005_500_1.0_1024_0.0005_0.0001_BEST/best.ckpt\")\n",
    "    print sess.run(final_accuracy)\n",
    "    auc = roc_auc_score(np.argmax(sess.run(label_test), axis=1), final_score.eval())\n",
    "    print auc"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
