"""
VFL training with the MIMIC-III dataset
"""

import numpy as np
import argparse
import os
import imp
import re
import math
import copy
import random

from sklearn.cluster import KMeans
from mimic3models.in_hospital_mortality import utils
from mimic3benchmark.readers import InHospitalMortalityReader
from mimic3models.preprocessing import Discretizer, Normalizer
from mimic3models import metrics
from mimic3models import keras_utils
from mimic3models import common_utils
import tensorflow as tf
from keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.python.eager import backprop
import pickle
from sklearn import metrics as skmetrics
from tqdm import tqdm
from tensorflow.keras.models import Model

from scipy.optimize import minimize
from scipy.optimize import Bounds
from scipy.optimize import NonlinearConstraint
from scipy.optimize import BFGS

def argparser():
    """
    Parse input arguments
    """
    import sys
    workers = int(sys.argv[2])
    parser = argparse.ArgumentParser()
    common_utils.add_common_arguments(parser)
    parser.add_argument('--seed', type=int, nargs='?', default=42,
                            help='Random seed to be used.')
    parser.add_argument('--target_repl_coef', type=float, default=0.0)
    parser.add_argument('--data', type=str, help='Path to the data of in-hospital mortality task',
                        default=os.path.join(os.path.dirname(__file__), '../../data/in-hospital-mortality/'))
    parser.add_argument('--output_dir', type=str, help='Directory relative which all output files are stored',
                        default='.')
    parser.add_argument('--num_clients', type=int, help='Number of clients to split data between vertically',
                        default=2)
    parser.add_argument('--local_epochs', type=int, help='Number of local epochs to run at each client before synchronizing',
                        default=1)
    parser.add_argument('--prob', type=float, nargs=workers+1, default=1,
                            help='Indicates probability distribution to use for workers.')
    parser.add_argument('--lr_type', type=str, nargs='?', default="hetero",
                            help='Indicates if learning rates should be heterogeneous or homogeneous.')
    parser.add_argument('--priori', type=str, nargs='?', default="a",
                            help='Adjust learning rate adaptively or not')
    parser.add_argument('--change', type=int, nargs=workers, default=None,
                            help='Use google cluster data to change probs over time or not')
    #parser.add_argument('--batch_type', type=str, nargs='?', default="hetero",
    #                        help='Indicates if batch sizes should be heterogeneous or homogeneous.')
                        
    args = parser.parse_args()
    print("*"*80, "\n\n", args, "\n\n", "*"*80)
    return args

def opt(prob, K, Q, lr_type, params1, params2, grads1, grads2, 
        L_prev, Lk_prev, Ls, Lks):
    """
    Find best step sizes
    """

    # Values of L and L_k based on previous testing
    L = 3000
    L_k = [100] * (num_clients+1)

    # Original code for getting estimates of L and L_k 
    #for i in range(len(params1)):
    #    for j in range(len(params1[i])):
    #        Li = abs(np.linalg.norm(grads1[i][j] - grads2[i][j]) / 
    #                np.linalg.norm(params1[i][j] - params2[i][j]))
    #        if math.isinf(Li) or Li == 0 or Li > 10000:
    #            continue
    #        L_k[i] = max(L_k[i], Li)
    #    Lks[i].append(L_k[i])
    #    L_k[i] = max(L_k[i], Lk_prev[i])
    #L = np.max(L_k)
    ##L = Ls[-1] + 0.01*(L - Ls[-1])
    #Ls.append(L)
    #L = max(L, L_prev)
    #print(f"L = {L}")
    #print(f"L_k = {L_k}")
    #print(f"probs = {prob}")

    prob = np.array(prob)
    L_k = np.array(L_k)
    # Objective Function
    def sum_neg(eta):
        return -1*np.sum(eta) 
        
    # Constraint
    def cons_1(eta):
        return (L*np.sum(prob*eta)
                    + 2*(Q**2)*K*np.sum((L_k**2)*prob*(eta**2)))
    con_bound = 4*np.sum(prob) - np.sum(prob**2) - 2*(K+1)
    nonlinear_constraint = NonlinearConstraint(cons_1, 0, con_bound)
    constraints=[nonlinear_constraint]
        
    # Set bounds
    bounds = Bounds(np.zeros((K+1))+1e-15, np.ones((K+1)))
    eta = np.ones((K+1))*0.5
    if(lr_type == "homo"):
        bounds = Bounds(1e-15, 1)
        eta = 0.5

    # Optimize step size(s)
    res = minimize(sum_neg, eta, method='trust-constr', constraints=constraints, 
            options={'gtol': 1e-15, 'verbose': 0, 'maxiter': 500}, bounds=bounds)

    lr = res['x']
    if(lr_type == "homo"):
        lr = np.ones((K+1))*lr
    print(f"LR = {lr}")
    return lr



if __name__ == "__main__":
    
    # Parse input arguments
    args = argparser()
    np.random.seed(args.seed)
    tf.random.set_seed(args.seed)
    random.seed(42)
    num_clients = args.num_clients
    local_epochs = args.local_epochs
    lr = args.lr
    prob = args.prob 

    if args.small_part:
        args.save_every = 2**30

    target_repl = (args.target_repl_coef > 0.0 and args.mode == 'train')

    # Build readers, discretizers, normalizers
    train_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'train'),
                                            listfile=os.path.join(args.data, 'train_listfile.csv'),
                                            period_length=48.0)

    val_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'train'),
                                        listfile=os.path.join(args.data, 'val_listfile.csv'),
                                        period_length=48.0)

    test_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'test'),
                                            listfile=os.path.join(args.data, 'test_listfile.csv'),
                                            period_length=48.0)

    discretizer = Discretizer(timestep=float(args.timestep),
                            store_masks=True,
                            impute_strategy='previous',
                            start_time='zero')

    discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
    cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

    normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
    normalizer_state = args.normalizer_state
    if normalizer_state is None:
        normalizer_state = 'ihm_ts{}.input_str:{}.start_time:zero.normalizer'.format(args.timestep, args.imputation)
        normalizer_state = os.path.join(os.path.dirname(__file__), normalizer_state)
    normalizer.load_params(normalizer_state)

    args_dict = dict(args._get_kwargs())
    args_dict['header'] = discretizer_header
    args_dict['task'] = 'ihm'
    args_dict['target_repl'] = target_repl
    args_dict['downstream_clients'] = num_clients # total number of vertical partitions present

    models = []
    # Make models for each client
    for i in range(num_clients+1):
        # Build the model
        args.network = "mimic3models/keras_models/lstm"
        if i < num_clients:
            args.network += "_bottom.py"
        else:
            args.network += "_top.py"

        print("==> using model {}".format(args.network))
        model_module = imp.load_source(os.path.basename(args.network), args.network)
        model = model_module.Network(input_dim=int(76/num_clients), **args_dict)

        # Compile the model
        print("==> compiling the model")
        optimizer_config = tf.keras.optimizers.Adam(
                    learning_rate=lr, beta_1=args.beta_1)
        if target_repl:
            loss = ['binary_crossentropy'] * 2
            loss_weights = [1 - args.target_repl_coef, args.target_repl_coef]
        else:
            loss = 'binary_crossentropy'
            loss_weights = None

        model.compile(optimizer=optimizer_config,
                    loss=loss,
                    loss_weights=loss_weights)
        model.summary()
        models.append(model)

    # Load model weights
    n_trained_chunks = 0
    if args.load_state != "":
        model.load_weights(args.load_state)
        n_trained_chunks = int(re.match(".*epoch([0-9]+).*", args.load_state).group(1))


    """
    Uncomment first block if first time running.
    Use second block after running once for faster startup.
    """
    # Read data from file and save to pickle
    train_raw = utils.load_data(train_reader, discretizer, normalizer, args.small_part)
    pickle.dump(train_raw, open('train_raw1.pkl', 'wb'))
    val_raw = utils.load_data(val_reader, discretizer, normalizer, args.small_part)
    pickle.dump(val_raw, open('val_raw1.pkl', 'wb'))
    test_raw = utils.load_data(test_reader, discretizer, normalizer, args.small_part)
    pickle.dump(test_raw, open('test_raw1.pkl', 'wb'))
    
    # Read data from pickle
    # train_raw = pickle.load(open('train_raw1.pkl', 'rb'))
    # val_raw = pickle.load(open('val_raw1.pkl', 'rb'))
    # test_raw = pickle.load(open('test_raw1.pkl', 'rb'))    
    
    if target_repl:
        T = train_raw[0][0].shape[0]

        def extend_labels(data):
            data = list(data)
            labels = np.array(data[1])  # (B,)
            data[1] = [labels, None]
            data[1][1] = np.expand_dims(labels, axis=-1).repeat(T, axis=1)  # (B, T)
            data[1][1] = np.expand_dims(data[1][1], axis=-1)  # (B, T, 1)
            return data

        train_raw = extend_labels(train_raw)
        val_raw = extend_labels(val_raw)

    # Prepare training

    print("==> training")

    activation = tf.keras.activations.sigmoid 
    coords_per = int(76/num_clients)

    # Training functions
    # @tf.function
    def get_grads(x, y, H, model, server_model, i):
        loss_value = 0
        Hnew = H.copy()
        with backprop.GradientTape() as tape:
            out = model(x, training=True)
            Hnew[i] = out
            logits = server_model(tf.concat(Hnew,axis=1), training=True)
            loss_value = server_model.compiled_loss(y, logits)
        grads = tape.gradient(loss_value, model.trainable_variables 
                                        + server_model.trainable_variables)
        return grads, loss_value

    # @tf.function
    def train_step(x, y, model, server_model, H, local, i):
        loss_value = 0
        for t in range(local):
            grads, loss_value = get_grads(x, y, H.tolist(), model, server_model, i)
            grads = model.optimizer._clip_gradients(grads)    # pylint: disable=protected-access
            # only use grads up to the 9th index. The last two are for the server model, which is not necessary as the servermodel is fixed
            model.optimizer.apply_gradients(zip(grads[:9],
                                                model.trainable_variables))
        return loss_value

    # @tf.function
    def getserver_grads(y, H, server_model):
        loss_value = 0
        Hnew = H.copy()
        with backprop.GradientTape() as tape:
            logits = server_model(tf.concat(Hnew,axis=1), training=True)
            loss_value = server_model.compiled_loss(y, logits)
        #grads = tape.gradient(loss_value, model.trainable_variables 
        #                                + server_model.trainable_variables)
        grads = tape.gradient(loss_value, server_model.trainable_variables)
        return grads, loss_value

    # @tf.function
    def trainserver_step(y, server_model, H, local):
        global args
        loss_value = 0
        for t in range(local):
            grads, loss_value = getserver_grads(y, H, server_model)
            # grads, loss_value = getserver_grads(y, H.tolist(), server_model)
            grads = server_model.optimizer._clip_gradients(grads)    # pylint: disable=protected-access
            # since we are only getting gradient for the server model trainable variables, we can 
            # just pass in the entire grads list
            server_model.optimizer.apply_gradients(zip(grads,
                                            server_model.trainable_variables))
        return loss_value

    # Get client embeddings
    # @tf.function
    def forward(x, y, model):
        out = model(x, training=False)
        return out 

    # Get predicted labels to calculating accuracy 
    # @tf.function
    def predict(x, y, models):
        out = []
        for i in range(len(models)-1):
            x_local = x[:,:,coords_per*i:coords_per*(i+1)]
            out.append(forward(x_local, y, models[i]))
        logits = models[-1](tf.concat(out,axis=1), training=False)
        loss = models[-1].compiled_loss(y, logits)
        return logits , loss

    # Split data into batches 
    train_dataset = tf.data.Dataset.from_tensor_slices((
                                    train_raw[0], 
                                    train_raw[1].reshape(-1,1)))
    train_dataset_static_for_logging = tf.data.Dataset.from_tensor_slices((
                                    train_raw[0], 
                                    train_raw[1].reshape(-1,1)))
    
    train_dataset = train_dataset.batch(args.batch_size)
    train_dataset_static_for_logging = train_dataset_static_for_logging.batch(args.batch_size)

    test_dataset = tf.data.Dataset.from_tensor_slices((
                                    test_raw[0], 
                                    test_raw[1].reshape(-1,1)))
    test_dataset_static_for_logging = tf.data.Dataset.from_tensor_slices((
                                    test_raw[0], 
                                    test_raw[1].reshape(-1,1)))
    
    test_dataset = test_dataset.batch(args.batch_size)
    test_dataset_static_for_logging = test_dataset_static_for_logging.batch(args.batch_size)


    workers = num_clients
    orig_prob = copy.deepcopy(prob)
    lrs = [args.lr] * (num_clients+1)

    for i in range(num_clients+1):
        models[i].optimizer.lr = lrs[i]
    
    losses = []
    accs_train = []
    accs_test = []

    # Old code for testing values of L
    Ls = [1000]
    Lks = []
    for i in range(num_clients+1):
        Lks.append([1000])
    window = 5
    e = 0

    estimate_prob = copy.deepcopy(prob)
    cpus = None
    # For adaptive tests, set step sizes according to probs
    if args.change is not None:
        cpus = np.load('google-cpu-full.npy')
        rand_inds = args.change
        cpus = cpus[rand_inds,0:args.epochs]
        estimate_prob = [0] * (num_clients+1) 
        for i in range(num_clients):
            estimate_prob[i] = 1 - np.max(cpus[i])
        estimate_prob[-1] = 1.0

        lrs = opt(estimate_prob, num_clients, local_epochs, args.lr_type, 
                None, None, None, None, 0, 0, Ls, Lks)
        for i in range(num_clients+1):
            models[i].optimizer.lr = lrs[i]
    print(lrs)

    # Get initial loss and accuracy
    predictions = np.zeros(train_raw[0].shape[0])
    left = 0
    total_loss = 0
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset_static_for_logging):
        logits, loss_aggregate_model = predict(x_batch_train, y_batch_train, models)
        total_loss += loss_aggregate_model
        predictions[left: left + len(x_batch_train)] = tf.reshape(tf.identity(logits),-1)
        left = left + len(x_batch_train)
    losses.append(total_loss/len(train_dataset_static_for_logging))
    print(f"************Loss = {losses[-1]}***************")
    pickle.dump(losses, open(f'losses_varlr_BS{args.batch_size}_NC{args.num_clients}_LE{args.local_epochs}_lr{args.lr_type}_prob{orig_prob}_priori{args.priori}.pkl', 'wb'))

    # Calculate Training Accuracy 
    ret = metrics.print_metrics_binary(train_raw[1], predictions, verbose=0)
    accs_train.append(list(ret.items()))
    pickle.dump(accs_train, open(f'accs_train_varlr_BS{args.batch_size}_NC{args.num_clients}_LE{args.local_epochs}_lr{args.lr_type}_prob{orig_prob}_priori{args.priori}.pkl', 'wb'))
    print("Train acc", ret)

    # Main training loop
    for epoch in tqdm(range(args.epochs)):
        #train_dataset = train_dataset.shuffle(buffer_size=train_raw[0].shape[0])
        #print("\nStart of epoch %d" % (epoch,))

        Hs = np.empty((math.ceil(train_raw[0].shape[0] / args.batch_size), num_clients), dtype=object)
        Hs.fill([])
        num_batches = 0

        # Update probabilities for adaptive experiments
        if args.change is not None:
            #prob[0] = 0.25*math.sin(epoch/30)+0.7
            #prob[1] = 0.25*math.sin(epoch/20-math.pi/16)+0.7
            #prob[2] = 0.25*math.sin(epoch/10+math.pi/16)+0.7
            #prob[3] = 0.25*math.sin(epoch/40+math.pi/8)+0.7
            prob[:4] = 1 - cpus[:,epoch]

        # Iterate over the batches of the dataset.
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            params1 = []
            params2 = []
            grads1 = []
            grads2 = []
            num_batches += 1

            # Exchange client embeddings
            for i in range(num_clients):
                x_local = x_batch_train[:,:,coords_per*i:coords_per*(i+1)]
                Hs[step,i] = forward(x_local, y_batch_train, models[i])

            # For calculating L, get initial grads and parameters
            #for i in range(num_clients):
            #    x_local = x_batch_train[:,:,coords_per*i:coords_per*(i+1)]
            #    params1.append(copy.deepcopy(models[i].trainable_variables))
            #    grad, _ = get_grads(x_local, y_batch_train, 
            #            Hs[step].tolist(), models[i], models[-1], i)
            #    grads1.append(grad[:9])
            #params1.append(copy.deepcopy(models[-1].trainable_variables))
            #grad, _ = getserver_grads(y_batch_train, Hs[step].tolist(), models[-1])
            #grads1.append(grad)

            # Train for each client 
            client_losses = [0]*num_clients
            for i in range(num_clients):
                x_local = x_batch_train[:,:,coords_per*i:coords_per*(i+1)]
                H = copy.deepcopy(Hs[step])
                le = 0
                for local in range(local_epochs):
                    le += random.random() < prob[i]
                client_losses[i] = train_step(x_local, y_batch_train, models[i], models[-1], H, le, i)

                # Use exponential averaging to estimate probability
                current_prob = le/local_epochs - estimate_prob[i]
                estimate_prob[i] = estimate_prob[i] + 0.01*current_prob

            # Send latest embeddings to server
            #H_most_recent = [None]*num_clients
            #for i in range(num_clients):
            #    x_local = x_batch_train[:,:,coords_per*i:coords_per*(i+1)]
            #    H_most_recent[i] = forward(x_local, y_batch_train, models[i])

            # Train server
            loss_final = trainserver_step(y_batch_train, models[-1], Hs[step].tolist(), local_epochs)

            # For calculating L, get final grads and parameters
            #for i in range(num_clients):
            #    x_local = x_batch_train[:,:,coords_per*i:coords_per*(i+1)]
            #    params2.append(copy.deepcopy(models[i].trainable_variables))
            #    grad, _ = get_grads(x_local, y_batch_train, H_most_recent, models[i], models[-1], i)
            #    grads2.append(grad[:9])
            #params2.append(copy.deepcopy(models[-1].trainable_variables))
            #grad, _ = getserver_grads(y_batch_train, Hs[step].tolist(), models[-1])
            #grads2.append(grad)

            if args.priori == "adapt":
                # Old code for calculating L
                # Keep track of previous L's, keep max over a window
                newL = 0
                if e < window and epoch > 0:
                    newL = np.max(Ls)
                if e >= window:
                    newL = np.max(Ls[-1*window:])

                newLk = []
                for i in range(num_clients+1):
                    newLk.append(0)
                    if e < window and epoch > 0:
                        newLk[i] = np.max(Lks)
                    if e >= window:
                        newLk[i] = np.max(Lks[-1*window:])
                e += 1

                # Calculate new learning rates
                lrs = opt(estimate_prob, num_clients, local_epochs, args.lr_type, 
                        params1, params2, grads1, grads2, newL, newLk, Ls, Lks)

                for i in range(num_clients+1):
                    models[i].optimizer.lr = lrs[i]
            
        #print("==> predicting")
        # Iterate over the batches of the dataset to calculate loss/accuracy
        # Calculate Training Loss
        predictions = np.zeros(train_raw[0].shape[0])
        left = 0
        total_loss = 0
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset_static_for_logging):
            logits, loss_aggregate_model = predict(x_batch_train, y_batch_train, models)
            total_loss += loss_aggregate_model
            predictions[left: left + len(x_batch_train)] = tf.reshape(tf.identity(logits),-1)
            left = left + len(x_batch_train)
        losses.append(total_loss/num_batches)
        #print(f"************Loss = {losses[-1]}***************")
        pickle.dump(losses, open(f'losses_varlr_BS{args.batch_size}_NC{args.num_clients}_LE{args.local_epochs}_lr{args.lr_type}_prob{orig_prob}_priori{args.priori}.pkl', 'wb'))

        # Calculate Training Accuracy 
        ret = metrics.print_metrics_binary(train_raw[1], predictions, verbose=0)
        accs_train.append(list(ret.items()))
        pickle.dump(accs_train, open(f'accs_train_varlr_BS{args.batch_size}_NC{args.num_clients}_LE{args.local_epochs}_lr{args.lr_type}_prob{orig_prob}_priori{args.priori}.pkl', 'wb'))
        #print("Train acc", ret['acc'])

        # Calculate Test Accuracy 
        predictions = np.zeros(test_raw[0].shape[0])
        left = 0
        for step, (x_batch_test, y_batch_test) in enumerate(test_dataset_static_for_logging):
            logits, _ = predict(x_batch_test, y_batch_test, models)
            predictions[left: left + len(x_batch_test)] = tf.reshape(tf.identity(logits),-1)
            left = left + len(x_batch_test)
        ret = metrics.print_metrics_binary(test_raw[1], predictions, verbose=0)
        accs_test.append(list(ret.items()))
        pickle.dump(accs_test, open(f'accs_test_varlr_BS{args.batch_size}_NC{args.num_clients}_LE{args.local_epochs}_lr{args.lr_type}_prob{orig_prob}_priori{args.priori}.pkl', 'wb'))
        #print("test acc", ret['acc'])
