import argparse
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
import os 
import multiprocessing as mp
from qnetwork import *
from utils import *
import pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score
import scipy.stats as stats
import random

rnn = tf.contrib.rnn
slim = tf.contrib.slim

parser = argparse.ArgumentParser()
parser.add_argument("-no_gpu", dest='no_gpu', action='store_true', help="Train w/o using GPUs")
parser.add_argument("-gpu", "--gpu_idx", type=int, help="Select which GPU to use DEFAULT=0", default=0)
parser.add_argument("-lstm_hidden_size", type=int, help="Set the size of LSTM hidden states DEFAULT=1024", default=1024)
parser.add_argument("-lr_prediction_model", type=float, help="Set learning rate for training the LSTM prediction model DEFAULT=0.005", default=0.005)
parser.add_argument("-decay_step", type=int, help="Set exponential decay step DEFAULT=500", default=500)
parser.add_argument("-decay_rate", type=float, help="Set exponential decay rate DEFAULT=0.95", default=0.95)
parser.add_argument("-training_steps", type=int, help="Set max number of training epochs DEFAULT=2000", default=2000)
parser.add_argument("-seed", type=int, help="Set random seed", default=2599)


if __name__ == '__main__':
    args = parser.parse_args()
    if not args.no_gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_idx)
        session_config = tf.ConfigProto(log_device_placement=False)
        session_config.gpu_options.allow_growth = True
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = ""
        session_config = tf.ConfigProto(log_device_placement=False)
    SEED = args.seed
    np.random.seed(SEED)
    tf.set_random_seed(SEED)
    random.seed(SEED)

    if not os.path.exists("./saved_model"):
            os.mkdir("./saved_model")
    if not os.path.exists("./stats"):
            os.mkdir("./stats")
    if not os.path.exists("./stats/rl_log"):
            os.mkdir("./stats/rl_log")

    df_shock_train = pd.read_csv("./data/df_shock_train.csv", index_col="TrainSampleIdx")
    df_shock_test = pd.read_csv("./data/df_shock_test.csv", index_col="TrainSampleIdx")
    df_non_shock_train = pd.read_csv("./data/df_non_shock_train.csv", index_col="TrainSampleIdx")
    df_non_shock_test = pd.read_csv("./data/df_non_shock_test.csv", index_col="TrainSampleIdx")

    # determine a numerical value to represent nan values
    _max = -np.infty
    _min = np.infty
    for _df in [df_shock_train, df_non_shock_train]:
        _df_values = np.copy(_df.values)
        _df_values[np.isnan(_df.values)] = 0.
        if np.max(_df_values) > _max:
            _max = np.max(_df_values)
        if np.min(_df_values) < _min:
            _min = np.min(_df_values)

    nan_replacement = 3*_max
    # nan_replacement = 0.

    # determine the max sequence length
    max_seq_len = -np.infty
    for _df in [df_shock_train, df_non_shock_train, df_shock_test, df_non_shock_test]:
        max_for_current_df = np.max(np.unique(_df.index.values, return_counts=True)[1])
        if max_for_current_df > max_seq_len:
            max_seq_len = max_for_current_df


    # replace nan values
    for _df in [df_shock_train, df_non_shock_train, df_shock_test, df_non_shock_test]:
        _df[_df.isna()]=nan_replacement

    def seq_length(sequence):
        used = tf.sign(tf.reduce_max(tf.abs(sequence), 2))
        length = tf.reduce_sum(used, 1)
        length = tf.cast(length, tf.int32)
        return length

    def gen_train():
        # Output mask's dimensions correspond to [num_timesteps, batch_size, num_input/sequence_length]
        for i in df_shock_train.index.unique():
            current_df = df_shock_train.loc[i]
            if isinstance(current_df, pd.core.frame.DataFrame):
                current_values = df_shock_train.loc[i].values
                out = np.vstack([current_values, np.zeros((max_seq_len-current_values.shape[0], current_values.shape[1]))])
                mask = out == nan_replacement
                mask = mask.astype(np.int)
                label = np.array([0., 1.])
                yield out, label, mask
        for i in df_non_shock_train.index.unique():
            current_df = df_non_shock_train.loc[i]
            if isinstance(current_df, pd.core.frame.DataFrame):
                current_values = df_non_shock_train.loc[i].values
                out = np.vstack([current_values, np.zeros((max_seq_len-current_values.shape[0], current_values.shape[1]))])
                mask = out == nan_replacement
                mask = mask.astype(np.int)
                label = np.array([1., 0.])
                yield out, label, mask


    def gen_test():
        # Output mask's dimensions correspond to [num_timesteps, batch_size, num_input/sequence_length]
        for i in df_shock_test.index.unique():
            current_df = df_shock_test.loc[i]
            if isinstance(current_df, pd.core.frame.DataFrame):
                current_values = df_shock_test.loc[i].values
                out = np.vstack([current_values, np.zeros((max_seq_len-current_values.shape[0], current_values.shape[1]))])
                mask = out == nan_replacement
                mask = mask.astype(np.int)
                label = np.array([0., 1.])
                yield out, label, mask
        for i in df_non_shock_test.index.unique():
            current_df = df_non_shock_test.loc[i]
            if isinstance(current_df, pd.core.frame.DataFrame):
                current_values = df_non_shock_test.loc[i].values
                out = np.vstack([current_values, np.zeros((max_seq_len-current_values.shape[0], current_values.shape[1]))])
                mask = out == nan_replacement
                mask = mask.astype(np.int)
                label = np.array([1., 0.])
                yield out, label, mask


    # Setting up the truncated normal distribution for exploration

    lower, upper = 0, 1
    mu, sigma = 0, 0.2
    left_truncnorm = stats.truncnorm(
        (lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
    right_truncnorm = stats.truncnorm(
        (lower - 1.) / sigma, (upper - 1.) / sigma, loc=1., scale=sigma)

    # fig, ax = plt.subplots(1, sharex=True)
    # ax.hist(np.concatenate([left_truncnorm.rvs(10000),right_truncnorm.rvs(10000)]), normed=True)

    np.random.seed(SEED)
    tf.set_random_seed(SEED)

    # Prediction Model Parameters
    start_learning_rate = args.lr_prediction_model
    decay_step = args.decay_step
    decay_rate = args.decay_rate
    num_hidden = args.lstm_hidden_size

    training_steps = args.training_steps
    batch_size = 128

    num_input = 15 
    timesteps = max_seq_len # timesteps
    num_classes = 2 

    display_step = 10

    gpu = 0

    graph = tf.Graph()

    file_appendix = "MIMIC_LSTM_GIL-H_" + str(start_learning_rate) + "_" + str(decay_step) + "_" + str(decay_rate) + "_" + str(num_hidden) 

    def build_net(x, is_training=True, reuse=tf.AUTO_REUSE, graph=graph):

        with graph.as_default():
            seq_len = seq_length(x)
            enumerated_last_idxs = tf.cast(tf.stack([seq_len-1, tf.range(tf.shape(seq_len)[0])], axis=1), tf.int32)
            x = tf.unstack(x, timesteps, 1)
            lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0, reuse=reuse)
            outputs, state, all_states = my_static_rnn(lstm_cell, x, dtype=tf.float32)
            last_outputs = tf.gather_nd(outputs, enumerated_last_idxs)
            with slim.arg_scope([slim.fully_connected], 
                                    activation_fn=tf.nn.relu,
                                    weights_initializer=tf.random_uniform_initializer(0.001, 0.01),
                                    weights_regularizer=slim.l2_regularizer(0.01),
                                    biases_regularizer=slim.l2_regularizer(0.01),
                                    normalizer_fn = None,
                                    normalizer_params = {"is_training": is_training},
                                    reuse = reuse):
                logits = slim.fully_connected(last_outputs,num_classes,activation_fn=None, weights_regularizer=None, normalizer_fn=None, scope='logits')
                pred = slim.softmax(logits, scope='pred')

                return logits, pred, outputs, x, all_states, seq_len


    with graph.as_default():

        dataset_train = tf.data.Dataset.from_generator(gen_train, (tf.float32, tf.float32, tf.int32), ([ timesteps, 15],[ 2],[timesteps, 15])).repeat(1000).shuffle(5000).batch(batch_size)
        input_train, label_train, mask_train = dataset_train.make_one_shot_iterator().get_next()

        dataset_test = tf.data.Dataset.from_generator(gen_test, (tf.float32, tf.float32, tf.int32), ([ timesteps, 15],[ 2],[timesteps, 15])).repeat(10000).batch(len(df_shock_test.index.unique())+len(df_non_shock_test.index.unique()))
        input_test, label_test, mask_test = dataset_test.make_one_shot_iterator().get_next()

        logits, prediction, outs, xs, states, seq_lens = build_net(input_train)
        loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=label_train) + tf.reduce_mean(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), axis=0)
#         loss_op = tf.divide(tf.multiply(loss_op, max_seq_len), tf.reshape(seq_len, [-1, 1]))
        learning_rate = tf.train.exponential_decay(start_learning_rate, tf.train.get_or_create_global_step(), decay_steps=decay_step, decay_rate=decay_rate)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

        missing_idxs = tf.where_v2(mask_train)
        missing_idxs = tf.stack([missing_idxs[:,1], missing_idxs[:,0], missing_idxs[:,2]], axis=-1)

        # tensor names for <i,j,f,o> -- rnn/basic_lstm_cell/split_{}:<0,1,2,3>

        i_gates = [graph.get_tensor_by_name("rnn/basic_lstm_cell/split_"+str(t)+":0") if t>0 else graph.get_tensor_by_name("rnn/basic_lstm_cell/split:0") for t in range(timesteps)]
        j_gates = [graph.get_tensor_by_name("rnn/basic_lstm_cell/split_"+str(t)+":1") if t>0 else graph.get_tensor_by_name("rnn/basic_lstm_cell/split:1") for t in range(timesteps)]
        f_gates = [graph.get_tensor_by_name("rnn/basic_lstm_cell/split_"+str(t)+":2") if t>0 else graph.get_tensor_by_name("rnn/basic_lstm_cell/split:2") for t in range(timesteps)]
        o_gates = [graph.get_tensor_by_name("rnn/basic_lstm_cell/split_"+str(t)+":3") if t>0 else graph.get_tensor_by_name("rnn/basic_lstm_cell/split:3") for t in range(timesteps)]

        grads_i = optimizer.compute_gradients(loss_op,i_gates)
        grads_i = [g[0] for g in grads_i]
        grads_j = optimizer.compute_gradients(loss_op,j_gates)
        grads_j = [g[0] for g in grads_j]
        grads_f = optimizer.compute_gradients(loss_op,f_gates)
        grads_f = [g[0] for g in grads_f]
        grads_o = optimizer.compute_gradients(loss_op,o_gates)
        grads_o = [g[0] for g in grads_o]

        grads_i_j_f_o = [tf.concat([grads_i[t], grads_j[t], grads_f[t], grads_o[t]], axis=1) for t in range(timesteps)]

        xs_need_to_be_zero = tf.gather_nd(xs,missing_idxs)
        xs_updates = tf.scatter_nd(indices=missing_idxs, updates=-xs_need_to_be_zero, shape=[timesteps, batch_size, num_input])
        xs_for_grads = xs + xs_updates
        W_grads = tf.tensordot(xs_for_grads, grads_i_j_f_o, axes=[[0,1],[0,1]])/batch_size

        enumerated_seq_lens = tf.cast(tf.stack([seq_lens, tf.range(tf.shape(seq_lens)[0])], axis=1), tf.int32)

        def cond(i, e, o):
            return i < batch_size
        def body(i, e, o):
            o = tf.concat([o,tf.stack([tf.range(e[i,0]),tf.repeat(e[i,1],e[i,0])],axis=-1)],axis=0)
            return i+1, e, o

        _,_,nonzero_out_idxs = tf.while_loop(cond,body,[tf.constant(1, dtype=tf.int32), enumerated_seq_lens, tf.stack([tf.range(enumerated_seq_lens[0,0]),tf.repeat(enumerated_seq_lens[0,1],enumerated_seq_lens[0,0])],axis=-1)], shape_invariants=[tf.TensorShape([]),tf.TensorShape([None,2]),tf.TensorShape([None,2])])

        outs_non_zero = tf.gather_nd(outs,nonzero_out_idxs)
        outs_updates = tf.scatter_nd(indices=nonzero_out_idxs, updates=outs_non_zero, shape=[timesteps, batch_size, num_hidden])
        outs = tf.zeros((timesteps,batch_size,num_hidden)) + outs_updates
        U_grads = tf.tensordot(outs, grads_i_j_f_o, axes=[[0,1],[0,1]])/batch_size
        lstm_kernel_grads = tf.concat([W_grads,U_grads],axis=0)     

        logits_final, pred_final, _, _, _, _ = build_net(input_test, is_training=False)


        grads = optimizer.compute_gradients(loss_op, tf.trainable_variables())
        grads = [g[0] for g in grads]

        grads[0] = lstm_kernel_grads


        grads_update_op = optimizer.apply_gradients(zip(grads, tf.trainable_variables()))
        
        train_correct_pred = tf.equal(tf.cast(tf.argmax(prediction, 1),tf.float32), tf.cast(tf.argmax(label_train, 1),tf.float32) )
        train_accuracy = tf.reduce_mean(tf.cast(train_correct_pred, tf.float32))

        final_correct_pred = tf.equal(tf.argmax(pred_final, 1), tf.argmax(label_test, 1))
        final_accuracy = tf.reduce_mean(tf.cast(final_correct_pred, tf.float32))

        max_final_acc = tf.Variable(0, dtype=tf.float32, name="max_final_acc", trainable=False)
        assign_max_final_acc = max_final_acc.assign(final_accuracy)
        
        final_score = pred_final[:, 1]

        init = tf.global_variables_initializer()

        saver = tf.train.Saver()

        with tf.Session(config=session_config, graph=graph) as sess:
            sess.run(init)
            max_auc = 0.
            max_ap = 0.

            for step in range(1, training_steps+1):
                _, train_acc = sess.run([grads_update_op,train_accuracy])
                auc = roc_auc_score(np.argmax(sess.run(label_test), axis=1), final_score.eval())
                ap = average_precision_score(np.argmax(sess.run(label_test), axis=1), final_score.eval())
                if step % display_step == 0 or step == 1:
                    # Calculate batch loss and accuracy
                    loss, acc = sess.run([loss_op, final_accuracy])
                    if acc > max_final_acc.eval():
                        sess.run(assign_max_final_acc)
                        max_auc = auc
                        max_ap = ap
                        saver.save(sess, "./saved_model/"+file_appendix+"/best.ckpt")
                    print "Step " + str(step) + ", Minibatch Loss= " + \
                      "{:.4f}".format(loss) + ", Training Accuracy= " + \
                      "{:.3f}".format(train_acc) + \
                      ", Max Final Accuracy= ", "{:6f}".format(max_final_acc.eval()) + \
                      ", Max AUC= ", "{:6f}".format(max_auc) + \
                      ", Max AP= ", "{:6f}".format(max_ap)

        print "Optimization Finished!"

        print "Testing Accuracy:", sess.run(max_final_acc)
        print "Testing AUC:", max_auc

        # Record the hyper-parameters tried along with their performances
        with open("./stats/MIMIC_LSTM_GIL-H.txt", "ab") as myfile:
            myfile.write("%.6f\t%i\t%.3f\t%.6f\t%.6f\t%i\t%.6f\t%.6f\t%.6f\n" %(start_learning_rate, decay_step, decay_rate, num_hidden, max_final_acc.eval(), max_auc, max_ap))






















