import numpy as np
import tensorflow as tf
print(tf.__version__)
import random as rn


import pandas as pd


from data_engineering.analyze_data import analyze
from data_engineering.transform_data import transform_data
from models.ospg_model import train_ospg_model
from models.dos_model import train_dos_model
from models.fqi_model import train_fqi_model
from models.es_rnn_model import train_es_rnn_model
from models.rrlsm_model import train_rrlsm_model
from utils.metrics import classification_cost, get_ec_cost_components

from tslearn.datasets import UCR_UEA_datasets

#REPRODUCIBILITY
SEED = 42
import os
import time
os.environ['PYTHONHASHSEED'] = '0'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
np.random.seed(SEED)
rn.seed(SEED)
tf.random.set_seed(SEED)

RESULTS_DIR = './'
EXPERIMENT_DF_NAME = 'stop_classifier_experiment.csv'


TRAIN_FRACTION = 0.7
VAL_FRACTION = 0.3


FOLDS = 10
DATASET_NAME = ['CBF', 'ChlorineConcentration', 'Crop', 'ECG5000', 'ElectricDevices', 'FaceAll', 'FacesUCR','FiftyWords',
                'InsectWingbeatSound', 'MedicalImages', 'MelbournePedestrian', 'MixedShapesRegularTrain', 'NonInvasiveFetalECGThorax2',
                'StarLightCurves', 'Symbols', 'UWaveGestureLibraryX', 'WordSynonyms']
ALPHA = [0.8, 0.4, 0.06, 0.5, 0.1, 0.01, 0.5, 0.5,
         1, 0.07, 0.8, 0.1, 0.04,
         0.3, 0.2, 0.5, 0.6]



def main():

    start_time = time.time()
    experiment_df = pd.DataFrame(
    {
        'algorithm': [],
        'dataset_name': [],
        'alpha': [],
        'num_trajectories': [],
        'num_classes': [],
        'fold': [],
        'avg_cost': [],
        'accuracy': [],
        'earliness': [],
        'train_time': [],
        'prediction_time_per_ts': []
    })

    experiments = list(zip(DATASET_NAME, ALPHA))
    exp_ctr = 0
    exp_row = 0
    for experiment_params in experiments:
        exp_ctr = exp_ctr + 1
        X_train, y_train, X_test, y_test = UCR_UEA_datasets().load_dataset(experiment_params[0])
        X = np.concatenate((X_train, X_test), axis=0)
        y = np.concatenate((y_train, y_test), axis=0)-1

        data_stats_dict = analyze(X, y, TRAIN_FRACTION, VAL_FRACTION, FOLDS, stratify=True)
        num_classes = len(np.unique(y))
        L = X.shape[1]

        rrlsm_config = {
            'samples_per_epoch': 200,
            'num_stacked_layers': 1,
            'ker_std': 0.0001,
            'rec_std': 0.3,
            'include_R': True,
            'units_hidden': 20
        }

        ospg_config = {
            'batch_size': 64,
            'os_epochs': 100,
            'samples_per_epoch': 200,
            'os_lr': 0.001,
            'clipnorm': 5,
            'use_DNN': False,
            'include_R': False,
            'num_stacked_layers': 1,
            'units_hidden': 20
        }

        es_rnn_config = {
            'batch_size': 64,
            'es_epochs': 100,
            'samples_per_epoch': 200,
            'es_lr': 0.001,
            'clipnorm': 5,
            'num_stacked_layers': 1,
            'units_hidden': 20,
            'num_classes': num_classes
        }

        fqi_config = {
            'batch_size': 64,
            'q_epochs': 100,
            'samples_per_epoch': 200,
            'q_lr': 0.001,
            'clipnorm': 5,
            'use_DNN': False,
            'include_R': True,
            'num_stacked_layers': 1,
            'units_hidden': 20
        }


        # ES RNN for Cost Model
        data_stats_dict_emp, data_stats_dict_est = train_es_rnn_model(es_rnn_config, data_stats_dict, experiment_params[1], transform_str=None)


        ## RRLSM
        rrlsm_result = train_rrlsm_model(rrlsm_config, data_stats_dict_est, transform_str=None, is_reward_flag=0)
        for fold in range(FOLDS):
            cost_emp = data_stats_dict_emp['test_folds'][fold][1]
            rrlsm_cost = np.mean(cost_emp[np.arange(len(rrlsm_result['rrlsm_stop_idxs'][fold])), rrlsm_result['rrlsm_stop_idxs'][fold]])
            rrlsm_eps = (np.mean(rrlsm_result['rrlsm_stop_idxs'][fold]) + 1) / L
            rrlsm_perr = rrlsm_cost - experiment_params[1] * rrlsm_eps
            print('rrlsm_cost: ' + str(rrlsm_cost))
            experiment_df.loc[exp_row] = ['RRLSM', experiment_params[0],
                                          experiment_params[1],
                                          X.shape[0], num_classes,
                                          fold, rrlsm_cost,
                                          1 - rrlsm_perr, rrlsm_eps,
                                          rrlsm_result['train_times'][fold],
                                          rrlsm_result['prediction_time_per_ts'][fold]]
            exp_row = exp_row + 1

        ## OSPG
        os_result = train_ospg_model(ospg_config,  data_stats_dict_emp, transform_str=None, is_reward_flag=0)
        for fold in range(FOLDS):

            cost_emp = data_stats_dict_emp['test_folds'][fold][1]
            os_cost = np.mean(cost_emp[np.arange(len(os_result['os_stop_idxs'][fold])), os_result['os_stop_idxs'][fold]])
            os_eps = (np.mean(os_result['os_stop_idxs'][fold]) + 1) / L
            os_perr = os_cost - experiment_params[1] * os_eps
            print('os_cost: ' + str(os_cost))

            experiment_df.loc[exp_row] = ['RNN_OSPG', experiment_params[0],
                                          experiment_params[1],
                                          X.shape[0], num_classes,
                                          fold, os_cost,
                                          1 - os_perr, os_eps,
                                          os_result['train_times'][fold],
                                          os_result['prediction_time_per_ts'][fold]]
            exp_row = exp_row + 1

        ## FQI
        q_result = train_fqi_model(fqi_config, data_stats_dict_est, transform_str=None, is_reward_flag=-0)
        for fold in range(FOLDS):
            cost_emp = data_stats_dict_emp['test_folds'][fold][1]
            q_cost = np.mean(cost_emp[np.arange(len(q_result['q_stop_idxs'][fold])), q_result['q_stop_idxs'][fold]])
            q_eps = (np.mean(q_result['q_stop_idxs'][fold]) + 1) / L
            q_perr = q_cost - experiment_params[1] * q_eps
            print('q_cost: ' + str(q_cost))
            experiment_df.loc[exp_row] = ['RNN_FQI', experiment_params[0],
                                          experiment_params[1],
                                          X.shape[0], num_classes,
                                          fold, q_cost,
                                          1 - q_perr, q_eps,
                                          q_result['train_times'][fold],
                                          q_result['prediction_time_per_ts'][fold]]
            exp_row = exp_row + 1

        print('done_experiment %d of %d' % (exp_ctr, len(experiments)))


    experiment_df.to_csv(RESULTS_DIR+EXPERIMENT_DF_NAME)
    end_time=time.time()
    print('run_time(min): ' + str((end_time-start_time)/60))
    print("done")


if __name__ == '__main__':
    main()