from __future__ import absolute_import
from __future__ import print_function

import numpy as np
import argparse
import os
import imp
import re
import math

from mimic3models.in_hospital_mortality import utils
from mimic3benchmark.readers import InHospitalMortalityReader

from mimic3models.preprocessing import Discretizer, Normalizer
from mimic3models import metrics
from mimic3models import keras_utils
from mimic3models import common_utils

import tensorflow as tf
from keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.python.keras import callbacks as callbacks_module
from tensorflow.python.eager import backprop
import pickle
from sklearn import metrics as skmetrics

parser = argparse.ArgumentParser()
common_utils.add_common_arguments(parser)
parser.add_argument('--target_repl_coef', type=float, default=0.0)
parser.add_argument('--data', type=str, help='Path to the data of in-hospital mortality task',
                    default=os.path.join(os.path.dirname(__file__), '../../data/in-hospital-mortality/'))
parser.add_argument('--output_dir', type=str, help='Directory relative which all output files are stored',
                    default='.')
args = parser.parse_args()
print(args)

if args.small_part:
    args.save_every = 2**30

target_repl = (args.target_repl_coef > 0.0 and args.mode == 'train')

# Build readers, discretizers, normalizers
train_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'train'),
                                         listfile=os.path.join(args.data, 'train_listfile.csv'),
                                         period_length=48.0)

val_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'train'),
                                       listfile=os.path.join(args.data, 'val_listfile.csv'),
                                       period_length=48.0)

test_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'test'),
                                        listfile=os.path.join(args.data, 'test_listfile.csv'),
                                        period_length=48.0)

discretizer = Discretizer(timestep=float(args.timestep),
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
normalizer_state = args.normalizer_state
if normalizer_state is None:
    normalizer_state = 'ihm_ts{}.input_str:{}.start_time:zero.normalizer'.format(args.timestep, args.imputation)
    normalizer_state = os.path.join(os.path.dirname(__file__), normalizer_state)
normalizer.load_params(normalizer_state)

args_dict = dict(args._get_kwargs())
args_dict['header'] = discretizer_header
args_dict['task'] = 'ihm'
args_dict['target_repl'] = target_repl

# Build the model
print("==> using model {}".format(args.network))
model_module = imp.load_source(os.path.basename(args.network), args.network)
model = model_module.Network(input_dim=38, **args_dict)
suffix = ".bs{}{}{}.ts{}{}".format(args.batch_size,
                                   ".L1{}".format(args.l1) if args.l1 > 0 else "",
                                   ".L2{}".format(args.l2) if args.l2 > 0 else "",
                                   args.timestep,
                                   ".trc{}".format(args.target_repl_coef) if args.target_repl_coef > 0 else "")
model.final_name = args.prefix + model.say_name() + suffix
print("==> model.final_name:", model.final_name)


# Compile the model
print("==> compiling the model")
optimizer_config = tf.keras.optimizers.Adam(
            learning_rate=args.lr, beta_1=args.beta_1)
#args.optimizer
#{'class_name': args.optimizer,
#                    'config': {'lr': args.lr,
#                               'beta_1': args.beta_1}}

# NOTE: one can use binary_crossentropy even for (B, T, C) shape.
#       It will calculate binary_crossentropies for each class
#       and then take the mean over axis=-1. Tre results is (B, T).
if target_repl:
    loss = ['binary_crossentropy'] * 2
    loss_weights = [1 - args.target_repl_coef, args.target_repl_coef]
else:
    loss = 'binary_crossentropy'
    loss_weights = None

model.compile(optimizer=optimizer_config,
              loss=loss,
              loss_weights=loss_weights)
model.summary()

# Load model weights
n_trained_chunks = 0
if args.load_state != "":
    model.load_weights(args.load_state)
    n_trained_chunks = int(re.match(".*epoch([0-9]+).*", args.load_state).group(1))


# Read data
train_raw = utils.load_data(train_reader, discretizer, normalizer, args.small_part)
val_raw = utils.load_data(val_reader, discretizer, normalizer, args.small_part)
test_raw = utils.load_data(test_reader, discretizer, normalizer, args.small_part)

if target_repl:
    T = train_raw[0][0].shape[0]

    def extend_labels(data):
        data = list(data)
        labels = np.array(data[1])  # (B,)
        data[1] = [labels, None]
        data[1][1] = np.expand_dims(labels, axis=-1).repeat(T, axis=1)  # (B, T)
        data[1][1] = np.expand_dims(data[1][1], axis=-1)  # (B, T, 1)
        return data

    train_raw = extend_labels(train_raw)
    val_raw = extend_labels(val_raw)

if args.mode == 'train':

    # Prepare training
    path = os.path.join(args.output_dir, 'keras_states/' + model.final_name + '.epoch{epoch}.test{val_loss}.state')
    # make sure save directory exists
    dirname = os.path.dirname(path)
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    saver = ModelCheckpoint(path, verbose=1, period=args.save_every)

    metrics_callback = keras_utils.InHospitalMortalityMetrics(train_data=train_raw,
                                                          val_data=val_raw,
                                                          target_repl=(args.target_repl_coef > 0),
                                                          batch_size=args.batch_size,
                                                          verbose=args.verbose)


    keras_logs = os.path.join(args.output_dir, 'keras_logs')
    if not os.path.exists(keras_logs):
        os.makedirs(keras_logs)
    csv_logger = CSVLogger(os.path.join(keras_logs, model.final_name + '.csv'),
                           append=True, separator=';')

    print("==> training")
    #model.fit(x=train_raw[0],
    #          y=train_raw[1],
    #          validation_data=val_raw,
    #          epochs=n_trained_chunks + args.epochs,
    #          initial_epoch=n_trained_chunks,
    #          callbacks=[metrics_callback, saver, csv_logger],
    #          shuffle=True,
    #          verbose=args.verbose,
    #          batch_size=args.batch_size)
    steps = math.ceil(train_raw[0].shape[0] / args.batch_size)

    @tf.function
    def train_step(x, y):
        with backprop.GradientTape() as tape:
            logits = model(x, training=True)
            loss_value = model.compiled_loss(y, logits)#, regularization_losses=model.losses)
        grads = tape.gradient(loss_value, model.trainable_variables)
        model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return loss_value

    train_dataset = tf.data.Dataset.from_tensor_slices((train_raw[0][:,:,:38], train_raw[1].reshape(-1,1)))
    train_dataset = train_dataset.batch(args.batch_size)
    losses = []
    accs_train = []
    accs_test = []
    for epoch in range(args.epochs):
        model.reset_metrics()
        #train_dataset = train_dataset.shuffle(buffer_size=train_raw[0][:,:,:38].shape[0])
        print("\nStart of epoch %d" % (epoch,))
        #model.fit(x=train_raw[0][:,:,:38],
        #          y=train_raw[1],
        #          validation_data=val_raw,
        #          epochs=1,
        #          initial_epoch=0,
        #          callbacks=[metrics_callback, saver, csv_logger],
        #          shuffle=True,
        #          verbose=args.verbose,
        #          batch_size=args.batch_size)

        # Iterate over the batches of the dataset.
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            loss_value = train_step(x_batch_train, y_batch_train)
            losses.append(loss_value)
            pickle.dump(losses, open('losses_main2.p', 'wb'))

            # Log every 200 batches.
            if step % 200 == 0:
                print(
                    "Training loss (for one batch) at step %d: %.4f"
                    % (step, float(loss_value))
                )
                print("Seen so far: %s samples" % ((step + 1) * args.batch_size))

        print("==> prediciting")
        predictions = model.predict(train_raw[0][:,:,:38], batch_size=args.batch_size, verbose=1)
        predictions = np.array(predictions)[:, 0]
        ret = metrics.print_metrics_binary(train_raw[1], predictions)
        accs_train.append(ret["acc"])
        pickle.dump(accs_train, open('accs_train_main2.p', 'wb'))

        predictions = model.predict(test_raw[0][:,:,:38], batch_size=args.batch_size, verbose=1)
        predictions = np.array(predictions)[:, 0]
        ret = metrics.print_metrics_binary(test_raw[1], predictions)
        accs_test.append(ret["acc"])
        pickle.dump(accs_test, open('accs_test_main2.p', 'wb'))

elif args.mode == 'test':

    # ensure that the code uses test_reader
    del train_reader
    del val_reader
    del train_raw
    del val_raw

    test_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'test'),
                                            listfile=os.path.join(args.data, 'test_listfile.csv'),
                                            period_length=48.0)
    ret = utils.load_data(test_reader, discretizer, normalizer, args.small_part,
                          return_names=True)

    data = ret["data"][0]
    labels = ret["data"][1]
    names = ret["names"]

    predictions = model.predict(data, batch_size=args.batch_size, verbose=1)
    predictions = np.array(predictions)[:, 0]
    metrics.print_metrics_binary(labels, predictions)

    path = os.path.join(args.output_dir, "test_predictions", os.path.basename(args.load_state)) + ".csv"
    utils.save_results(names, predictions, labels, path)

else:
    raise ValueError("Wrong value for args.mode")
