import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from data_generators.es_generator import ESBatchGenerator
from data_engineering.transform_data import transform_data
from data_engineering.analyze_data import analyze_fold
from utils.metrics import classification_cost
import time
import copy

EPS = 1e-8

def compute_rl_cost_tables(es_rnn, data, labels, alpha):
    N = data[1].shape[0]
    L = data[1].shape[1]
    es_rnn_predictions = es_rnn.predict(data)
    mcc_est = 1 - np.amax(es_rnn_predictions, axis=2)
    mcc_emp = classification_cost(es_rnn_predictions, labels, L)
    sample_cost = np.tile(np.expand_dims(np.linspace(1, L, num=L), axis=0), (N, 1)) / L
    cost_emp = mcc_emp + alpha * sample_cost
    cost_est = mcc_est + alpha * sample_cost

    return cost_emp, cost_est, es_rnn_predictions

def build_es_rnn_model(config, L, F):

    signal_tensor = keras.layers.Input(shape=(L, F), name='input_signal')
    time_tensor = keras.layers.Input(shape=(L, 1), name='input_time')
    input_tensor = keras.layers.Concatenate(axis=2)([time_tensor, signal_tensor])
    rnn_out = keras.layers.BatchNormalization(axis=2)(input_tensor)

    for i in range(config['num_stacked_layers']):
        rnn_out = keras.layers.GRU(units=config['units_hidden'],
                                   return_sequences=True,
                                   activation='tanh')(rnn_out)
        rnn_out = keras.layers.BatchNormalization(axis=2)(rnn_out)

    proba_out = keras.layers.TimeDistributed(keras.layers.Dense(units=config['num_classes'],
                                     activation='softmax',
                                     name='dense'))(rnn_out)

    es_rnn = keras.Model([time_tensor, signal_tensor], proba_out)

    return es_rnn


def train_es_rnn_model(config, data_stats_dict, alpha, transform_str=None):

    nfolds = len(data_stats_dict['training_folds'])
    L = data_stats_dict['training_folds'][0][0].shape[1]
    F = data_stats_dict['training_folds'][0][0].shape[2]

    # update cost matrix for use by other models
    rl_data_stats_dict_emp = copy.deepcopy(data_stats_dict)
    rl_data_stats_dict_est = copy.deepcopy(data_stats_dict)

    for i in range(nfolds):

        data_stats = None
        if transform_str is not None:
            data_stats = data_stats_dict[transform_str][i]

        # TRANSFORM INPUT DATA
        transformed_train_data = transform_data(data_stats_dict['training_folds'][i][0], data_stats)
        train_labels = data_stats_dict['training_folds'][i][1]
        transformed_val_data = transform_data(data_stats_dict['validation_folds'][i][0], data_stats)
        val_labels = data_stats_dict['validation_folds'][i][1]
        transformed_test_data = transform_data(data_stats_dict['test_folds'][i][0], data_stats)
        test_labels = data_stats_dict['test_folds'][i][1]

        #BUILD MODEL
        es_rnn = build_es_rnn_model(config, L, F)

        # COMPILE MODEL
        es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=0, patience=5)
        es_rnn.compile(loss=keras.losses.SparseCategoricalCrossentropy(),
                       optimizer=keras.optimizers.Adam(lr=config['es_lr'], clipnorm=config['clipnorm']))

        # FIT MODEL
        es_train_generator = ESBatchGenerator(transformed_train_data, train_labels, config, config['batch_size'], randomize=True)
        es_val_generator = ESBatchGenerator(transformed_val_data, val_labels, config, transformed_val_data.shape[0], randomize=False)
        start_train_time = time.time()
        es_history = es_rnn.fit(es_train_generator,
                                validation_data=es_val_generator,
                                callbacks=[es_callback],
                                epochs=config['es_epochs'], shuffle=False, verbose=0)
        end_train_time = time.time()

        #Train data
        train_t = np.tile(np.expand_dims(np.arange(L, dtype=float), axis=(0, 2)), (transformed_train_data.shape[0], 1, 1)) / L
        cost_emp, cost_est, es_preds = compute_rl_cost_tables(es_rnn, [train_t, transformed_train_data], train_labels, alpha)
        rl_data_stats_dict_emp['training_folds'][i][1] = cost_emp
        rl_data_stats_dict_est['training_folds'][i][1] = cost_est
        rl_data_stats_dict_emp['training_folds'][i][0] = np.concatenate([rl_data_stats_dict_emp['training_folds'][i][0], es_preds], axis=2)
        rl_data_stats_dict_est['training_folds'][i][0] = np.concatenate([rl_data_stats_dict_est['training_folds'][i][0], es_preds], axis=2)
        _, rl_data_stats_dict_est['mu_sigma'][i] = analyze_fold(rl_data_stats_dict_est['training_folds'][i][0])
        _, rl_data_stats_dict_emp['mu_sigma'][i] = analyze_fold(rl_data_stats_dict_emp['training_folds'][i][0])
        #Validation data
        val_t = np.tile(np.expand_dims(np.arange(L, dtype=float), axis=(0, 2)), (transformed_val_data.shape[0], 1, 1)) / L
        cost_emp, cost_est, es_preds = compute_rl_cost_tables(es_rnn, [val_t, transformed_val_data], val_labels, alpha)
        rl_data_stats_dict_emp['validation_folds'][i][1] = cost_emp
        rl_data_stats_dict_est['validation_folds'][i][1] = cost_est
        rl_data_stats_dict_emp['validation_folds'][i][0] = np.concatenate([rl_data_stats_dict_emp['validation_folds'][i][0], es_preds], axis=2)
        rl_data_stats_dict_est['validation_folds'][i][0] = np.concatenate([rl_data_stats_dict_est['validation_folds'][i][0], es_preds], axis=2)
        #Test data
        test_t = np.tile(np.expand_dims(np.arange(L, dtype=float), axis=(0, 2)), (transformed_test_data.shape[0], 1, 1)) / L
        cost_emp, cost_est, es_preds = compute_rl_cost_tables(es_rnn, [test_t, transformed_test_data], test_labels, alpha)
        rl_data_stats_dict_emp['test_folds'][i][1] = cost_emp
        rl_data_stats_dict_est['test_folds'][i][1] = cost_est
        rl_data_stats_dict_emp['test_folds'][i][0] = np.concatenate([rl_data_stats_dict_emp['test_folds'][i][0], es_preds], axis=2)
        rl_data_stats_dict_est['test_folds'][i][0] = np.concatenate([rl_data_stats_dict_est['test_folds'][i][0], es_preds], axis=2)

    return rl_data_stats_dict_emp, rl_data_stats_dict_est