from __future__ import absolute_import
from __future__ import print_function

import numpy as np
import argparse
import os
import imp
import re
import math

from mimic3models.in_hospital_mortality import utils
from mimic3benchmark.readers import InHospitalMortalityReader

from mimic3models.preprocessing import Discretizer, Normalizer
from mimic3models import metrics
from mimic3models import keras_utils
from mimic3models import common_utils

import tensorflow as tf
from keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.python.eager import backprop
import pickle
from sklearn import metrics as skmetrics

parser = argparse.ArgumentParser()
common_utils.add_common_arguments(parser)
parser.add_argument('--target_repl_coef', type=float, default=0.0)
parser.add_argument('--data', type=str, help='Path to the data of in-hospital mortality task',
                    default=os.path.join(os.path.dirname(__file__), '../../data/in-hospital-mortality/'))
parser.add_argument('--output_dir', type=str, help='Directory relative which all output files are stored',
                    default='.')
parser.add_argument('--num_clients', type=int, help='Number of clients to split data between vertically',
                    default=2)
parser.add_argument('--local_epochs', type=int, help='Number of local epochs to run at each client before synchronizing',
                    default=1)
args = parser.parse_args()
print(args)
num_clients = args.num_clients
local_epochs = args.local_epochs

if args.small_part:
    args.save_every = 2**30

target_repl = (args.target_repl_coef > 0.0 and args.mode == 'train')

# Build readers, discretizers, normalizers
train_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'train'),
                                         listfile=os.path.join(args.data, 'train_listfile.csv'),
                                         period_length=48.0)

val_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'train'),
                                       listfile=os.path.join(args.data, 'val_listfile.csv'),
                                       period_length=48.0)

test_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'test'),
                                        listfile=os.path.join(args.data, 'test_listfile.csv'),
                                        period_length=48.0)

discretizer = Discretizer(timestep=float(args.timestep),
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
normalizer_state = args.normalizer_state
if normalizer_state is None:
    normalizer_state = 'ihm_ts{}.input_str:{}.start_time:zero.normalizer'.format(args.timestep, args.imputation)
    normalizer_state = os.path.join(os.path.dirname(__file__), normalizer_state)
normalizer.load_params(normalizer_state)

args_dict = dict(args._get_kwargs())
args_dict['header'] = discretizer_header
args_dict['task'] = 'ihm'
args_dict['target_repl'] = target_repl

models = []
# Make models for each client
for i in range(num_clients):
    # Build the model
    print("==> using model {}".format(args.network))
    model_module = imp.load_source(os.path.basename(args.network), args.network)
    model = model_module.Network(input_dim=int(76/num_clients), multi=True, **args_dict)

    # Compile the model
    print("==> compiling the model")
    optimizer_config = tf.keras.optimizers.Adam(
                learning_rate=args.lr, beta_1=args.beta_1)
    if target_repl:
        loss = ['binary_crossentropy'] * 2
        loss_weights = [1 - args.target_repl_coef, args.target_repl_coef]
    else:
        loss = 'binary_crossentropy'
        loss_weights = None

    model.compile(optimizer=optimizer_config,
                  loss=loss,
                  loss_weights=loss_weights)
    model.summary()
    models.append(model)

# Load model weights
n_trained_chunks = 0
if args.load_state != "":
    model.load_weights(args.load_state)
    n_trained_chunks = int(re.match(".*epoch([0-9]+).*", args.load_state).group(1))


# Read data
train_raw = utils.load_data(train_reader, discretizer, normalizer, args.small_part)
val_raw = utils.load_data(val_reader, discretizer, normalizer, args.small_part)
test_raw = utils.load_data(test_reader, discretizer, normalizer, args.small_part)

if target_repl:
    T = train_raw[0][0].shape[0]

    def extend_labels(data):
        data = list(data)
        labels = np.array(data[1])  # (B,)
        data[1] = [labels, None]
        data[1][1] = np.expand_dims(labels, axis=-1).repeat(T, axis=1)  # (B, T)
        data[1][1] = np.expand_dims(data[1][1], axis=-1)  # (B, T, 1)
        return data

    train_raw = extend_labels(train_raw)
    val_raw = extend_labels(val_raw)

if args.mode == 'train':

    # Prepare training

    print("==> training")

    activation = tf.keras.activations.sigmoid 
    coords_per = int(76/num_clients)

    @tf.function
    def apply_grads(x, y, H, model):
        loss_value = 0
        with backprop.GradientTape() as tape:
            out = model(x, training=True)
            logits = activation(H + out)
            loss_value = model.compiled_loss(y, logits)#, regularization_losses=model.losses)
        grads = tape.gradient(loss_value, model.trainable_variables)
        return grads, loss_value

    def train_step(x, y, model, H, local):
        loss_value = 0
        for t in range(local):
            grads, loss_value = apply_grads(x, y, H, model)
            grads = model.optimizer._clip_gradients(grads)    # pylint: disable=protected-access
            model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return loss_value

    @tf.function
    def forward(x, y, model):
        out = model(x, training=False)
        return out 

    @tf.function
    def predict(x, y, models):
        out = 0
        for i in range(len(models)):
            x_local = x[:,:,coords_per*i:coords_per*(i+1)]
            out += models[i](x_local, training=False)
        logits = activation(out)
        return logits 

    # Split data vertically
    train_dataset = tf.data.Dataset.from_tensor_slices((
                                    train_raw[0], 
                                    train_raw[1].reshape(-1,1)))
    
    train_dataset = train_dataset.batch(args.batch_size)
    losses = []
    accs_train = []
    accs_test = []
    for epoch in range(args.epochs):
        #train_dataset = train_dataset.shuffle(buffer_size=train_raw[0].shape[0])
        print("\nStart of epoch %d" % (epoch,))

        # Iterate over the batches of the dataset.
        Hs = np.empty((math.ceil(train_raw[0].shape[0] / args.batch_size), num_clients), dtype=object)
        Hs.fill([])
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            loss_vals = [0] * num_clients

            # Train for each client 
            for i in range(num_clients):
                x_local = x_batch_train[:,:,coords_per*i:coords_per*(i+1)]
                Hs[step,i] = forward(x_local, y_batch_train, models[i])

            for i in range(num_clients):
                x_local = x_batch_train[:,:,coords_per*i:coords_per*(i+1)]
                H = tf.cast(np.sum(Hs[step]) - Hs[step][i], tf.float32)
                loss_vals[i] = train_step(x_local, y_batch_train, models[i], H, local_epochs)

            losses.append(loss_vals[0])
            pickle.dump(losses, open('losses_multi_k'+str(num_clients)+'.p', 'wb'))

            # Log every 200 batches.
            if step % 20 == 0:
                print(
                    "Training loss (for one batch) at step %d: %.4f"
                    % (step, float(losses[-1]))
                )
                print("Seen so far: %s samples" % ((step + 1) * args.batch_size))

        print("==> predicting")
        # Iterate over the batches of the dataset.
        predictions = np.zeros(train_raw[0].shape[0])
        for x_val in range(train_raw[0].shape[0]):
            logits = predict(np.expand_dims(train_raw[0][x_val],axis=0), 
                    train_raw[1][x_val].reshape(1,1), models)
            predictions[x_val] = tf.reshape(tf.identity(logits),-1)
        ret = metrics.print_metrics_binary(train_raw[1], predictions)
        accs_train.append(ret["acc"])
        pickle.dump(accs_train, open('accs_train_multi_k'+str(num_clients)+'.p', 'wb'))

        print("==> predicting")
        # Iterate over the batches of the dataset.
        predictions = np.zeros(test_raw[0].shape[0])
        for x_val in range(test_raw[0].shape[0]):
            logits = predict(np.expand_dims(test_raw[0][x_val],axis=0), 
                    test_raw[1][x_val].reshape(1,1), models)
            predictions[x_val] = tf.reshape(tf.identity(logits),-1)
        ret = metrics.print_metrics_binary(test_raw[1], predictions)
        accs_test.append(ret["acc"])
        pickle.dump(accs_test, open('accs_test_multi_k'+str(num_clients)+'.p', 'wb'))

elif args.mode == 'test':

    # ensure that the code uses test_reader
    del train_reader
    del val_reader
    del train_raw
    del val_raw

    test_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'test'),
                                            listfile=os.path.join(args.data, 'test_listfile.csv'),
                                            period_length=48.0)
    ret = utils.load_data(test_reader, discretizer, normalizer, args.small_part,
                          return_names=True)

    data = ret["data"][0]
    labels = ret["data"][1]
    names = ret["names"]

    predictions = model.predict(data, batch_size=args.batch_size, verbose=1)
    predictions = np.array(predictions)[:, 0]
    metrics.print_metrics_binary(labels, predictions)

    path = os.path.join(args.output_dir, "test_predictions", os.path.basename(args.load_state)) + ".csv"
    utils.save_results(names, predictions, labels, path)

else:
    raise ValueError("Wrong value for args.mode")
