import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.patches as patches

import os

import json

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score, roc_auc_score, roc_curve

from crtod.graphing import heat_plotter



save_path = './Unpredictability_Results'

# loading the data with the ground truth outliers
datasets = ['swan_sf','water_quality']
all_data_folder = path_to_TODS + 'tods/benchmark/realworld_data/data'
data_name_list = os.listdir(all_data_folder)
data_dict_csv = {}

for data_name in data_name_list:
    if 'csv' in data_name and data_name.split('.csv')[0] in datasets:
        data_path = '{}/{}'.format(all_data_folder, data_name)
        df = pd.read_csv(data_path)
        scaler = StandardScaler()
        df.iloc[:,1:] = scaler.fit_transform(df.iloc[:,1:])

        data_dict_csv[data_name] = df

# models to plot
models_test = {'water_quality': ['ocsvm',
                                 'iforest',
                                 'ae',
                                 'ar',
                                 'lstm',
                                 'transformer_encoder',
                                 'transformer_encoder_decoder'],
               
               'swan_sf': ['iforest',
                             'ae',
                             'ar',
                             'transformer_encoder',
                             'transformer_encoder_decoder']
              }


# loading in the f1 results
results_path = '{}/realworld_predictions/'.format(save_path)

with open(results_path + 'real_world_purpose_built_4_hypersearch.json') as a_file:
    result_dict = json.load(a_file)

with open(results_path + 'transformer_encoder_mseloss_4.json') as a_file:
    result_dict_transformer_encoder = json.load(a_file)

with open(results_path + 'transformer_encoder_decoder_mseloss_4.json') as a_file:
    result_dict_transformer_encoder_decoder = json.load(a_file)

for model_name in result_dict_transformer_encoder.keys():
    result_dict[model_name] = result_dict_transformer_encoder[model_name]

for model_name in result_dict_transformer_encoder_decoder.keys():
    result_dict[model_name] = result_dict_transformer_encoder_decoder[model_name]
    

# loading in the predictions and outlier scores for each test
proba_dict = {}
predict_dict = {}

for file_name in os.listdir(results_path):
    if 'hypersearch.csv' in file_name:
        
        if 'proba' in file_name:
            pipeline_name = file_name.split('_con')[0]
            contamination = file_name.split('_con_')[1].split('_')[0]
            data_set = file_name.split('proba')[1].split(contamination)[-1][1:].split('hypersearch.csv')[0][:-3]
            run_number = file_name.split('hypersearch.csv')[0].split('_')[-2]
            if not pipeline_name in proba_dict:
                proba_dict[pipeline_name] = {}
            if not data_set in proba_dict[pipeline_name]:
                proba_dict[pipeline_name][data_set] = {}
            if not contamination in proba_dict[pipeline_name][data_set]:
                proba_dict[pipeline_name][data_set][contamination] = {}
                
            proba_dict[pipeline_name][data_set][contamination][run_number] = pd.read_csv(results_path + file_name).iloc[:,-1].to_numpy()
            
        else:
            pipeline_name = file_name.split('_con')[0]
            contamination = file_name.split('_con_')[1].split('_')[0]
            data_set = file_name.split('_con_')[1].split(contamination)[-1][1:].split('hypersearch.csv')[0][:-3]
            run_number = file_name.split('hypersearch.csv')[0].split('_')[-2]
            if not pipeline_name in predict_dict:
                predict_dict[pipeline_name] = {}
            if not data_set in predict_dict[pipeline_name]:
                predict_dict[pipeline_name][data_set] = {}
            if not contamination in predict_dict[pipeline_name][data_set]:
                predict_dict[pipeline_name][data_set][contamination] = {}
                
            predict_dict[pipeline_name][data_set][contamination][run_number] = pd.read_csv(results_path + file_name).iloc[:,-1].to_numpy()
                
    elif 'transformer' in file_name and '.csv' in file_name:

        if 'proba' in file_name:
            model_name = file_name.split('_con_')[0]
            contamination = file_name.split('_con_')[1].split('_')[0]
            data_set = file_name.split('proba_')[1][:-6]
            run_number = file_name[-5]
            if not model_name in proba_dict:
                proba_dict[model_name] = {}
            if not data_set in proba_dict[model_name]:
                proba_dict[model_name][data_set] = {}
            if not contamination in proba_dict[model_name][data_set]:
                proba_dict[model_name][data_set][contamination] = {}
                
            proba_dict[model_name][data_set][contamination][run_number] = pd.read_csv(results_path + file_name).iloc[:,-1].to_numpy()
            
        else:
            model_name = file_name.split('_con_')[0]
            contamination = file_name.split('_con_')[1].split('_')[0]
            data_set = file_name.split(contamination)[1][1:][:-6]
            run_number = file_name[-5]
            if not model_name in predict_dict:
                predict_dict[model_name] = {}
            if not data_set in predict_dict[model_name]:
                predict_dict[model_name][data_set] = {}
            if not contamination in predict_dict[model_name][data_set]:
                predict_dict[model_name][data_set][contamination] = {}
                
            predict_dict[model_name][data_set][contamination][run_number] = pd.read_csv(results_path + file_name).iloc[:,-1].to_numpy()


# grouping the models by their names. This ensures that models with different hyper parameters are grouped together
model_short_to_long = {}

for data_set in models_test.keys():
    for model_short_name in models_test[data_set]:
        model_long_names = []
        for model_long_name in result_dict.keys():
            if model_short_name in model_long_name:
                if model_short_name == 'transformer_encoder':
                    if not 'decoder' in model_long_name:
                        model_long_names.append(model_long_name)
                else:
                    model_long_names.append(model_long_name)
        
        model_short_to_long[model_short_name] = model_long_names


# model names that will be plotted
model_short_to_plot_name = {'ocsvm': 'One Class SVM',
                                 'iforest': 'Isolation Forest',
                                 'ae': 'Auto-Encoder',
                                 'ar': 'Auto-Regressive',
                                 'lstm': 'Long Short Term Memory RNN',
                                 'transformer_encoder': 'Transformer Encoder',
                                 'transformer_encoder_decoder': 'Transformer Encoder Decoder'}


# building dictionaries containing the best scores for each group of models
best_scores = {key: {model: {} for model in models_test[key]} for key in models_test.keys()}
best_scores_std = {key: {model: {} for model in models_test[key]} for key in models_test.keys()}
best_scores_model = {key: {model: {} for model in models_test[key]} for key in models_test.keys()}
contamination_list = ['0.05', '0.1', '0.15', '0.2', '0.25']

for data_set in best_scores.keys():
    for model_short_name in best_scores[data_set].keys():
        for contamination in contamination_list:
            
            scores = [np.mean(list(result_dict[model_long_name][str(contamination)][data_set+ '.csv'].values())) for model_long_name in model_short_to_long[model_short_name]]
            scores_sd = [np.std(list(result_dict[model_long_name][str(contamination)][data_set+ '.csv'].values())) for model_long_name in model_short_to_long[model_short_name]]
            max_score = np.max(scores)
            max_score_model = model_short_to_long[model_short_name][np.argmax(scores)]
            max_score_sd = scores_sd[np.argmax(scores)]
            
            best_scores[data_set][model_short_name][str(contamination)] = max_score
            best_scores_model[data_set][model_short_name][str(contamination)] = max_score_model
            best_scores_std[data_set][model_short_name][str(contamination)] = max_score_sd
            


# plotting the bar chart for the f1 score on the water quality dataset
data_set = 'water_quality'

model_list = list(best_scores[data_set].keys())

scores = np.zeros((len(contamination_list), len(best_scores[data_set])))
scores_std = np.zeros((len(contamination_list), len(best_scores[data_set])))
for ic, contamination in enumerate(contamination_list):
    for im, model_short_name in enumerate(best_scores[data_set].keys()):
        scores[ic,im] = best_scores[data_set][model_short_name][str(contamination)]
        scores_std[ic,im] = best_scores_std[data_set][model_short_name][str(contamination)]

max_scores = np.max(scores, axis = 0)
arg_max_scores = np.argmax(scores, axis = 0)
std_scores = scores_std[arg_max_scores,np.arange(scores_std.shape[1])]

fig, axes = plt.subplots(1,1,figsize = (8,5))

x = [0,1,2,5,6,7,8] 

axes.bar(x[:3], max_scores[:3], label = 'Non-Recurrent')
axes.bar(x[3:], max_scores[3:], label = 'Recurrent')
axes.errorbar(x[:3], max_scores[:3], yerr=std_scores[:3], barsabove = True, ls='none', c = 'black', capsize=10, elinewidth=1, markeredgewidth=1)
axes.errorbar(x[3:], max_scores[3:], yerr=std_scores[3:], barsabove = True, ls='none', c = 'black', capsize=10, elinewidth=1, markeredgewidth=1)

axes.legend(title = 'Methods', frameon = True, facecolor = None, edgecolor = 'black')

axes.set_xticks(x)
axes.set_xticklabels([model_short_to_plot_name[model] for model in model_list], rotation = 90)

axes.set_yticks(np.arange(0,1.1,0.1))
axes.set_yticklabels(['{:.0f}%'.format(x*100) for x in np.arange(0,1.1,0.1)])

axes.set_ylim(0,0.425)

axes.set_ylabel('F1 Score', fontsize = 15)
axes.set_title('Maximum F1 Score On Water Quality Data'.format(data_name), fontsize = 15)

axes.grid(False, axis = 'x')

fig.savefig('{}/data_graphs/max_f1_score_water_quality.pdf'.format(save_path), bbox_inches = 'tight')




# plotting the bar chart for the f1 score on the SWAN-SF dataset
data_set = 'swan_sf'

model_list = list(best_scores[data_set].keys())

scores = np.zeros((len(contamination_list), len(best_scores[data_set])))
scores_std = np.zeros((len(contamination_list), len(best_scores[data_set])))
for ic, contamination in enumerate(contamination_list):
    for im, model_short_name in enumerate(best_scores[data_set].keys()):
        scores[ic,im] = best_scores[data_set][model_short_name][str(contamination)]
        scores_std[ic,im] = best_scores_std[data_set][model_short_name][str(contamination)]

max_scores = np.max(scores, axis = 0)
arg_max_scores = np.argmax(scores, axis = 0)
std_scores = scores_std[arg_max_scores,np.arange(scores_std.shape[1])]

fig, axes = plt.subplots(1,1,figsize = (8,5))

x = [0,1,4,5,6] 

axes.bar(x[:2], max_scores[:2], label = 'Non-Recurrent')
axes.bar(x[2:], max_scores[2:], label = 'Recurrent')
axes.errorbar(x[:2], max_scores[:2], yerr=std_scores[:2], barsabove = True, ls='none', c = 'black', capsize=10, elinewidth=1, markeredgewidth=1)
axes.errorbar(x[2:], max_scores[2:], yerr=std_scores[2:], barsabove = True, ls='none', c = 'black', capsize=10, elinewidth=1, markeredgewidth=1)

axes.legend(title = 'Methods', frameon = True, facecolor = None, edgecolor = 'black')

axes.set_xticks(x)
axes.set_xticklabels([model_short_to_plot_name[model] for model in model_list], rotation = 90)

axes.set_yticks(np.arange(0,1.1,0.1))
axes.set_yticklabels(['{:.0f}%'.format(x*100) for x in np.arange(0,1.1,0.1)])

axes.set_ylim(0,0.625)

axes.set_ylabel('F1 Score', fontsize = 15)
axes.set_title('Maximum F1 Score On SWAN-SF Data'.format(data_name), fontsize = 15)

axes.grid(False, axis = 'x')

fig.savefig('{}/data_graphs/max_f1_score_swan_sf.pdf'.format(save_path), bbox_inches = 'tight')



