import numpy as np
import pandas as pd

import os

import json

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score

from crtod.methods.transformers import ODRegressionEncoderTransformerModel, ODTransformerModel
from crtod.data import make_input_roll
from crtod.methods.embeddings import IdentityEmbedding

import torch

from d3m.metadata.pipeline import Pipeline
from axolotl.backend.simple import SimpleRunner

from tods import generate_dataset, load_pipeline, evaluate_pipeline




save_path = './Unpredictability_Results'




# importing the pipelines from the TODS folder, as well as the datasets
data_type = 'multivariate'
data_type_folder_name = 'multidataset'

all_pipeline_folder =  path_to_TODS + 'tods/benchmark/synthetic/Pipeline/'

pipeline_dict = {}

pipeline_folder_list_to_test = ['ocsvm_pipeline_default',
                                 'GBReg_win3',
                                 'GBReg_win5',
                                 'iforest_subseg',
                                'lstm_pipeline_layer2_dim32',
                                'lstm_pipeline_layer2_dim64',
                                 'lstm_pipeline_layer5_dim32',
                                 'GBReg_subseg',
                                 'ar_pipeline_win10',
                                 'lstm_pipeline_default',
                                 'ar_pipeline_default',
                                 'lstm_pipeline_layer10_dim64',
                                 'ae_subseg',
                                 'GBReg_subseg_win5',
                                 'ae_pipeline_default',
                                 'GBReg_subseg_win3',
                                 'ocsvm_subseg',
                                 'ae_pipeline_14841',
                                 'lstm_pipeline_layer5_dim64',
                                 'iforest_pipeline_default',
                                 'GBReg_default',
                                 'ar_pipeline_win5',
                                 'ae_pipeline_5321616325',
                                 'ae_pipeline_32161632',
                                 'lstm_pipeline_layer10_dim32',
                                 'ae_pipeline_1416321641']


for pipeline_folder in pipeline_folder_list_to_test:
    pipelines_path = all_pipeline_folder + pipeline_folder
    pipeline_list = os.listdir(pipelines_path)
    
    pipeline_name = pipeline_list[-1]
    final_pipeline_path = '{}/{}'.format(pipelines_path, pipeline_name)   
    pipeline = load_pipeline(final_pipeline_path)
    evaluation_metric = 'F1'
    
    pipeline_name_trimmed = pipeline_name.split('_con')[0]
    pipeline_dict[pipeline_name_trimmed] = pipeline

all_data_folder = path_to_TODS + 'tods/benchmark/synthetic/{}'.format(data_type_folder_name)
data_name_list = os.listdir(all_data_folder)

data_dict = {}
data_dict_csv = {}

for data_name in data_name_list:
    if 'csv' in data_name:
        data_path = '{}/{}'.format(all_data_folder, data_name)
        df = pd.read_csv(data_path)
        scaler = StandardScaler()
        df.iloc[:,:-1] = scaler.fit_transform(df.iloc[:,:-1])

        target_index = 5
        dataset = generate_dataset(df, target_index)
        
        data_dict_csv[data_name] = df
        data_dict[data_name] = dataset





result_dict = {}


# run the TODS pipelines?
run_pipelines = True

contamination_list = ['0.05', '0.1', '0.15', '0.2', '0.25']



# running the TODS pipelines on the data and saving the predicted outliers and scores as well as the F1.
for run_number in range(5):
    
    
    if run_pipelines:
        path = '{}/synthetic_predictions/tods_results_{}_hypersearch.json'.format(save_path, str(run_number))

        for pipeline_name in pipeline_dict.keys():
            if not pipeline_name in result_dict:
                result_dict[pipeline_name] = {}
            
            for contamination in contamination_list:
                if not str(contamination) in result_dict[pipeline_name]:
                    result_dict[pipeline_name][str(contamination)] = {}
                for data_name in data_dict.keys():
                    if not data_name in result_dict[pipeline_name][str(contamination)]:
                        result_dict[pipeline_name][str(contamination)][data_name] = {}
            
            for data_name in data_dict.keys():
                
                if np.all([str(run_number) in result_dict[pipeline_name][str(contamination)][data_name] for contamination in contamination_list]):
                    continue
                
                dataset = data_dict[data_name]
                pipeline = pipeline_dict[pipeline_name]
                evaluation_metric = 'F1'
                pipeline_result = evaluate_pipeline(dataset, pipeline, evaluation_metric)
                
                try: proba = np.load('result_temp.npy')[:,-1]
                except:
                    pipeline_result = evaluate_pipeline(dataset, pipeline, evaluation_metric)
                    proba = np.load('result_temp.npy')[:,-1]
                
                
                Y = data_dict_csv[data_name].iloc[:,-1].to_numpy()
                
                for contamination in contamination_list:
                    
                    top_k = int(float(contamination)*Y.shape[0])
                    
                    i_outliers = np.argpartition(proba, -top_k)[-top_k:]
                    i_outliers = i_outliers[np.argsort(proba[i_outliers])[::-1]]
                    predict = np.zeros(Y.shape[0])
                    predict[i_outliers] = 1
                    f1 = f1_score(Y,predict)
                    result_dict[pipeline_name][str(contamination)][data_name][str(run_number)] = f1

                    # saving the results
                    pd.DataFrame(predict, columns = ['predict']).to_csv('./Unpredictability_Results/' \
                                                                   'synthetic_predictions/{}_con_{}_{}_{}_hypersearch.csv'.format(pipeline_name, 
                                                                                                                                 contamination,
                                                                                                                                    data_name.split('.csv')[0],  
                                                                                                                                 str(run_number)))



                    print(f1)
                    results_file = open(path, "w")
                    results_file = json.dump(result_dict, results_file)
                    pd.DataFrame(proba, columns = ['proba']).to_csv('./Unpredictability_Results/' \
                                                                   'synthetic_predictions/{}_con_{}_proba_{}_{}_hypersearch.csv'.format(pipeline_name, 
                                                                                                                                 contamination,
                                                                                                                                    data_name.split('.csv')[0],  
                                                                                                                                 str(run_number)))
                    
                    

                    results_file = open(path, "w")
                    results_file = json.dump(result_dict, results_file)
                    

result_dict_transformer_encoder_decoder = {}
model_name = 'transformer_encoder_decoder'


# run the transformer encoder decoder tests?
run_tests = True

model_name = 'transformer_encoder_decoder'

contamination_list = ['0.05', '0.1', '0.15', '0.2', '0.25']


# transformer encoder decoder hyper-params
hyperparams_change = [{'nlayers': 2, 'dim_feedforward': 32}, 
                      {'nlayers': 2, 'dim_feedforward': 64}, 
                      {'nlayers': 5, 'dim_feedforward': 32}, 
                      {'nlayers': 5, 'dim_feedforward': 64}]

# running the transformer encoder-decoder on the data and saving the predicted outliers and scores as well as the F1.
for run_number in range(5):
    
    if run_tests:
        
        for hyperparams_run_dict in hyperparams_change:
            
            path = '{}/synthetic_predictions/{}_results_{}_hypersearch.json'.format(save_path, model_name, str(run_number))

            model_name_hyper = model_name + '_layers{}_dff{}'.format(hyperparams_run_dict['nlayers'], hyperparams_run_dict['dim_feedforward'])
            
            contamination_list = ['0.05', '0.1', '0.15', '0.2', '0.25']
            
            # consistent hyper-params
            hyperparams_dict = {'sequence_length':5,
                                'predicted_sequence_length': 3, 
                                 'dim_feedforward': hyperparams_run_dict['dim_feedforward'], 
                                 'num_encoder_layers': hyperparams_run_dict['nlayers'], 
                                'num_decoder_layers': hyperparams_run_dict['nlayers'],
                                 'batch_size': 30, 
                                 'dropout': 0.1, 
                                'epochs': 20, 
                                'verbose': True, 
                                'learning_rate': 0.01}
            
            if not model_name_hyper in result_dict_transformer_encoder_decoder:
                result_dict_transformer_encoder_decoder[model_name_hyper] = {}
            
            for contamination in contamination_list:
                if not str(contamination) in result_dict_transformer_encoder_decoder[model_name_hyper]:
                    result_dict_transformer_encoder_decoder[model_name_hyper][str(contamination)] = {}
                for file_name in data_dict_csv.keys():
                    if not file_name in result_dict_transformer_encoder_decoder[model_name_hyper][str(contamination)]:
                        result_dict_transformer_encoder_decoder[model_name_hyper][str(contamination)][file_name] = {}
            

            for nf, file_name in enumerate(data_dict_csv.keys()):

                X = data_dict_csv[file_name].iloc[:,:-1].to_numpy()
                Y = data_dict_csv[file_name].iloc[:,-1].to_numpy()

                scaler = StandardScaler()
                X = scaler.fit_transform(X)

                sequence_length = hyperparams_dict['sequence_length']
                
                if np.all([str(run_number) in result_dict_transformer_encoder_decoder[model_name_hyper][str(contamination)][file_name] for contamination in contamination_list]):
                    continue

                od = ODTransformerModel(sequence_length = hyperparams_dict['sequence_length'],
                                                predict_sequence_length = hyperparams_dict['predicted_sequence_length'], 
                                                  d_model=X.shape[-1], 
                                                  nhead=X.shape[-1], 
                                                  num_encoder_layers=hyperparams_dict['num_encoder_layers'], 
                                                 num_decoder_layers=hyperparams_dict['num_decoder_layers'], 
                                                  dim_feedforward=hyperparams_dict['dim_feedforward'], 
                                                  dropout=hyperparams_dict['dropout'], 
                                                 activation='relu',
                                                 layer_norm_eps=1e-04, 
                                                 device=device,
                                                 embedding_class = IdentityEmbedding, 
                                                  embedding_args = {},
                                                 learning_rate = hyperparams_dict['learning_rate'], 
                                                  epochs = hyperparams_dict['epochs'], 
                                                  batch_size = hyperparams_dict['batch_size'], 
                                                  verbose = hyperparams_dict['verbose'])

                od.fit(X)
                
                proba = od.predict_proba(X)

                for contamination in contamination_list:
                    
                    top_k = int(float(contamination)*Y.shape[0])

                    i_outliers = np.argpartition(proba, -top_k)[-top_k:]
                    i_outliers = i_outliers[np.argsort(proba[i_outliers])[::-1]]
                    predict = np.zeros(Y.shape[0])
                    predict[i_outliers] = 1
                    
                    f1 = f1_score(Y,predict)
                    if hyperparams_dict['verbose']: print(f1)
                    result_dict_transformer_encoder_decoder[model_name_hyper][str(contamination)][file_name][str(run_number)] = f1

                    # saving the results
                    pd.DataFrame(predict, columns = ['predict']).to_csv('./Unpredictability_Results/' \
                                                                        'synthetic_predictions/' \
                                                                        '{}_con_{}_data_{}_{}.csv'.format(model_name_hyper, contamination,
                                                                                                file_name.split('.csv')[0], str(run_number)))
                    
                    pd.DataFrame(proba, columns = ['proba']).to_csv('./Unpredictability_Results/' \
                                                            'synthetic_predictions/' \
                                                            '{}_con_{}_proba_data_{}_{}.csv'.format(model_name_hyper, contamination,
                                                                                                              file_name.split('.csv')[0], str(run_number)))

                    

                    results_file = open(path, "w")
                    results_file = json.dump(result_dict_transformer_encoder_decoder, results_file)


result_dict_transformer_encoder = {}
model_name = 'transformer_encoder'

# run the transformer encoder tests?
run_tests = True

model_name = 'transformer_encoder'

contamination_list = ['0.05', '0.1', '0.15', '0.2', '0.25']


# transformer encoder hyper-params
hyperparams_change = [{'nlayers': 2, 'dim_feedforward': 32}, 
                      {'nlayers': 2, 'dim_feedforward': 64}, 
                      {'nlayers': 5, 'dim_feedforward': 32}, 
                      {'nlayers': 5, 'dim_feedforward': 64}]

# running the transformer encoder on the data and saving the predicted outliers and scores as well as the F1.
for run_number in range(5):
    
    if run_tests:
        
        for hyperparams_run_dict in hyperparams_change:
            
            path = '{}/synthetic_predictions/{}_results_{}_hypersearch.json'.format(save_path, model_name, str(run_number))

            model_name_hyper = model_name + '_layers{}_dff{}'.format(hyperparams_run_dict['nlayers'], hyperparams_run_dict['dim_feedforward'])
            
            contamination_list = ['0.05', '0.1', '0.15', '0.2', '0.25']
            

            # consistent hyper-params
            hyperparams_dict = {'sequence_length': 5, 
                                 'dim_feedforward': hyperparams_run_dict['dim_feedforward'], 
                                 'nlayers': hyperparams_run_dict['nlayers'], 
                                 'batch_size': 30, 
                                 'dropout': 0.1, 
                                    'epochs': 20, 
                                'learning_rate': 0.01,
                                   'verbose': True}
            
            if not model_name_hyper in result_dict_transformer_encoder:
                result_dict_transformer_encoder[model_name_hyper] = {}
            
            for contamination in contamination_list:
                if not str(contamination) in result_dict_transformer_encoder[model_name_hyper]:
                    result_dict_transformer_encoder[model_name_hyper][str(contamination)] = {}
                for file_name in data_dict_csv.keys():
                    if not file_name in result_dict_transformer_encoder[model_name_hyper][str(contamination)]:
                        result_dict_transformer_encoder[model_name_hyper][str(contamination)][file_name] = {}
            

            for nf, file_name in enumerate(data_dict_csv.keys()):

                X = data_dict_csv[file_name].iloc[:,:-1].to_numpy()
                Y = data_dict_csv[file_name].iloc[:,-1].to_numpy()

                scaler = StandardScaler()
                X = scaler.fit_transform(X)

                sequence_length = hyperparams_dict['sequence_length']
                
                if np.all([str(run_number) in result_dict_transformer_encoder[model_name_hyper][str(contamination)][file_name] for contamination in contamination_list]):
                    continue


                od = ODRegressionEncoderTransformerModel(sequence_length = hyperparams_dict['sequence_length'], 
                                                      embedding_dim = X.shape[-1], 
                                                      nhead = X.shape[-1], 
                                                      dim_feedforward = hyperparams_dict['dim_feedforward'], 
                                                      nlayers = hyperparams_dict['nlayers'], 
                                                      batch_size = hyperparams_dict['batch_size'], 
                                                      epochs = hyperparams_dict['epochs'], 
                                                      dropout = hyperparams_dict['dropout'], 
                                                      verbose = hyperparams_dict['verbose'], 
                                                      learning_rate = hyperparams_dict['learning_rate'],
                                                      device = device)

                od.fit(X)
                
                proba = od.predict_proba(X)

                for contamination in contamination_list:
                    
                    top_k = int(float(contamination)*Y.shape[0])

                    i_outliers = np.argpartition(proba, -top_k)[-top_k:]
                    i_outliers = i_outliers[np.argsort(proba[i_outliers])[::-1]]
                    predict = np.zeros(Y.shape[0])
                    predict[i_outliers] = 1
                    
                    f1 = f1_score(Y,predict)
                    if hyperparams_dict['verbose']: print(f1)
                    result_dict_transformer_encoder[model_name_hyper][str(contamination)][file_name][str(run_number)] = f1

                    # saving the results
                    pd.DataFrame(predict, columns = ['predict']).to_csv('./Unpredictability_Results/' \
                                                                        'synthetic_predictions/' \
                                                                        '{}_con_{}_data_{}_{}.csv'.format(model_name_hyper, contamination,
                                                                                                file_name.split('.csv')[0], str(run_number)))
                    
                    pd.DataFrame(proba, columns = ['proba']).to_csv('./Unpredictability_Results/' \
                                                            'synthetic_predictions/' \
                                                            '{}_con_{}_proba_data_{}_{}.csv'.format(model_name_hyper, contamination,
                                                                                                              file_name.split('.csv')[0], str(run_number)))

                    

                    results_file = open(path, "w")
                    results_file = json.dump(result_dict_transformer_encoder, results_file)


