import numpy as np
import tensorflow.keras as keras
import tensorflow.keras.backend as k
from data_generators.os_generator import OSBatchGenerator
from data_engineering.transform_data import transform_data
from utils.metrics import calculate_stopping_reward

import time


EPS = 1e-8
SEED = 42



def rrlsm_loss(reward, phi):
    J = phi[:, 0] * reward[:, 0] + (1 - phi[:, 0]) * reward[:, 1]
    loss = -k.mean(J)
    return loss


def build_rrlsm_model(config, L, F):

    signal_tensor = keras.layers.Input(shape=(L, F), name='input_signal')
    time_tensor = keras.layers.Input(shape=(L, 1), name='input_time')
    if config['include_R']:
        reward_tensor = keras.layers.Input(shape=(L, 1), name='input_reward')
        input_tensor = keras.layers.Concatenate(axis=2)([time_tensor, reward_tensor, signal_tensor])
    else:
        input_tensor = keras.layers.Concatenate(axis=2)([time_tensor, signal_tensor])
    norm_tensor = keras.layers.BatchNormalization(axis=2, trainable=False)(input_tensor)

    for i in range(config['num_stacked_layers']):
        rrlsm_out = keras.layers.SimpleRNN(units=config['units_hidden'],
                                           trainable=False,
                                           return_sequences=True,
                                           kernel_initializer=keras.initializers.RandomNormal(stddev=config['ker_std'], seed=SEED),
                                           recurrent_initializer=keras.initializers.RandomNormal(stddev=config['rec_std'], seed=SEED + 1),
                                           activation='tanh')(norm_tensor)

    if config['include_R']:
        rrlsm_model = keras.Model([time_tensor, reward_tensor, signal_tensor], [rrlsm_out])
    else:
        rrlsm_model = keras.Model([time_tensor, signal_tensor], [rrlsm_out])
    rrlsm_model.trainable = False

    return rrlsm_model


def train_rrlsm_model(config, data_stats_dict, transform_str=None, is_reward_flag=1):

    nfolds = len(data_stats_dict['training_folds'])
    L = data_stats_dict['training_folds'][0][0].shape[1]
    F = data_stats_dict['training_folds'][0][0].shape[2]

    rrlsm_rewards = []
    rrlsm_reward_idxs = []
    prediction_time_per_ts = []
    train_times = []


    for i in range(nfolds):

        data_stats = None
        if transform_str is not None:
            data_stats = data_stats_dict[transform_str][i]

        # TRANSFORM INPUT DATA
        transformed_train_data = transform_data(data_stats_dict['training_folds'][i][0], data_stats, transform_str)
        train_rewards = data_stats_dict['training_folds'][i][1]
        transformed_test_data = transform_data(data_stats_dict['test_folds'][i][0], data_stats, transform_str)
        test_rewards = data_stats_dict['test_folds'][i][1]

        # BUILD MODEL
        rrlsm_model = build_rrlsm_model(config, L, F)
        # COMPILE MODEL (DUMMY since weights are fixed. Compile needed to predict)
        rrlsm_model.compile(loss=rrlsm_loss, optimizer=keras.optimizers.Adam(lr=0.001, clipnorm=5)) # dummy compile
        rrlsm_train_prediction_generator = OSBatchGenerator(transformed_train_data, train_rewards, config, batch_size=32, randomize=False)
        rrlsm_test_prediction_generator = OSBatchGenerator(transformed_test_data, test_rewards, config, batch_size=32, randomize=False)

        rrlsm_models = []
        train_value = train_rewards[:, L-1]
        test_interventions = np.ones((test_rewards.shape[0], test_rewards.shape[1]))
        start_train_time = time.time()
        train_h = rrlsm_model.predict(rrlsm_train_prediction_generator)
        for n in range(L-2, -1, -1):
            A = np.concatenate([train_h[:, n, :], np.ones((train_h.shape[0], 1))], axis=1)
            theta = np.linalg.lstsq(A, train_value)[0]
            if(is_reward_flag == 1):
                stop = (train_rewards[:, n] >= np.dot(A, theta))*1
            else:
                stop = (train_rewards[:, n] <= np.dot(A, theta)) * 1

            train_value = train_rewards[:, n] * stop + train_value * (1 - stop)
            # STORE MODELS FOR INFERENCE
            rrlsm_models.append(theta)
        end_train_time = time.time()

        # MODEL INFERENCE
        start_predict_time = time.time()
        test_h = rrlsm_model.predict(rrlsm_test_prediction_generator)
        for n in range(L-1):
            A = np.concatenate([test_h[:, n, :], np.ones((test_h.shape[0], 1))], axis=1)
            theta = rrlsm_models[L-2-n]
            if(is_reward_flag == 1):
                test_interventions[:, n] = (test_rewards[:, n] >= np.dot(A, theta)) * 1
            else:
                test_interventions[:, n] = (test_rewards[:, n] <= np.dot(A, theta)) * 1
        test_interventions[:, -1] = 1
        end_predict_time = time.time()

        stopping_reward, stop_idxs = calculate_stopping_reward(0.5, test_interventions, test_rewards)

        print(str(stopping_reward))
        rrlsm_rewards.append(stopping_reward)
        prediction_time_per_ts.append((end_predict_time-start_predict_time) * (10**3) / (test_interventions.shape[0] * L))
        train_times.append((end_train_time-start_train_time))
        rrlsm_reward_idxs.append(stop_idxs)


    rrlsm_results = {'rrlsm_rewards': rrlsm_rewards,
                     'rrlsm_stop_idxs': rrlsm_reward_idxs,
                     'prediction_time_per_ts': prediction_time_per_ts,
                     'train_times': train_times}


    return rrlsm_results
