"""
Async-VFL training with the MIMIC-III dataset
"""

import numpy as np
import argparse
import os
import imp
import re
import math
import copy
import random

from sklearn.cluster import KMeans
from mimic3models.in_hospital_mortality import utils
from mimic3benchmark.readers import InHospitalMortalityReader
from mimic3models.preprocessing import Discretizer, Normalizer
from mimic3models import metrics
from mimic3models import keras_utils
from mimic3models import common_utils
import tensorflow as tf
from keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.python.eager import backprop
import pickle
from sklearn import metrics as skmetrics
from tqdm import tqdm
from tensorflow.keras.models import Model

from scipy.optimize import minimize
from scipy.optimize import Bounds
from scipy.optimize import NonlinearConstraint
from scipy.optimize import BFGS

def argparser():
    """
    Parse input arguments
    """
    import sys
    workers = int(sys.argv[2])
    parser = argparse.ArgumentParser()
    common_utils.add_common_arguments(parser)
    parser.add_argument('--seed', type=int, nargs='?', default=42,
                            help='Random seed to be used.')
    parser.add_argument('--target_repl_coef', type=float, default=0.0)
    parser.add_argument('--data', type=str, help='Path to the data of in-hospital mortality task',
                        default=os.path.join(os.path.dirname(__file__), '../../data/in-hospital-mortality/'))
    parser.add_argument('--output_dir', type=str, help='Directory relative which all output files are stored',
                        default='.')
    parser.add_argument('--num_clients', type=int, help='Number of clients to split data between vertically',
                        default=2)
    parser.add_argument('--prob', type=float, nargs=workers+1, default=1,
                            help='Indicates probability distribution to use for workers.')
    parser.add_argument('--lr_type', type=str, nargs='?', default="hetero",
                            help='Indicates if learning rates should be heterogeneous or homogeneous.')
    parser.add_argument('--server_time', type=int, nargs='?', default=1,
                            help='Number of local iteration time it takes for server to communicate.')
    #parser.add_argument('--batch_type', type=str, nargs='?', default="hetero",
    #                        help='Indicates if batch sizes should be heterogeneous or homogeneous.')
                        
    args = parser.parse_args()
    print("*"*80, "\n\n", args, "\n\n", "*"*80)
    return args


if __name__ == "__main__":
    
    # Parse input arguments
    args = argparser()
    np.random.seed(args.seed)
    tf.random.set_seed(args.seed)
    random.seed(42)
    num_clients = args.num_clients
    lr = args.lr
    prob = args.prob 

    if args.small_part:
        args.save_every = 2**30

    target_repl = (args.target_repl_coef > 0.0 and args.mode == 'train')

    # Build readers, discretizers, normalizers
    train_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'train'),
                                            listfile=os.path.join(args.data, 'train_listfile.csv'),
                                            period_length=48.0)

    val_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'train'),
                                        listfile=os.path.join(args.data, 'val_listfile.csv'),
                                        period_length=48.0)

    test_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'test'),
                                            listfile=os.path.join(args.data, 'test_listfile.csv'),
                                            period_length=48.0)

    discretizer = Discretizer(timestep=float(args.timestep),
                            store_masks=True,
                            impute_strategy='previous',
                            start_time='zero')

    discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
    cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

    normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
    normalizer_state = args.normalizer_state
    if normalizer_state is None:
        normalizer_state = 'ihm_ts{}.input_str:{}.start_time:zero.normalizer'.format(args.timestep, args.imputation)
        normalizer_state = os.path.join(os.path.dirname(__file__), normalizer_state)
    normalizer.load_params(normalizer_state)

    args_dict = dict(args._get_kwargs())
    args_dict['header'] = discretizer_header
    args_dict['task'] = 'ihm'
    args_dict['target_repl'] = target_repl
    args_dict['downstream_clients'] = num_clients # total number of vertical partitions present

    models = []
    # Make models for each client
    for i in range(num_clients+1):
        # Build the model
        args.network = "mimic3models/keras_models/lstm"
        if i < num_clients:
            args.network += "_bottom.py"
        else:
            args.network += "_top.py"

        print("==> using model {}".format(args.network))
        model_module = imp.load_source(os.path.basename(args.network), args.network)
        model = model_module.Network(input_dim=int(76/num_clients), **args_dict)

        # Compile the model
        print("==> compiling the model")
        optimizer_config = tf.keras.optimizers.Adam(
                    learning_rate=lr, beta_1=args.beta_1)
        if target_repl:
            loss = ['binary_crossentropy'] * 2
            loss_weights = [1 - args.target_repl_coef, args.target_repl_coef]
        else:
            loss = 'binary_crossentropy'
            loss_weights = None

        model.compile(optimizer=optimizer_config,
                    loss=loss,
                    loss_weights=loss_weights)
        model.summary()
        models.append(model)

        if i == num_clients:
            orig_weights = model.get_weights()
            # Create copies of server model for all clients
            server_models = [] 
            for i in range(num_clients):
                model_module = imp.load_source(os.path.basename(args.network), args.network)
                model = model_module.Network(input_dim=int(76/num_clients), **args_dict)
                model.compile(optimizer=optimizer_config,
                            loss=loss,
                            loss_weights=loss_weights)
                model.set_weights(orig_weights)
                server_models.append(model)

    # Load model weights
    n_trained_chunks = 0
    if args.load_state != "":
        model.load_weights(args.load_state)
        n_trained_chunks = int(re.match(".*epoch([0-9]+).*", args.load_state).group(1))


    """
    Uncomment first block if first time running.
    Use second block after running once for faster startup.
    """
    # Read data from file and save to pickle
    train_raw = utils.load_data(train_reader, discretizer, normalizer, args.small_part)
    pickle.dump(train_raw, open('train_raw1.pkl', 'wb'))
    val_raw = utils.load_data(val_reader, discretizer, normalizer, args.small_part)
    pickle.dump(val_raw, open('val_raw1.pkl', 'wb'))
    test_raw = utils.load_data(test_reader, discretizer, normalizer, args.small_part)
    pickle.dump(test_raw, open('test_raw1.pkl', 'wb'))
    
    # Read data from pickle
    # train_raw = pickle.load(open('train_raw1.pkl', 'rb'))
    # val_raw = pickle.load(open('val_raw1.pkl', 'rb'))
    # test_raw = pickle.load(open('test_raw1.pkl', 'rb'))    
    
    if target_repl:
        T = train_raw[0][0].shape[0]

        def extend_labels(data):
            data = list(data)
            labels = np.array(data[1])  # (B,)
            data[1] = [labels, None]
            data[1][1] = np.expand_dims(labels, axis=-1).repeat(T, axis=1)  # (B, T)
            data[1][1] = np.expand_dims(data[1][1], axis=-1)  # (B, T, 1)
            return data

        train_raw = extend_labels(train_raw)
        val_raw = extend_labels(val_raw)

    # Prepare training

    print("==> training")

    activation = tf.keras.activations.sigmoid 
    coords_per = int(76/num_clients)

    # Training functions
    # @tf.function
    def get_grads(x, y, H, model, server_model, i):
        loss_value = 0
        Hnew = H.copy()
        with backprop.GradientTape() as tape:
            out = model(x, training=True)
            Hnew[i] = out
            logits = server_model(tf.concat(Hnew,axis=1), training=True)
            loss_value = server_model.compiled_loss(y, logits)
        grads = tape.gradient(loss_value, model.trainable_variables 
                                        + server_model.trainable_variables)
        return grads, loss_value

    # @tf.function
    def train_step(x, y, model, server_model, H, local, i):
        loss_value = 0
        for t in range(local):
            grads, loss_value = get_grads(x, y, H.tolist(), model, server_model, i)
            grads = model.optimizer._clip_gradients(grads)    # pylint: disable=protected-access
            # only use grads up to the 9th index. The last two are for the server model, which is not necessary as the servermodel is fixed
            model.optimizer.apply_gradients(zip(grads[:9],
                                                model.trainable_variables))
        return loss_value

    # @tf.function
    def getserver_grads(y, H, server_model):
        loss_value = 0
        Hnew = H.copy()
        with backprop.GradientTape() as tape:
            logits = server_model(tf.concat(Hnew,axis=1), training=True)
            loss_value = server_model.compiled_loss(y, logits)
        #grads = tape.gradient(loss_value, model.trainable_variables 
        #                                + server_model.trainable_variables)
        grads = tape.gradient(loss_value, server_model.trainable_variables)
        return grads, loss_value

    # @tf.function
    def trainserver_step(y, server_model, H, local):
        global args
        loss_value = 0
        for t in range(local):
            grads, loss_value = getserver_grads(y, H, server_model)
            # grads, loss_value = getserver_grads(y, H.tolist(), server_model)
            grads = server_model.optimizer._clip_gradients(grads)    # pylint: disable=protected-access
            # since we are only getting gradient for the server model trainable variables, we can 
            # just pass in the entire grads list
            server_model.optimizer.apply_gradients(zip(grads,
                                            server_model.trainable_variables))
        return loss_value

    # Get client embeddings
    # @tf.function
    def forward(x, y, model):
        out = model(x, training=False)
        return out 

    # Get predicted labels to calculating accuracy 
    # @tf.function
    def predict(x, y, models):
        out = []
        for i in range(len(models)-1):
            x_local = x[:,:,coords_per*i:coords_per*(i+1)]
            out.append(forward(x_local, y, models[i]))
        logits = models[-1](tf.concat(out,axis=1), training=False)
        loss = models[-1].compiled_loss(y, logits)
        return logits , loss

    # Split data into batches 
    train_dataset = tf.data.Dataset.from_tensor_slices((
                                    train_raw[0], 
                                    train_raw[1].reshape(-1,1)))
    train_dataset_static_for_logging = tf.data.Dataset.from_tensor_slices((
                                    train_raw[0], 
                                    train_raw[1].reshape(-1,1)))
    
    train_dataset = train_dataset.batch(args.batch_size)
    train_dataset_static_for_logging = train_dataset_static_for_logging.batch(args.batch_size)

    test_dataset = tf.data.Dataset.from_tensor_slices((
                                    test_raw[0], 
                                    test_raw[1].reshape(-1,1)))
    test_dataset_static_for_logging = tf.data.Dataset.from_tensor_slices((
                                    test_raw[0], 
                                    test_raw[1].reshape(-1,1)))
    
    test_dataset = test_dataset.batch(args.batch_size)
    test_dataset_static_for_logging = test_dataset_static_for_logging.batch(args.batch_size)


    workers = num_clients
    orig_prob = copy.deepcopy(prob)
    lrs = [args.lr] * (num_clients+1)

    for i in range(num_clients+1):
        models[i].optimizer.lr = lrs[i]

    # Get embeddings for all batches
    Hs = np.empty((math.ceil(train_raw[0].shape[0] / args.batch_size), num_clients), dtype=object)
    server_H = np.empty((math.ceil(train_raw[0].shape[0] / args.batch_size), num_clients), dtype=object)
    client_H = []
    for i in range(num_clients):
        client_H.append(np.empty((math.ceil(train_raw[0].shape[0] / args.batch_size), num_clients), dtype=object))
    Hs.fill([])
    num_batches = 0
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        # Train for each client 
        for i in range(num_clients):
            x_local = x_batch_train[:,:,coords_per*i:coords_per*(i+1)]
            Hs[step,i] = forward(x_local, y_batch_train, models[i])
            server_H[step,i] = forward(x_local, y_batch_train, models[i])
            for k in range(num_clients):
                client_H[k][step,i] = forward(x_local, y_batch_train, models[i])
        num_batches += 1
    
    losses = []
    accs_train = []
    accs_test = []

    # Calculate Training Loss
    predictions = np.zeros(train_raw[0].shape[0])
    left = 0
    total_loss = 0
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset_static_for_logging):
        logits, loss_aggregate_model = predict(x_batch_train, y_batch_train, models)
        total_loss += loss_aggregate_model
        predictions[left: left + len(x_batch_train)] = tf.reshape(tf.identity(logits),-1)
        left = left + len(x_batch_train)
    losses.append(total_loss/num_batches)
    print(f"************Loss = {losses[-1]}***************")
    pickle.dump(losses, open(f'losses_vafl_BS{args.batch_size}_NC{args.num_clients}_lr{args.lr_type}_prob{orig_prob}_server{args.server_time}.pkl', 'wb'))

    # Calculate Training Accuracy 
    ret = metrics.print_metrics_binary(train_raw[1], predictions, verbose=0)
    accs_train.append(list(ret.items()))
    pickle.dump(accs_train, open(f'accs_train_vafl_BS{args.batch_size}_NC{args.num_clients}_lr{args.lr_type}_prob{orig_prob}_server{args.server_time}.pkl', 'wb'))
    #print("Train acc", ret['acc'])

    # Calculate Test Accuracy 
    predictions = np.zeros(test_raw[0].shape[0])
    left = 0
    for step, (x_batch_test, y_batch_test) in enumerate(test_dataset_static_for_logging):
        logits, _ = predict(x_batch_test, y_batch_test, models)
        predictions[left: left + len(x_batch_test)] = tf.reshape(tf.identity(logits),-1)
        left = left + len(x_batch_test)
    ret = metrics.print_metrics_binary(test_raw[1], predictions, verbose=0)
    accs_test.append(list(ret.items()))
    pickle.dump(accs_test, open(f'accs_test_vafl_BS{args.batch_size}_NC{args.num_clients}_lr{args.lr_type}_prob{orig_prob}_server{args.server_time}.pkl', 'wb'))
    #print("test acc", ret['acc'])

    client_le = [0] * num_clients
    client_busy = [0] * num_clients
    server_busy = 0
    queue = {}
    # Main training loop
    for epoch in tqdm(range(args.epochs)):
        #print("\nStart of epoch %d" % (epoch,))

        # Iterate over the batches of the dataset.
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            client_losses = [0]*num_clients
            for i in range(num_clients):
                # Is client still waiting on server?
                if i in queue.keys():
                    continue
                client_busy[i] -= 1
                if client_busy[i] > 0:
                    continue

                # Train for each client with probability
                if random.random() < 1-prob[i]:
                    continue

                # Train client for one local iteration
                x_local = x_batch_train[:,:,coords_per*i:coords_per*(i+1)]
                H = copy.deepcopy(client_H[i][step])
                client_losses[i] = train_step(x_local, y_batch_train, models[i], server_models[i], H, 1, i)
                client_le[i] += 1

                # Enter server queue if training completed
                H_new = forward(x_local, y_batch_train, models[i])
                queue[i] = [step, H_new, copy.deepcopy(y_batch_train)] 

            server_busy -= 1
            # If server isn't busy and there is someone in the queue
            if server_busy < 0 and len(queue) > 0:
                # Train server on latest embeddings
                for k in queue:
                    k_step = queue[k][0]
                    server_H[k_step,k] = queue[k][1]
                    y_local = queue[k][2]
                    client_busy[k] = args.server_time+len(queue)
                    loss_final = trainserver_step(y_local, models[-1], 
                                    server_H[k_step].tolist(), 1)
                server_busy = args.server_time+len(queue)

                # Update client embeddings and server model
                orig_weights = models[-1].get_weights()
                for k in queue:
                    client_H[k] = copy.deepcopy(server_H)
                    server_models[k].set_weights(orig_weights)
                queue.clear()
            
        #print("==> predicting")
        # Iterate over the batches of the dataset to calculate loss/accuracy
        # Calculate Training Loss
        predictions = np.zeros(train_raw[0].shape[0])
        left = 0
        total_loss = 0
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset_static_for_logging):
            logits, loss_aggregate_model = predict(x_batch_train, y_batch_train, models)
            total_loss += loss_aggregate_model
            predictions[left: left + len(x_batch_train)] = tf.reshape(tf.identity(logits),-1)
            left = left + len(x_batch_train)
        losses.append(total_loss/num_batches)
        print(f"************Loss = {losses[-1]}***************")
        pickle.dump(losses, open(f'losses_vafl_BS{args.batch_size}_NC{args.num_clients}_lr{args.lr_type}_prob{orig_prob}_server{args.server_time}.pkl', 'wb'))

        # Calculate Training Accuracy 
        ret = metrics.print_metrics_binary(train_raw[1], predictions, verbose=0)
        accs_train.append(list(ret.items()))
        pickle.dump(accs_train, open(f'accs_train_vafl_BS{args.batch_size}_NC{args.num_clients}_lr{args.lr_type}_prob{orig_prob}_server{args.server_time}.pkl', 'wb'))
        #print("Train acc", ret['acc'])

        # Calculate Test Accuracy 
        predictions = np.zeros(test_raw[0].shape[0])
        left = 0
        for step, (x_batch_test, y_batch_test) in enumerate(test_dataset_static_for_logging):
            logits, _ = predict(x_batch_test, y_batch_test, models)
            predictions[left: left + len(x_batch_test)] = tf.reshape(tf.identity(logits),-1)
            left = left + len(x_batch_test)
        ret = metrics.print_metrics_binary(test_raw[1], predictions, verbose=0)
        accs_test.append(list(ret.items()))
        pickle.dump(accs_test, open(f'accs_test_vafl_BS{args.batch_size}_NC{args.num_clients}_lr{args.lr_type}_prob{orig_prob}_server{args.server_time}.pkl', 'wb'))
        #print("test acc", ret['acc'])
