import numpy as np
import argparse
import os
import imp
import re
import math
from sklearn.cluster import KMeans
from mimic3models.in_hospital_mortality import utils
from mimic3benchmark.readers import InHospitalMortalityReader
from mimic3models.preprocessing import Discretizer, Normalizer
from mimic3models import metrics
from mimic3models import keras_utils
from mimic3models import common_utils
import tensorflow as tf
from keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.python.eager import backprop
import pickle
from sklearn import metrics as skmetrics
from tqdm import tqdm
from tensorflow.keras.models import Model

def argparser():
    parser = argparse.ArgumentParser()
    common_utils.add_common_arguments(parser)
    parser.add_argument('--seed', type=int, nargs='?', default=42,
                            help='Random seed to be used.')
    parser.add_argument('--target_repl_coef', type=float, default=0.0)
    parser.add_argument('--data', type=str, help='Path to the data of in-hospital mortality task',
                        default=os.path.join(os.path.dirname(__file__), '../../data/in-hospital-mortality/'))
    parser.add_argument('--output_dir', type=str, help='Directory relative which all output files are stored',
                        default='.')
    parser.add_argument('--num_clients', type=int, help='Number of clients to split data between vertically',
                        default=2)
    parser.add_argument('--no_clusters', type=int, help='Number of clusters of embeddings in KMeans',
                        default=10)
    parser.add_argument('--local_epochs', type=int, help='Number of local epochs to run at each client before synchronizing',
                        default=1)
    parser.add_argument('--if_cluster', type=bool, help='If workers want to cluster embeddings before sending to server',
                        default=False)
                        
    args = parser.parse_args()
    print("*"*80, "\n\n", args, "\n\n", "*"*80)
    return args




if __name__ == "__main__":
    
    args = argparser()
    np.random.seed(args.seed)
    tf.random.set_seed(args.seed)

    num_clients = args.num_clients
    local_epochs = args.local_epochs

    if args.small_part:
        args.save_every = 2**30

    target_repl = (args.target_repl_coef > 0.0 and args.mode == 'train')

    # Build readers, discretizers, normalizers
    train_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'train'),
                                            listfile=os.path.join(args.data, 'train_listfile.csv'),
                                            period_length=48.0)

    val_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'train'),
                                        listfile=os.path.join(args.data, 'val_listfile.csv'),
                                        period_length=48.0)

    test_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'test'),
                                            listfile=os.path.join(args.data, 'test_listfile.csv'),
                                            period_length=48.0)

    discretizer = Discretizer(timestep=float(args.timestep),
                            store_masks=True,
                            impute_strategy='previous',
                            start_time='zero')

    discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
    cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

    normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
    normalizer_state = args.normalizer_state
    if normalizer_state is None:
        normalizer_state = 'ihm_ts{}.input_str:{}.start_time:zero.normalizer'.format(args.timestep, args.imputation)
        normalizer_state = os.path.join(os.path.dirname(__file__), normalizer_state)
    normalizer.load_params(normalizer_state)

    args_dict = dict(args._get_kwargs())
    args_dict['header'] = discretizer_header
    args_dict['task'] = 'ihm'
    args_dict['target_repl'] = target_repl
    args_dict['downstream_clients'] = num_clients # total number of vertical partitions present

    models = []
    # Make models for each client
    for i in range(num_clients+1):
        # Build the model
        args.network = "mimic3models/keras_models/lstm"
        if i < num_clients:
            args.network += "_bottom.py"
        else:
            args.network += "_top.py"

        print("==> using model {}".format(args.network))
        model_module = imp.load_source(os.path.basename(args.network), args.network)
        model = model_module.Network(input_dim=int(76/num_clients), **args_dict)

        # Compile the model
        print("==> compiling the model")
        optimizer_config = tf.keras.optimizers.Adam(
                    learning_rate=args.lr, beta_1=args.beta_1)
        if target_repl:
            loss = ['binary_crossentropy'] * 2
            loss_weights = [1 - args.target_repl_coef, args.target_repl_coef]
        else:
            loss = 'binary_crossentropy'
            loss_weights = None

        model.compile(optimizer=optimizer_config,
                    loss=loss,
                    loss_weights=loss_weights)
        model.summary()
        models.append(model)

    # Load model weights
    n_trained_chunks = 0
    if args.load_state != "":
        model.load_weights(args.load_state)
        n_trained_chunks = int(re.match(".*epoch([0-9]+).*", args.load_state).group(1))


    # Read data
    # train_raw = utils.load_data(train_reader, discretizer, normalizer, args.small_part)
    # pickle.dump(train_raw, open('train_raw'+str(num_clients)+'.pkl', 'wb'))
    # val_raw = utils.load_data(val_reader, discretizer, normalizer, args.small_part)
    # pickle.dump(val_raw, open('val_raw'+str(num_clients)+'.pkl', 'wb'))
    # test_raw = utils.load_data(test_reader, discretizer, normalizer, args.small_part)
    # pickle.dump(test_raw, open('test_raw'+str(num_clients)+'.pkl', 'wb'))
    
    # Read data
    train_raw = pickle.load(open('train_raw1.pkl', 'rb'))
    val_raw = pickle.load(open('val_raw1.pkl', 'rb'))
    test_raw = pickle.load(open('test_raw1.pkl', 'rb'))    
    
    if target_repl:
        T = train_raw[0][0].shape[0]

        def extend_labels(data):
            data = list(data)
            labels = np.array(data[1])  # (B,)
            data[1] = [labels, None]
            data[1][1] = np.expand_dims(labels, axis=-1).repeat(T, axis=1)  # (B, T)
            data[1][1] = np.expand_dims(data[1][1], axis=-1)  # (B, T, 1)
            return data

        train_raw = extend_labels(train_raw)
        val_raw = extend_labels(val_raw)

    if args.mode == 'train':

        # Prepare training

        print("==> training")

        activation = tf.keras.activations.sigmoid 
        coords_per = int(76/num_clients)

        # @tf.function
        def get_grads(x, y, H, model, server_model, i):
            loss_value = 0
            Hnew = H.copy()
            with backprop.GradientTape() as tape:
                out = model(x, training=True)
                Hnew[i] = out
                logits = server_model(tf.concat(Hnew,axis=1), training=True)
                loss_value = server_model.compiled_loss(y, logits)
            grads = tape.gradient(loss_value, model.trainable_variables 
                                            + server_model.trainable_variables)
            return grads, loss_value

        # @tf.function
        def train_step(x, y, model, server_model, H, local, i):
            loss_value = 0
            for t in range(local):
                grads, loss_value = get_grads(x, y, H.tolist(), model, server_model, i)
                grads = model.optimizer._clip_gradients(grads)    # pylint: disable=protected-access
                # only use grads up to the 9th index. The last two are for the server model, which is not necessary as the servermodel is fixed
                model.optimizer.apply_gradients(zip(grads[:9],
                                                    model.trainable_variables))
            return loss_value

        # @tf.function
        def getserver_grads(y, H, server_model):
            loss_value = 0
            Hnew = H.copy()
            with backprop.GradientTape() as tape:
                logits = server_model(tf.concat(Hnew,axis=1), training=True)
                loss_value = server_model.compiled_loss(y, logits)
            #grads = tape.gradient(loss_value, model.trainable_variables 
            #                                + server_model.trainable_variables)
            grads = tape.gradient(loss_value, server_model.trainable_variables)
            return grads, loss_value

        # @tf.function
        def trainserver_step(y, server_model, H, local):
            global args
            loss_value = 0
            for t in range(local):
                grads, loss_value = getserver_grads(y, H, server_model)
                # grads, loss_value = getserver_grads(y, H.tolist(), server_model)
                grads = server_model.optimizer._clip_gradients(grads)    # pylint: disable=protected-access
                # since we are only getting gradient for the server model trainable variables, we can 
                # just pass in the entire grads list
                server_model.optimizer.apply_gradients(zip(grads,
                                                server_model.trainable_variables))
            return loss_value

        # @tf.function
        def forward(x, y, model):
            out = model(x, training=False)
            return out 

        # @tf.function
        def predict(x, y, models):
            out = []
            for i in range(len(models)-1):
                x_local = x[:,:,coords_per*i:coords_per*(i+1)]
                out.append(forward(x_local, y, models[i]))
            logits = models[-1](tf.concat(out,axis=1), training=False)
            loss = models[-1].compiled_loss(y, logits)
            return logits , loss

        # Split data vertically
        train_dataset = tf.data.Dataset.from_tensor_slices((
                                        train_raw[0], 
                                        train_raw[1].reshape(-1,1)))
        train_dataset_static_for_logging = tf.data.Dataset.from_tensor_slices((
                                        train_raw[0], 
                                        train_raw[1].reshape(-1,1)))
        
        train_dataset = train_dataset.batch(args.batch_size)
        train_dataset_static_for_logging = train_dataset_static_for_logging.batch(args.batch_size)

        test_dataset = tf.data.Dataset.from_tensor_slices((
                                        test_raw[0], 
                                        test_raw[1].reshape(-1,1)))
        test_dataset_static_for_logging = tf.data.Dataset.from_tensor_slices((
                                        test_raw[0], 
                                        test_raw[1].reshape(-1,1)))
        
        test_dataset = test_dataset.batch(args.batch_size)
        test_dataset_static_for_logging = test_dataset_static_for_logging.batch(args.batch_size)

        losses = []
        accs_train = []
        accs_test = []
        for epoch in tqdm(range(args.epochs)):
            train_dataset = train_dataset.shuffle(buffer_size=train_raw[0].shape[0])
            print("\nStart of epoch %d" % (epoch,))

            # Iterate over the batches of the dataset.
            Hs = np.empty((math.ceil(train_raw[0].shape[0] / args.batch_size), num_clients), dtype=object)
            Hs.fill([])
            for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
                # Train for each client 
                for i in range(num_clients):
                    x_local = x_batch_train[:,:,coords_per*i:coords_per*(i+1)]
                    Hs[step,i] = forward(x_local, y_batch_train, models[i])

                client_losses = [0]*num_clients
                #import pdb;pdb.set_trace()
                for i in range(num_clients):
                    x_local = x_batch_train[:,:,coords_per*i:coords_per*(i+1)]
                    H = Hs[step]
                    client_losses[i] = train_step(x_local, y_batch_train, models[i], models[-1], H, local_epochs, i)
                #trainserver_step(y_batch_train, models[-1], H, local_epochs)
                #----------------------------------------------------------------------------
                H_most_recent = [None]*num_clients
                for i in range(num_clients):
                    x_local = x_batch_train[:,:,coords_per*i:coords_per*(i+1)]
                    H_most_recent[i] = forward(x_local, y_batch_train, models[i])
                    if args.if_cluster and len(H_most_recent[i])>=args.no_clusters:
                        """
                        Performing Kmeans if required
                        """
                        if step==0:
                            print("Clustering {epoch}")
                        kmeans = KMeans(n_clusters=args.no_clusters, random_state=args.seed)
                        kmeans.fit(H_most_recent[i].numpy())
                        cluster_centers = kmeans.cluster_centers_
                        input_top_layer = np.zeros(H_most_recent[i].shape) # initialize blank x_clustered
                        # basically create clusters and extract cluster centers, then 
                        # recreate the X_batch with the help of the clusters. So replace the rows 
                        # of embeddings corresponding to cluster i with cluster center of i
                        # this is analogous to sending cluster centers, and indices of the data points
                        # and the cluster where they belong to such that the server can recreate the batch of data 
                        for idx, label in enumerate(np.unique(kmeans.labels_)):
                            input_top_layer[kmeans.labels_== label, :] = kmeans.cluster_centers_[label]
                        # replace original embeddings with the compressed embeddings for input to the server model
                        H_most_recent[i] = input_top_layer
                """
                Very important to understand that the rows of the compressed embedding matrix has the
                same previously aligned target values or the y values even though for example
                data point i maybe in cluster 0 in client 1 and cluster 7 in client 2
                """
                loss_final = trainserver_step(y_batch_train, models[-1], H_most_recent, local_epochs)
                _, loss_aggregated_model = predict(x_batch_train, y_batch_train, models)
                #----------------------------------------------------------------------------
                #losses.append(float(np.average(H_new[0])))
                # print(float(np.asarray(loss_final)), loss_final, client_losses)
                losses.append([i.numpy() for i in client_losses] + [loss_final.numpy(), loss_aggregated_model.numpy()])
                pickle.dump(losses, open(f'losses_multisplit_BS{args.batch_size}_NC{args.num_clients}_Clus{args.no_clusters}_GE{args.epochs}_LE{args.local_epochs}_lr{args.lr}_ifclust{args.if_cluster}.pkl', 'wb'))

                # Log every 200 batches.
                if step % 20 == 0:
                    print(
                        "Training loss (for one batch) at step %d: %.4d, %.4f"
                        % (step, float(loss_final), float(np.average(client_losses)))
                    )
                    print("Seen so far: %s samples" % ((step + 1) * args.batch_size))
                
            print("==> predicting")
            # Iterate over the batches of the dataset.
            predictions = np.zeros(train_raw[0].shape[0])
            left = 0
            for step, (x_batch_train, y_batch_train) in enumerate(train_dataset_static_for_logging):
                logits, loss_aggregate_model = predict(x_batch_train, y_batch_train, models)
                predictions[left: left + len(x_batch_train)] = tf.reshape(tf.identity(logits),-1)
                left = left + len(x_batch_train)
            ret = metrics.print_metrics_binary(train_raw[1], predictions)
            accs_train.append(list(ret.items()))
            pickle.dump(accs_train, open(f'accs_train_multisplit_BS{args.batch_size}_NC{args.num_clients}_Clus{args.no_clusters}_GE{args.epochs}_LE{args.local_epochs}_lr{args.lr}_ifclust{args.if_cluster}.pkl', 'wb'))
            print("Train acc", ret['acc'])
            # Iterate over the batches of the dataset.
            predictions = np.zeros(test_raw[0].shape[0])
            left = 0
            for step, (x_batch_test, y_batch_test) in enumerate(test_dataset_static_for_logging):
                logits, _ = predict(x_batch_test, y_batch_test, models)
                predictions[left: left + len(x_batch_test)] = tf.reshape(tf.identity(logits),-1)
                left = left + len(x_batch_test)
            ret = metrics.print_metrics_binary(test_raw[1], predictions)
            accs_test.append(list(ret.items()))
            pickle.dump(accs_test, open(f'accs_test_multisplit_BS{args.batch_size}_NC{args.num_clients}_Clus{args.no_clusters}_GE{args.epochs}_LE{args.local_epochs}_lr{args.lr}_ifclust{args.if_cluster}.pkl', 'wb'))
            print("test acc", ret['acc'])
    elif args.mode == 'test':

        # ensure that the code uses test_reader
        del train_reader
        del val_reader
        del train_raw
        del val_raw

        test_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'test'),
                                                listfile=os.path.join(args.data, 'test_listfile.csv'),
                                                period_length=48.0)
        ret = utils.load_data(test_reader, discretizer, normalizer, args.small_part,
                            return_names=True)

        data = ret["data"][0]
        labels = ret["data"][1]
        names = ret["names"]

        predictions = model.predict(data, batch_size=args.batch_size, verbose=1)
        predictions = np.array(predictions)[:, 0]
        metrics.print_metrics_binary(labels, predictions)

        path = os.path.join(args.output_dir, "test_predictions", os.path.basename(args.load_state)) + ".csv"
        utils.save_results(names, predictions, labels, path)

    else:
        raise ValueError("Wrong value for args.mode")
