import numpy as np
import tensorflow.keras as keras
import tensorflow.keras.backend as k
from data_generators.os_generator import OSBatchGenerator
from data_generators.os_mc_generator import OSMCBatchGenerator
from data_engineering.transform_data import transform_data
from utils.metrics import calculate_stopping_reward
import tensorflow as tf
import time

EPS = 1e-5


def os_pg_loss(is_reward_flag):
    def loss(reward, phi):
        reward_sum = k.sum(reward, axis=1, keepdims=True)
        reward_distribution = (reward / (reward_sum + EPS))
        reward_distribution = is_reward_flag * reward_distribution + (1 - is_reward_flag) * (1 - reward_distribution)

        log_one_minus_phi = k.log(k.clip(1-phi, EPS, 1))
        log_phi = k.log(k.clip(phi, EPS, 1))

        psi_j = k.exp(k.cumsum(log_one_minus_phi[:, :-2], axis=1) + log_phi[:, 1:-1])
        psi_1 = k.expand_dims(phi[:, 0], axis=1)
        psi_L = k.exp(k.sum(log_one_minus_phi[:, :-1], axis=1, keepdims=True))
        psi = k.concatenate([psi_1, psi_j, psi_L], axis=1)
        r_ij = reward_distribution * reward_sum
        r_a = r_ij - k.mean(r_ij)
        r_psi = k.stop_gradient(r_a * psi)
        pg_loss = k.sum(r_psi * k.log(k.clip(psi, EPS, 1)), axis=1)

        return -k.mean(pg_loss)

    return loss

def stable_os_reward(reward, phi):

    log_one_minus_phi = k.log(k.clip(1-phi, EPS, 1))
    log_phi = k.log(k.clip(phi, EPS, 1))
    psi_j = k.exp(k.cumsum(log_one_minus_phi[:, :-2], axis=1) + log_phi[:, 1:-1])
    loss_j = k.sum(reward[:, 1:-1] * psi_j, axis=1)
    loss_1 = reward[:, 0] * phi[:, 0]
    loss_L = k.exp(k.sum(log_one_minus_phi[:, :-1], axis=1)) * reward[:, -1]
    return k.mean(loss_1 + loss_j + loss_L)


def dnn_ospg(config, input):
    output = input
    for i in range(config['num_stacked_layers']):
        output = keras.layers.TimeDistributed(keras.layers.Dense(units=config['units_hidden'],
                                                                       activation='relu'))(output)
        output = keras.layers.BatchNormalization(axis=2)(output)

    return output


def rnn_ospg(config, input):
    output = input
    for i in range(config['num_stacked_layers']):
        output = keras.layers.GRU(units=config['units_hidden'],
                                    return_sequences=True,
                                    activation='tanh')(output)
        output = keras.layers.BatchNormalization(axis=2)(output)

    return output


def build_ospg_model(config, L, F):

    signal_tensor = keras.layers.Input(shape=(L, F), name='input_signal')
    time_tensor = keras.layers.Input(shape=(L, 1), name='input_time')
    if config['include_R']:
        reward_tensor = keras.layers.Input(shape=(L, 1), name='input_reward')
        input_tensor = keras.layers.Concatenate(axis=2)([time_tensor, reward_tensor, signal_tensor])
    else:
        input_tensor = keras.layers.Concatenate(axis=2)([time_tensor, signal_tensor])
    norm_tensor = keras.layers.BatchNormalization(axis=2)(input_tensor)

    if config['use_DNN']:
        final_hidden = dnn_ospg(config, norm_tensor)
    else:
        final_hidden = rnn_ospg(config, norm_tensor)

    proba_out = keras.layers.TimeDistributed(keras.layers.Dense(units=1,
                                     activation='sigmoid',
                                     name='dense'), name='policy_out')(final_hidden)

    if config['include_R']:
        ospg_model = keras.Model([time_tensor, reward_tensor, signal_tensor], [proba_out])
    else:
        ospg_model = keras.Model([time_tensor, signal_tensor], [proba_out])

    return ospg_model




def train_ospg_model(config, data_stats_dict, transform_str=None, is_reward_flag=1):

    nfolds = len(data_stats_dict['training_folds'])
    L = data_stats_dict['training_folds'][0][0].shape[1]
    F = data_stats_dict['training_folds'][0][0].shape[2]

    os_rewards = []
    os_stop_idxs = []
    prediction_time_per_ts = []
    train_times = []


    for i in range(nfolds):

        data_stats = None
        if transform_str is not None:
            data_stats = data_stats_dict[transform_str][i]

        # TRANSFORM INPUT DATA
        transformed_train_data = transform_data(data_stats_dict['training_folds'][i][0], data_stats, transform_str)
        train_rewards = data_stats_dict['training_folds'][i][1]
        transformed_val_data = transform_data(data_stats_dict['validation_folds'][i][0], data_stats, transform_str)
        val_rewards = data_stats_dict['validation_folds'][i][1]
        transformed_test_data = transform_data(data_stats_dict['test_folds'][i][0], data_stats, transform_str)
        test_rewards = data_stats_dict['test_folds'][i][1]

        #BUILD MODEL
        ospg_model = build_ospg_model(config, L, F)

        # COMPILE MODEL
        if is_reward_flag == 1:
            es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_stable_os_reward', mode='max', verbose=1, patience=5)
        else:
            es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_stable_os_reward', mode='min', verbose=1, patience=5)

        ospg_model.compile(loss=os_pg_loss(is_reward_flag),
                       metrics=[stable_os_reward],
                       optimizer=keras.optimizers.Adam(lr=config['os_lr'], clipnorm=config['clipnorm']))

        # FIT MODEL
        os_train_generator = OSBatchGenerator(transformed_train_data, train_rewards, config, config['batch_size'], randomize=True)
        os_val_generator = OSBatchGenerator(transformed_val_data, val_rewards, config, transformed_val_data.shape[0], randomize=False)
        start_train_time = time.time()
        os_history = ospg_model.fit(os_train_generator,
                                validation_data=os_val_generator,
                                callbacks=[es_callback],
                                epochs=config['os_epochs'], shuffle=False, verbose=0)

        end_train_time = time.time()

        # PREDICT ON TEST SET
        os_test_prediction_generator = OSBatchGenerator(transformed_test_data, test_rewards, config, config['batch_size'], randomize=False)
        start_predict_time = time.time()
        os_test_predictions = ospg_model.predict(os_test_prediction_generator)
        end_predict_time = time.time()
        interventions = (os_test_predictions[:, :, 0] >= np.random.uniform(size=(os_test_predictions.shape[0], L))) * 1
        interventions[:, -1] = 1
        reward, stop_idxs = calculate_stopping_reward(0.5, interventions, test_rewards)

        print(str(reward))
        os_rewards.append(reward)
        prediction_time_per_ts.append((end_predict_time-start_predict_time) * (10**3) / (os_test_predictions.shape[0] * L))
        train_times.append((end_train_time-start_train_time))
        os_stop_idxs.append(stop_idxs)

    os_results = {'os_rewards': os_rewards,
                  'os_stop_idxs': os_stop_idxs,
                  'prediction_time_per_ts': prediction_time_per_ts,
                  'train_times': train_times}


    return os_results

################# MC training for American Option Pricing #############################

def train_mc_ospg_model(config, option_parameters, num_train_folds, num_test_folds):

    os_reward_mean = []
    os_reward_std = []
    os_stop_idxs = []
    prediction_time_per_ts = []
    train_times = []

    L = config['L']
    F = config['F']

    rnd_seed = 0
    for i in range(num_train_folds):

        #BUILD MODEL
        ospg_model = build_ospg_model(config, L, F)
        ospg_model.compile(loss = {'policy_out': os_pg_loss(is_reward_flag=1)},
                       metrics = {'policy_out': stable_os_reward},
                       optimizer=keras.optimizers.Adam(lr=config['os_lr'], clipnorm=config['clipnorm']))

        # FIT MODEL
        os_train_generator = OSMCBatchGenerator(config, option_parameters, config['train_samples_per_epoch'], seed=rnd_seed, mode='train')
        start_train_time = time.time()
        os_history = ospg_model.fit(os_train_generator,
                                epochs=1, shuffle=False, verbose=1)
        end_train_time = time.time()

        # EVALUATE ON TEST FOLDS
        rnd_seed += config['train_samples_per_epoch']
        test_rewards_list = []
        pred_time_list = []
        for j in range(num_test_folds):
            os_test_generator = OSMCBatchGenerator(config, option_parameters, config['test_samples_per_epoch'], seed=(rnd_seed + j*config['test_samples_per_epoch']))
            start_evaluate_time = time.time()
            fold_reward = ospg_model.evaluate(os_test_generator, verbose=0)[1]
            end_evaluate_time = time.time()
            pred_time = (end_evaluate_time - start_evaluate_time) * (10**3) / (config['test_samples_per_epoch']*config['batch_size'])
            test_rewards_list.append(fold_reward)
            pred_time_list.append(pred_time)

        reward_mean = np.mean(test_rewards_list)
        reward_std = np.std(test_rewards_list)
        pred_time_mean = np.mean(pred_time_list)

        print(str(reward_mean))
        os_reward_mean.append(reward_mean)
        os_reward_std.append(reward_std)
        prediction_time_per_ts.append(pred_time_mean)
        train_times.append((end_train_time-start_train_time))

    os_results = {'os_reward_mean': os_reward_mean,
                  'os_reward_std': os_reward_std,
                  'prediction_time_per_ts': prediction_time_per_ts,
                  'train_times': train_times}

    return os_results