import numpy as np
import pandas as pd

import os

import json

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score

from crtod.methods.transformers import ODRegressionEncoderTransformerModel, ODTransformerModel
from crtod.data import make_input_roll
from crtod.methods.embeddings import IdentityEmbedding

import torch

from d3m.metadata.pipeline import Pipeline
from axolotl.backend.simple import SimpleRunner

from tods import generate_dataset, load_pipeline, evaluate_pipeline




save_path = './Unpredictability_Results'




# importing the pipelines from the TODS folder, as well as the datasets
data_type = 'real_world'
datasets = ['swan_sf','water_quality']

all_pipeline_folder =  path_to_TODS + 'tods/benchmark/synthetic/Pipeline/'

pipeline_dict = {}

pipeline_folder_list_to_test = ['ocsvm_pipeline_default',
                                 'GBReg_win3',
                                 'GBReg_win5',
                                 'iforest_subseg',
                                'lstm_pipeline_layer2_dim32',
                                'lstm_pipeline_layer2_dim64',
                                 'lstm_pipeline_layer5_dim32',
                                 'GBReg_subseg',
                                 'ar_pipeline_win10',
                                 'lstm_pipeline_default',
                                 'ar_pipeline_default',
                                 'lstm_pipeline_layer10_dim64',
                                 'ae_subseg',
                                 'GBReg_subseg_win5',
                                 'ae_pipeline_default',
                                 'GBReg_subseg_win3',
                                 'ocsvm_subseg',
                                 'ae_pipeline_14841',
                                 'lstm_pipeline_layer5_dim64',
                                 'iforest_pipeline_default',
                                 'GBReg_default',
                                 'ar_pipeline_win5',
                                 'ae_pipeline_5321616325',
                                 'ae_pipeline_32161632',
                                 'lstm_pipeline_layer10_dim32',
                                 'ae_pipeline_1416321641']


for pipeline_folder in pipeline_folder_list_to_test:
    pipelines_path = all_pipeline_folder + pipeline_folder
    pipeline_list = os.listdir(pipelines_path)
    
    pipeline_name = pipeline_list[-1]
    final_pipeline_path = '{}/{}'.format(pipelines_path, pipeline_name)   
    pipeline = load_pipeline(final_pipeline_path)
    evaluation_metric = 'F1'
    
    pipeline_name_trimmed = pipeline_name.split('_con')[0]
    pipeline_dict[pipeline_name_trimmed] = pipeline
        
all_data_folder =  path_to_TODS + 'tods/benchmark/realworld_data/data'
data_name_list = os.listdir(all_data_folder)

data_dict = {}
data_dict_csv = {}

for data_name in data_name_list:
    if 'csv' in data_name and data_name.split('.csv')[0] in datasets:
        data_path = '{}/{}'.format(all_data_folder, data_name)
        df = pd.read_csv(data_path)
        scaler = StandardScaler()
        df.iloc[:,1:] = scaler.fit_transform(df.iloc[:,1:])

        target_index = 0
        dataset = generate_dataset(df, target_index)
        
        data_dict_csv[data_name] = df
        data_dict[data_name] = dataset


run_number = 0
result_dict = {}

models_test = {'water_quality.csv': ['ocsvm',
                                 'iforest',
                                 'ae',
                                 'ar',
                                 #'GBReg',
                                 'lstm',
                                 'transformer_encoder',
                                 'transformer_encoder_decoder'],
               'swan_sf.csv': [#'ocsvm',
                             'iforest',
                             'ae',
                             'ar',
                             #'GBReg',
                             #'lstm',
                             'transformer_encoder',
                             'transformer_encoder_decoder']
              }



# run the TODS pipelines?
run_pipelines = True

contamination_list = ['0.05', '0.1', '0.15', '0.2', '0.25']


# running the TODS pipelines on the data and saving the predicted outliers and scores as well as the F1.
while run_number < 5:
    
    if run_pipelines:
        path = '{}/realworld_predictions/{}_purpose_built_{}_hypersearch.json'.format(save_path, data_type, str(run_number))
        for pipeline_name in pipeline_dict.keys():
            if not pipeline_name in result_dict:
                result_dict[pipeline_name] = {}
            
            for contamination in contamination_list:
                if not str(contamination) in result_dict[pipeline_name]:
                    result_dict[pipeline_name][str(contamination)] = {}
            
            
            for data_name in data_dict.keys():
                
                if not np.any([model in pipeline_name for model in models_test[data_name]]):
                    continue
                
                for contamination in contamination_list:
                    if not data_name in result_dict[pipeline_name][str(contamination)]:
                        result_dict[pipeline_name][str(contamination)][data_name] = {}
                
                if np.all([str(run_number) in result_dict[pipeline_name][str(contamination)][data_name] for contamination in contamination_list]):
                    continue
                
                dataset = data_dict[data_name]
                pipeline = pipeline_dict[pipeline_name]
                    
                evaluation_metric = 'F1'
                pipeline_result = evaluate_pipeline(dataset, pipeline, evaluation_metric)
                
                # the pipeline was prone to erroring, so this re-runs it if it errored
                if not pipeline_result.error is None:
                    if len(pipeline_result.error) != 0:
                        pipeline_result = evaluate_pipeline(dataset, pipeline, evaluation_metric)
                
                proba = np.load('result_temp.npy')[:,-1]
                Y = data_dict_csv[data_name].iloc[:,0].to_numpy()
                
                for contamination in contamination_list:
                            
                    
                    top_k = int(float(contamination)*Y.shape[0])
                    
                    i_outliers = np.argpartition(proba, -top_k)[-top_k:]
                    i_outliers = i_outliers[np.argsort(proba[i_outliers])[::-1]]
                    predict = np.zeros(Y.shape[0])
                    predict[i_outliers] = 1
                    f1 = f1_score(Y, predict)
                    result_dict[pipeline_name][str(contamination)][data_name][str(run_number)] = f1
                    print(f1)

                    # saving the results
                    results_file = open(path, "w")
                    results_file = json.dump(result_dict, results_file)

                    proba = np.load('result_temp.npy')[:,-1]
                    pd.DataFrame(proba, columns = ['proba']).to_csv('./Unpredictability_Results/' \
                                                                    'realworld_predictions/' \
                                                                    '{}_con_{}_proba_{}_{}_hypersearch.csv'.format(pipeline_name.split('.json')[0], contamination,
                                                                                                data_name.split('.csv')[0], 
                                                                                                str(run_number)))

                    pd.DataFrame(predict, columns = ['predict']).to_csv('./Unpredictability_Results/' \
                                                                    'realworld_predictions/' \
                                                                    '{}_con_{}_{}_{}_hypersearch.csv'.format(pipeline_name.split('.json')[0], contamination,
                                                                                                data_name.split('.csv')[0], 
                                                                                                str(run_number)))

                

    
    run_number += 1



result_dict_transformer_encoder_decoder = {}
model_name = 'transformer_encoder_decoder'

# run the transformer encoder decoder tests?
run_tests = True

# transformer encoder decoder hyper-params
hyperparams_change = [{'nlayers': 2, 'dim_feedforward': 32}, 
                      {'nlayers': 2, 'dim_feedforward': 64}, 
                      {'nlayers': 5, 'dim_feedforward': 32}, 
                      {'nlayers': 5, 'dim_feedforward': 64}]

model_name = 'transformer_encoder_decoder'

# running the transformer encoder-decoder on the data and saving the predicted outliers and scores as well as the F1.
for run_number in range(5):
    if run_tests:
        
        for hyperparams_run_dict in hyperparams_change:

            model_name_hyper = model_name + '_layers{}_dff{}'.format(hyperparams_run_dict['nlayers'], hyperparams_run_dict['dim_feedforward'])
            path = '{}/realworld_predictions/{}_mseloss_{}.json'.format(save_path, model_name, str(run_number))
        
            contamination_list = ['0.05', '0.1', '0.15', '0.2', '0.25']

            # consistent hyper-params
            hyperparams_dict = {'sequence_length':5,
                                'predicted_sequence_length': 3, 
                                 'dim_feedforward': hyperparams_run_dict['dim_feedforward'], 
                                 'num_encoder_layers': hyperparams_run_dict['nlayers'], 
                                'num_decoder_layers': hyperparams_run_dict['nlayers'],
                                 'batch_size': 30, 
                                 'dropout': 0.1, 
                                'epochs': 20, 
                                'verbose': True, 
                                'learning_rate': 0.01}
            
            if not model_name_hyper in result_dict_transformer_encoder_decoder:
                result_dict_transformer_encoder_decoder[model_name_hyper] = {}
            
            for contamination in contamination_list:
                if not str(contamination) in result_dict_transformer_encoder_decoder[model_name_hyper]:
                    result_dict_transformer_encoder_decoder[model_name_hyper][str(contamination)] = {}
                for file_name in data_dict_csv.keys():
                    if not file_name in result_dict_transformer_encoder_decoder[model_name_hyper][str(contamination)]:
                        result_dict_transformer_encoder_decoder[model_name_hyper][str(contamination)][file_name] = {}
            
            
            
            
            for nf, file_name in enumerate(data_dict_csv.keys()):
                
                if np.all([str(run_number) in result_dict_transformer_encoder_decoder[model_name_hyper][str(contamination)][file_name] for contamination in contamination_list]):
                    continue
                
                X = data_dict_csv[file_name].iloc[:,1:].to_numpy()
                Y = data_dict_csv[file_name].iloc[:,0].to_numpy()

                scaler = StandardScaler()
                X = scaler.fit_transform(X)

                sequence_length = hyperparams_dict['sequence_length']


                od_encoder_decoder = ODTransformerModel(sequence_length = hyperparams_dict['sequence_length'],
                                                    predict_sequence_length = hyperparams_dict['predicted_sequence_length'], 
                                                      d_model=X.shape[-1], 
                                                      nhead=X.shape[-1], 
                                                      num_encoder_layers=hyperparams_dict['num_encoder_layers'], 
                                                     num_decoder_layers=hyperparams_dict['num_decoder_layers'], 
                                                      dim_feedforward=hyperparams_dict['dim_feedforward'], 
                                                      dropout=hyperparams_dict['dropout'], 
                                                     activation='relu',
                                                     layer_norm_eps=1e-04, 
                                                     device=device,
                                                     embedding_class = IdentityEmbedding, 
                                                      embedding_args = {},
                                                     learning_rate = hyperparams_dict['learning_rate'], 
                                                      epochs = hyperparams_dict['epochs'], 
                                                      batch_size = hyperparams_dict['batch_size'], 
                                                      verbose = hyperparams_dict['verbose'])

                od_encoder_decoder.fit(X)
                
                proba = od_encoder_decoder.predict_proba(X)
                
                for contamination in contamination_list:
                    

                    top_k = int(float(contamination)*Y.shape[0])

                    i_outliers = np.argpartition(proba, -top_k)[-top_k:]
                    i_outliers = i_outliers[np.argsort(proba[i_outliers])[::-1]]
                    predict = np.zeros(Y.shape[0])
                    predict[i_outliers] = 1

                    f1 = f1_score(Y,predict)
                    if hyperparams_dict['verbose']: print(f1)

                    result_dict_transformer_encoder_decoder[model_name_hyper][str(contamination)][file_name][str(run_number)] = f1

                    # saving the results
                    pd.DataFrame(predict, columns = ['predict']).to_csv('./Unpredictability_Results/' \
                                                                        'realworld_predictions/' \
                                                                        '{}_con_{}_{}_{}.csv'.format(model_name_hyper, contamination, file_name.split('.csv')[0], str(run_number)))

                    pd.DataFrame(proba, columns = ['proba']).to_csv('./Unpredictability_Results/' \
                                                            'realworld_predictions/' \
                                                            '{}_con_{}_proba_{}_{}.csv'.format(model_name_hyper, contamination, file_name.split('.csv')[0], str(run_number)))

                    results_file = open(path, "w")
                    results_file = json.dump(result_dict_transformer_encoder_decoder, results_file)

    else:
        path = '{}/realworld_predictions/{}_mseloss_{}.json'.format(save_path, model_name, str(run_number))
        with open(path) as a_file:
            result_dict_transformer_encoder_decoder = json.load(a_file)

result_dict_transformer_encoder = {}
model_name = 'transformer_encoder'

# run the transformer encoder tests?
run_tests = True

# transformer encoder hyper-params
hyperparams_change = [{'nlayers': 2, 'dim_feedforward': 32}, 
                      {'nlayers': 2, 'dim_feedforward': 64}, 
                      {'nlayers': 5, 'dim_feedforward': 32}, 
                      {'nlayers': 5, 'dim_feedforward': 64}]

model_name = 'transformer_encoder'

# running the transformer encoder on the data and saving the predicted outliers and scores as well as the F1.
for run_number in range(5):
    if run_tests:
        
        for hyperparams_run_dict in hyperparams_change:

            model_name_hyper = model_name + '_layers{}_dff{}'.format(hyperparams_run_dict['nlayers'], hyperparams_run_dict['dim_feedforward'])
            path = '{}/realworld_predictions/{}_mseloss_{}.json'.format(save_path, model_name, str(run_number))
        
            contamination_list = ['0.05', '0.1', '0.15', '0.2', '0.25']

            # consistent hyper-params
            hyperparams_dict = {'sequence_length': 5, 
                                 'dim_feedforward': hyperparams_run_dict['dim_feedforward'], 
                                 'nlayers':  hyperparams_run_dict['nlayers'], 
                                 'batch_size': 30, 
                                 'dropout': 0.1, 
                                    'epochs': 20, 
                                'learning_rate': 0.01,
                                   'verbose': True}
            
            
            if not model_name_hyper in result_dict_transformer_encoder:
                result_dict_transformer_encoder[model_name_hyper] = {}
            
            for contamination in contamination_list:
                if not str(contamination) in result_dict_transformer_encoder[model_name_hyper]:
                    result_dict_transformer_encoder[model_name_hyper][str(contamination)] = {}
                for file_name in data_dict_csv.keys():
                    if not file_name in result_dict_transformer_encoder[model_name_hyper][str(contamination)]:
                        result_dict_transformer_encoder[model_name_hyper][str(contamination)][file_name] = {}
            
            
            
            
            for nf, file_name in enumerate(data_dict_csv.keys()):

                
                if np.all([str(run_number) in result_dict_transformer_encoder[model_name_hyper][str(contamination)][file_name] for contamination in contamination_list]):
                    continue
                
                
                X = data_dict_csv[file_name].iloc[:,1:].to_numpy()
                Y = data_dict_csv[file_name].iloc[:,0].to_numpy()

                scaler = StandardScaler()
                X = scaler.fit_transform(X)

                sequence_length = hyperparams_dict['sequence_length']


                od_encoder = ODRegressionEncoderTransformerModel(sequence_length = hyperparams_dict['sequence_length'], 
                                                      embedding_dim = X.shape[-1], 
                                                      nhead = X.shape[-1], 
                                                      dim_feedforward = hyperparams_dict['dim_feedforward'], 
                                                      nlayers = hyperparams_dict['nlayers'], 
                                                      batch_size = hyperparams_dict['batch_size'], 
                                                      epochs = hyperparams_dict['epochs'], 
                                                      dropout = hyperparams_dict['dropout'], 
                                                      verbose = hyperparams_dict['verbose'], 
                                                      learning_rate = hyperparams_dict['learning_rate'],
                                                      device = device)

                od_encoder.fit(X)
                
                proba = od_encoder.predict_proba(X)
                
                for contamination in contamination_list:
                    

                    top_k = int(float(contamination)*Y.shape[0])

                    i_outliers = np.argpartition(proba, -top_k)[-top_k:]
                    i_outliers = i_outliers[np.argsort(proba[i_outliers])[::-1]]
                    predict = np.zeros(Y.shape[0])
                    predict[i_outliers] = 1

                    f1 = f1_score(Y,predict)
                    if hyperparams_dict['verbose']: print(f1)

                    result_dict_transformer_encoder[model_name_hyper][str(contamination)][file_name][str(run_number)] = f1

                    # saving the results
                    pd.DataFrame(predict, columns = ['predict']).to_csv('./Unpredictability_Results/' \
                                                                        'realworld_predictions/' \
                                                                        '{}_con_{}_{}_{}.csv'.format(model_name_hyper, contamination, file_name.split('.csv')[0], str(run_number)))

                    pd.DataFrame(proba, columns = ['proba']).to_csv('./Unpredictability_Results/' \
                                                            'realworld_predictions/' \
                                                            '{}_con_{}_proba_{}_{}.csv'.format(model_name_hyper, contamination, file_name.split('.csv')[0], str(run_number)))

                    results_file = open(path, "w")
                    results_file = json.dump(result_dict_transformer_encoder, results_file)

    else:
        path = '{}/realworld_predictions/{}_mseloss_{}.json'.format(save_path, model_name, str(run_number))
        with open(path) as a_file:
            result_dict_transformer_encoder = json.load(a_file)

