import argparse
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
import os 
import multiprocessing as mp
from qnetwork import *
from utils import *
import pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score
import scipy.stats as stats
import random

rnn = tf.contrib.rnn
slim = tf.contrib.slim

parser = argparse.ArgumentParser()
parser.add_argument("-no_gpu", dest='no_gpu', action='store_true', help="Train w/o using GPUs")
parser.add_argument("-gpu", "--gpu_idx", type=int, help="Select which GPU to use DEFAULT=0", default=0)
parser.add_argument("-lr_prediction_model", type=float, help="Set learning rate for training the MLP prediction model DEFAULT=0.0005", default=0.0005)
parser.add_argument("-decay_step", type=int, help="Set exponential decay step DEFAULT=750", default=1000)
parser.add_argument("-decay_rate", type=float, help="Set exponential decay rate DEFAULT=1.0", default=0.8)
parser.add_argument("-training_steps", type=int, help="Set max number of training epochs DEFAULT=3000", default=3000)
parser.add_argument("-seed", type=int, help="Set random seed", default=2599)


if __name__ == '__main__':
    args = parser.parse_args()
    if not args.no_gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_idx)
        session_config = tf.ConfigProto(log_device_placement=False)
        session_config.gpu_options.allow_growth = True
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = ""
        session_config = tf.ConfigProto(log_device_placement=False)
    SEED = args.seed
    np.random.seed(SEED)
    tf.set_random_seed(SEED)
    random.seed(SEED)

    if not os.path.exists("./saved_model"):
            os.mkdir("./saved_model")
    if not os.path.exists("./stats"):
            os.mkdir("./stats")
    if not os.path.exists("./stats/rl_log"):
            os.mkdir("./stats/rl_log")

    normal_train = np.loadtxt("./data/normal_train_all_35_missing.txt")
    abnormal_train = np.loadtxt("./data/abnormal_train_all_35_missing.txt")
    normal_test = np.loadtxt("./data/normal_test_all_35_missing.txt")
    abnormal_test = np.loadtxt("./data/abnormal_test_all_35_missing.txt")

    data_train = np.vstack([normal_train, abnormal_train]).astype(np.float32)
    data_label_train = np.concatenate([np.zeros(len(normal_train)), np.ones(len(abnormal_train))]).astype(np.int32)
    data_mask_train = np.isnan(data_train).astype(np.float32)

    data_test = np.vstack([normal_test, abnormal_test]).astype(np.float32)
    data_label_test = np.concatenate([np.zeros(len(normal_test)), np.ones(len(abnormal_test))]).astype(np.int32)
    data_mask_test = np.isnan(data_test).astype(np.float32)

    nan_replacement = 0.

    data_train[np.isnan(data_train)] = nan_replacement
    data_test[np.isnan(data_test)] = nan_replacement


    # Setting up the truncated normal distribution for exploration

    lower, upper = 0, 1
    mu, sigma = 0, 0.2
    left_truncnorm = stats.truncnorm(
        (lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
    right_truncnorm = stats.truncnorm(
        (lower - 1.) / sigma, (upper - 1.) / sigma, loc=1., scale=sigma)

    # fig, ax = plt.subplots(1, sharex=True)
    # ax.hist(np.concatenate([left_truncnorm.rvs(10000),right_truncnorm.rvs(10000)]), normed=True)

    np.random.seed(SEED)
    tf.set_random_seed(SEED)
    random.seed(SEED)

    # Prediction Model Parameters
    start_learning_rate = args.lr_prediction_model
    decay_step = args.decay_step
    decay_rate = args.decay_rate

    training_steps = args.training_steps
    batch_size = 128

    num_input = 4101
    timesteps = 1 # timesteps
    num_classes = 2 

    display_step = 10

    weights = [1000, 1000]

    gpu = 0

    graph = tf.Graph()

    file_appendix = "Ophthalmic_35_missing_MLP_GIL-H_" + str(start_learning_rate) + "_" + str(decay_step) + "_" + str(decay_rate)

    def build_net(x, is_training=True, reuse=tf.AUTO_REUSE, graph=graph):

        with graph.as_default():

            with tf.variable_scope("NN", reuse=tf.AUTO_REUSE) as scope:
                with slim.arg_scope([slim.fully_connected], 
                                        activation_fn=tf.nn.relu,
                                        weights_initializer=tf.random_uniform_initializer(0.001, 0.01),
                                        weights_regularizer=slim.l2_regularizer(0.01),
                                        biases_regularizer=slim.l2_regularizer(0.01),
                                        normalizer_fn = None,
                                        normalizer_params = {"is_training": is_training},
                                        reuse = reuse,
                                        scope = scope):

                    fc1 = slim.fully_connected(x, weights[0], scope='fc1')
                    fc2 = slim.fully_connected(fc1, weights[1], scope='fc2')
                    logits = slim.fully_connected(fc2,2,activation_fn=None, weights_regularizer=None, normalizer_fn=None, scope='logits')
                    pred = slim.softmax(logits, scope='pred')

                    return logits, pred



    def gen_train():
        for i in range(data_train.shape[0]):
            label = np.zeros(2)
            label[data_label_train[i]] = 1.
            yield data_train[i], label, data_mask_train[i]


    with graph.as_default():

        dataset_train = tf.data.Dataset.from_generator(gen_train, (tf.float32, tf.float32, tf.int32), ([data_train.shape[1]],[2],[data_train.shape[1]])).repeat(1000).shuffle(5000).batch(batch_size)
        input_train, label_train, mask_train = dataset_train.make_one_shot_iterator().get_next()


        logits, prediction = build_net(input_train)


        all_test = data_test

        logits_final, pred_final = build_net(all_test, is_training=False)

        loss_op = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=label_train) + tf.reduce_mean(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        loss_mean = tf.reduce_mean(loss_op, axis=0)
        learning_rate = tf.train.exponential_decay(start_learning_rate, tf.train.get_or_create_global_step(), decay_steps=decay_step, decay_rate=decay_rate)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        grads = tf.vectorized_map(lambda x: optimizer.compute_gradients(x, tf.trainable_variables()), loss_op)
        grads = [g[0] for g in grads]

        missing_idxs = tf.where_v2(mask_train)
        missing_idxs = tf.stack([missing_idxs[:,0], missing_idxs[:,1]], axis=-1)

        # Apply the simple heuristic to the gradients calculated by retular SGD solver
        fc1_grads_gather = tf.gather_nd(grads[0], tf.stack([missing_idxs[0],missing_idxs[1]],axis=-1))
        fc1_grads_update = tf.scatter_nd(indices=tf.stack([missing_idxs[0],missing_idxs[1]],axis=-1), updates=-fc1_grads_gather, shape=[batch_size, data_train.shape[1], weights[0]])
        grads[0] += fc1_grads_update

        grads = [tf.reduce_mean(g,axis=0) for g in grads]
        
        with tf.control_dependencies(update_ops):
            grads_update_op = optimizer.apply_gradients(zip(grads, tf.trainable_variables()))
        
        train_correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(label_train, 1))
        train_accuracy = tf.reduce_mean(tf.cast(train_correct_pred, tf.float32))

        final_correct_pred = tf.equal(tf.argmax(pred_final, 1), data_label_test)
        final_accuracy = tf.reduce_mean(tf.cast(final_correct_pred, tf.float32))

        max_final_acc = tf.Variable(0, dtype=tf.float32, name="max_final_acc", trainable=False)
        assign_max_final_acc = max_final_acc.assign(final_accuracy)

        final_score = pred_final[:,1]

        init = tf.global_variables_initializer()

        saver = tf.train.Saver()

    # Start training
    with tf.Session(config=session_config, graph=graph) as sess:

        sess.run(init)
        max_auc = 0.
        max_ap = 0.
        max_acc = 0.

        for step in range(1, training_steps+1):
            sess.run(grads_update_op)
            if step % display_step == 0 or step == 1:
                # Calculate batch loss and accuracy
                loss, acc, train_acc = sess.run([loss_mean, final_accuracy, train_accuracy])
                auc = roc_auc_score(data_label_test, final_score.eval())
                ap = average_precision_score(data_label_test, final_score.eval())
                if acc+auc > max_acc+max_auc:
                    max_acc = acc
                    max_auc = auc
                    max_ap = ap
                    sess.run(assign_max_final_acc)
                    saver.save(sess, "./saved_model/"+file_appendix+"/best.ckpt")
                print "Step " + str(step) + ", Minibatch Loss= " + \
                      "{:.4f}".format(loss) + ", Training Accuracy= " + \
                      "{:.3f}".format(train_acc) + \
                      ", Max Final Accuracy= ", "{:6f}".format(max_final_acc.eval()) + \
                      ", Max AUC= ", "{:6f}".format(max_auc) + \
                      ", Max AP= ", "{:6f}".format(max_ap)

        print "Optimization Finished!"

        print "Testing Accuracy:", sess.run(max_final_acc)
        print "Testing AUC:", max_auc
        with open("./stats/Ophthalmic_35_missing_GIL-H.txt", "ab") as myfile:
            myfile.write("%.9f\t%i\t%.3f\t%i\t%i\t%.6f\t%.6f\t%.6f\n" %(start_learning_rate, decay_step, decay_rate, weights[0], weights[1], max_final_acc.eval(), max_auc, max_ap))

