import numpy as np
import argparse
import os
import imp
import re
import math
from sklearn.cluster import KMeans
from mimic3models.in_hospital_mortality import utils
from mimic3benchmark.readers import InHospitalMortalityReader
from mimic3models.preprocessing import Discretizer, Normalizer
from mimic3models import metrics
from mimic3models import keras_utils
from mimic3models import common_utils
import tensorflow as tf
from keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.python.eager import backprop
import pickle
from sklearn import metrics as skmetrics
from tqdm import tqdm
from tensorflow.keras.models import Model

def argparser():
    parser = argparse.ArgumentParser()
    common_utils.add_common_arguments(parser)
    parser.add_argument('--target_repl_coef', type=float, default=0.0)
    parser.add_argument('--data', type=str, help='Path to the data of in-hospital mortality task',
                        default=os.path.join(os.path.dirname(__file__), '../../data/in-hospital-mortality/'))
    parser.add_argument('--output_dir', type=str, help='Directory relative which all output files are stored',
                        default='.')
    parser.add_argument('--num_clients', type=int, help='Number of clients to split data between vertically',
                        default=2)
    parser.add_argument('--no_clusters', type=int, help='Number of clusters of embeddings in KMeans',
                        default=10)
    parser.add_argument('--local_epochs', type=int, help='Number of local epochs to run at each client before synchronizing',
                        default=1)
    parser.add_argument('--if_cluster', type=bool, help='If workers want to cluster embeddings before sending to server',
                        default=False)
                        
    args = parser.parse_args()
    print("*"*80, "\n\n", args, "\n\n", "*"*80)
    return args




if __name__ == "__main__":

    args = argparser()

    num_clients = args.num_clients
    local_epochs = args.local_epochs

    if args.small_part:
        args.save_every = 2**30

    target_repl = (args.target_repl_coef > 0.0 and args.mode == 'train')

    # Build readers, discretizers, normalizers
    train_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'train'),
                                            listfile=os.path.join(args.data, 'train_listfile.csv'),
                                            period_length=48.0)

    val_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'train'),
                                        listfile=os.path.join(args.data, 'val_listfile.csv'),
                                        period_length=48.0)

    test_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'test'),
                                            listfile=os.path.join(args.data, 'test_listfile.csv'),
                                            period_length=48.0)

    discretizer = Discretizer(timestep=float(args.timestep),
                            store_masks=True,
                            impute_strategy='previous',
                            start_time='zero')

    discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
    cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

    normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
    normalizer_state = args.normalizer_state
    if normalizer_state is None:
        normalizer_state = 'ihm_ts{}.input_str:{}.start_time:zero.normalizer'.format(args.timestep, args.imputation)
        normalizer_state = os.path.join(os.path.dirname(__file__), normalizer_state)
    normalizer.load_params(normalizer_state)

    args_dict = dict(args._get_kwargs())
    args_dict['header'] = discretizer_header
    args_dict['task'] = 'ihm'
    args_dict['target_repl'] = target_repl
    args_dict['downstream_clients'] = num_clients # total number of vertical partitions present

    models = []
    # Make models for each client
    for i in range(num_clients+1):
        # Build the model
        args.network = "mimic3models/keras_models/lstm"
        if i < num_clients:
            args.network += "_bottom.py"
        else:
            args.network += "_top.py"

        print("==> using model {}".format(args.network))
        model_module = imp.load_source(os.path.basename(args.network), args.network)
        model = model_module.Network(input_dim=int(76/num_clients), **args_dict)

        # Compile the model
        print("==> compiling the model")
        optimizer_config = tf.keras.optimizers.Adam(
                    learning_rate=args.lr, beta_1=args.beta_1)
        if target_repl:
            loss = ['binary_crossentropy'] * 2
            loss_weights = [1 - args.target_repl_coef, args.target_repl_coef]
        else:
            loss = 'binary_crossentropy'
            loss_weights = None

        model.compile(optimizer=optimizer_config,
                    loss=loss,
                    loss_weights=loss_weights)
        model.summary()
        models.append(model)

    # Load model weights
    n_trained_chunks = 0
    if args.load_state != "":
        model.load_weights(args.load_state)
        n_trained_chunks = int(re.match(".*epoch([0-9]+).*", args.load_state).group(1))

    
    # Read data
    # train_raw = utils.load_data(train_reader, discretizer, normalizer, args.small_part)
    # pickle.dump(train_raw, open('train_raw'+str(num_clients)+'.pkl', 'wb'))
    # val_raw = utils.load_data(val_reader, discretizer, normalizer, args.small_part)
    # pickle.dump(val_raw, open('val_raw'+str(num_clients)+'.pkl', 'wb'))
    # test_raw = utils.load_data(test_reader, discretizer, normalizer, args.small_part)
    # pickle.dump(test_raw, open('test_raw'+str(num_clients)+'.pkl', 'wb'))
    
    # Read data
    train_raw = pickle.load(open('train_raw'+str(num_clients)+'.pkl', 'rb'))
    val_raw = pickle.load(open('val_raw'+str(num_clients)+'.pkl', 'rb'))
    test_raw = pickle.load(open('test_raw'+str(num_clients)+'.pkl', 'rb'))

    if target_repl:
        T = train_raw[0][0].shape[0]

        def extend_labels(data):
            data = list(data)
            labels = np.array(data[1])  # (B,)
            data[1] = [labels, None]
            data[1][1] = np.expand_dims(labels, axis=-1).repeat(T, axis=1)  # (B, T)
            data[1][1] = np.expand_dims(data[1][1], axis=-1)  # (B, T, 1)
            return data

        train_raw = extend_labels(train_raw)
        val_raw = extend_labels(val_raw)

    if args.mode == 'train':

        # Prepare training

        print("==> training")

        activation = tf.keras.activations.sigmoid 
        coords_per = int(76/num_clients)

        #@tf.function
        def get_grads(x, y, H, model, server_model, i):
            loss_value = 0
            Hnew = H.copy()
            with backprop.GradientTape() as tape:
                out = model(x, training=True)
                Hnew[i] = out
                logits = server_model(tf.concat(Hnew,axis=1), training=True)
                loss_value = server_model.compiled_loss(y, logits)
            grads = tape.gradient(loss_value, model.trainable_variables 
                                            + server_model.trainable_variables)
            return grads, loss_value

        def train_step(x, y, model, server_model, H, local, i):
            loss_value = 0
            for t in range(local):
                grads, loss_value = get_grads(x, y, H.tolist(), model, server_model, i)
                grads = model.optimizer._clip_gradients(grads)    # pylint: disable=protected-access
                """ ???????????????????????????????????????????? """
                model.optimizer.apply_gradients(zip(grads[:2], 
                                                    model.trainable_variables))
            return loss_value

        #@tf.function
        def getserver_grads(y, H, server_model):
            loss_value = 0
            Hnew = H.copy()
            with backprop.GradientTape() as tape:
                logits = server_model(tf.concat(Hnew,axis=1), training=True)
                loss_value = server_model.compiled_loss(y, logits)
            grads = tape.gradient(loss_value, server_model.trainable_variables)
            return grads, loss_value

        def trainserver_step(y, server_model, H, local):
            loss_value = 0
            for t in range(local):
                grads, loss_value = getserver_grads(y, H, server_model)
                # grads, loss_value = getserver_grads(y, H.tolist(), server_model)
                grads = server_model.optimizer._clip_gradients(grads)    # pylint: disable=protected-access
                # print("-->", H[0].shape, len(server_model.trainable_variables))
                server_model.optimizer.apply_gradients(zip(grads[:2],
                                                server_model.trainable_variables))
            return loss_value

        #@tf.function
        def forward(x, y, model):
            out = model(x, training=False)
            return out 

        #@tf.function
        def predict(x, y, models):
            out = []
            for i in range(len(models)-1):
                x_local = x[:,:,coords_per*i:coords_per*(i+1)]
                out.append(forward(x_local, y, models[i]))
            logits = models[-1](tf.concat(out,axis=1), training=False)
            return logits 

        # Split data vertically
        train_dataset = tf.data.Dataset.from_tensor_slices((
                                        train_raw[0], 
                                        train_raw[1].reshape(-1,1)))
        
        train_dataset = train_dataset.batch(args.batch_size)

        test_dataset = tf.data.Dataset.from_tensor_slices((
                                        test_raw[0], 
                                        test_raw[1].reshape(-1,1)))
        
        test_dataset = test_dataset.batch(args.batch_size)

        losses = []
        accs_train = []
        accs_test = []
        for epoch in tqdm(range(args.epochs)):
            train_dataset = train_dataset.shuffle(buffer_size=train_raw[0].shape[0])
            print("\nStart of epoch %d" % (epoch,))

            # Iterate over the batches of the dataset.
            Hs = np.empty((math.ceil(train_raw[0].shape[0] / args.batch_size), num_clients), dtype=object)
            Hs.fill([])
            for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
                # Train for each client 
                for i in range(num_clients):
                    x_local = x_batch_train[:,:,coords_per*i:coords_per*(i+1)]
                    Hs[step,i] = forward(x_local, y_batch_train, models[i])

                H_new = np.empty((num_clients), dtype=object)
                H_new.fill([])
                #import pdb;pdb.set_trace()
                for i in range(num_clients):
                    x_local = x_batch_train[:,:,coords_per*i:coords_per*(i+1)]
                    H = Hs[step]
                    H_new[i] = train_step(x_local, y_batch_train, models[i], models[-1], H, local_epochs, i) # make this return the loss and the updated embedding
                #----------------------------------------------------------------------------
                for i in range(num_clients):
                    x_local = x_batch_train[:,:,coords_per*i:coords_per*(i+1)]
                    H_most_recent = [forward(x_local, y_batch_train, models[i])]
                #import pdb;pdb.set_trace()
                if args.if_cluster:
                    """
                    Performing Kmeans
                    """
                    kmeans = KMeans(n_clusters=args.no_clusters, random_state=0).fit(H_most_recent[0].numpy())
                    cluster_centers = kmeans.cluster_centers_
                    # create a new y_batch_train such that for each cluster, the y value is the average of all values
                    y_batch_train_np = y_batch_train.numpy()
                    y_batch_clustered = np.zeros((args.no_clusters, 1))
                    for idx, lab in enumerate(np.unique(kmeans.labels_)):
                        y_cluster = y_batch_train_np[kmeans.labels_==idx]
                        (values,counts) = np.unique(y_cluster,return_counts=True)                    
                        ind=np.argmax(counts)
                        y_batch_clustered[idx] = values[ind]

                    loss_final = trainserver_step(tf.Variable(y_batch_clustered), models[-1], [tf.Variable(cluster_centers)], local_epochs)
                else:
                    # Not using clustering
                    loss_final = trainserver_step(y_batch_train, models[-1], H_most_recent, local_epochs)
                #----------------------------------------------------------------------------
                #losses.append(float(np.average(H_new[0])))
                print(float(np.asarray(loss_final)), loss_final)
                losses.append(float(np.asarray(loss_final)))
                pickle.dump(losses, open(f'losses_multisplit_BS{args.batch_size}_NC{args.num_clients}_Clus{args.no_clusters}_GE{args.epochs}_LE{args.local_epochs}_lr{args.lr}_ifclust{args.if_cluster}.pkl', 'wb'))

                # Log every 200 batches.
                if step % 20 == 0:
                    print(
                        "Training loss (for one batch) at step %d: %.4d, %.4f"
                        % (step, float(loss_final), float(np.average(H_new[0])))
                    )
                    print("Seen so far: %s samples" % ((step + 1) * args.batch_size))

        print("==> predicting")
        # Iterate over the batches of the dataset.
        predictions = np.zeros(train_raw[0].shape[0])
        left = 0
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            logits = predict(x_batch_train, y_batch_train, models)
            predictions[left: left + len(x_batch_train)] = tf.reshape(tf.identity(logits),-1)
            left = left + len(x_batch_train)
        ret = metrics.print_metrics_binary(train_raw[1], predictions)
        accs_train.append(ret["acc"])
        pickle.dump(accs_train, open(f'accs_train_multisplit_BS{args.batch_size}_NC{args.num_clients}_Clus{args.no_clusters}_GE{args.epochs}_LE{args.local_epochs}_lr{args.lr}_ifclust{args.if_cluster}.pkl', 'wb'))

        # Iterate over the batches of the dataset.
        predictions = np.zeros(test_raw[0].shape[0])
        left = 0
        for step, (x_batch_test, y_batch_test) in enumerate(test_dataset):
            logits = predict(x_batch_test, y_batch_test, models)
            predictions[left: left + len(x_batch_test)] = tf.reshape(tf.identity(logits),-1)
            left = left + len(x_batch_test)
        ret = metrics.print_metrics_binary(test_raw[1], predictions)
        accs_test.append(ret["acc"])
        pickle.dump(accs_test, open(f'accs_test_multisplit_BS{args.batch_size}_NC{args.num_clients}_Clus{args.no_clusters}_GE{args.epochs}_LE{args.local_epochs}_lr{args.lr}_ifclust{args.if_cluster}.pkl', 'wb'))

    elif args.mode == 'test':

        # ensure that the code uses test_reader
        del train_reader
        del val_reader
        del train_raw
        del val_raw

        test_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'test'),
                                                listfile=os.path.join(args.data, 'test_listfile.csv'),
                                                period_length=48.0)
        ret = utils.load_data(test_reader, discretizer, normalizer, args.small_part,
                            return_names=True)

        data = ret["data"][0]
        labels = ret["data"][1]
        names = ret["names"]

        predictions = model.predict(data, batch_size=args.batch_size, verbose=1)
        predictions = np.array(predictions)[:, 0]
        metrics.print_metrics_binary(labels, predictions)

        path = os.path.join(args.output_dir, "test_predictions", os.path.basename(args.load_state)) + ".csv"
        utils.save_results(names, predictions, labels, path)

    else:
        raise ValueError("Wrong value for args.mode")
