import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.patches as patches

import os

import json

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score, roc_auc_score, roc_curve

from crtod.graphing import heat_plotter


save_path = './Unpredictability_Results'


# loading the data with the ground truth outliers
data_type_folder_name = 'multidataset'
all_data_folder = path_to_TODS + 'tods/benchmark/synthetic/{}'.format(data_type_folder_name)
data_name_list = os.listdir(all_data_folder)
data_dict_csv = {}

for data_name in data_name_list:
    if 'csv' in data_name:
        data_path = '{}/{}'.format(all_data_folder, data_name)
        df = pd.read_csv(data_path)
        scaler = StandardScaler()
        df.iloc[:,:-1] = scaler.fit_transform(df.iloc[:,:-1])
        
        data_dict_csv[data_name] = df

# models to plot
models_test = ['ocsvm',
                 'iforest',
                 'ae',
                 'ar',
                 'GBReg',
                 'lstm',
                 'transformer_encoder',
                 'transformer_encoder_decoder']


# loading in the f1 results
results_path = '{}/synthetic_predictions/'.format(save_path)

with open(results_path + 'tods_results_4_hypersearch.json') as a_file:
    result_dict = json.load(a_file)
    
with open(results_path + 'transformer_encoder_results_4_hypersearch.json') as a_file:
    result_dict_transformer_encoder = json.load(a_file)

with open(results_path + 'transformer_encoder_decoder_results_4_hypersearch.json') as a_file:
    result_dict_transformer_encoder_decoder = json.load(a_file)

for model_name in result_dict_transformer_encoder.keys():
    result_dict[model_name] = result_dict_transformer_encoder[model_name]

for model_name in result_dict_transformer_encoder_decoder.keys():
    result_dict[model_name] = result_dict_transformer_encoder_decoder[model_name]


# loading in the predictions and outlier scores for each test
dataset_list = []
proba_dict = {}
predict_dict = {}

for file_name in os.listdir(results_path):
    if 'hypersearch.csv' in file_name:
        
        if 'proba' in file_name:
            pipeline_name = file_name.split('_con')[0]
            contamination = file_name.split('_con_')[1].split('_')[0]
            data_set = file_name.split('proba')[1].split(contamination)[-1][1:].split('hypersearch.csv')[0][:-3]
            if not data_set in dataset_list:
                dataset_list.append(data_set)
            run_number = file_name.split('hypersearch.csv')[0].split('_')[-2]
            if not pipeline_name in proba_dict:
                proba_dict[pipeline_name] = {}
            if not data_set in proba_dict[pipeline_name]:
                proba_dict[pipeline_name][data_set] = {}
            if not contamination in proba_dict[pipeline_name][data_set]:
                proba_dict[pipeline_name][data_set][contamination] = {}
                
            proba_dict[pipeline_name][data_set][contamination][run_number] = pd.read_csv(results_path + file_name).iloc[:,-1].to_numpy()
            
        else:
            pipeline_name = file_name.split('_con')[0]
            contamination = file_name.split('_con_')[1].split('_')[0]
            data_set = file_name.split('_con_')[1].split(contamination)[-1][1:].split('hypersearch.csv')[0][:-3]
            run_number = file_name.split('hypersearch.csv')[0].split('_')[-2]
            if not pipeline_name in predict_dict:
                predict_dict[pipeline_name] = {}
            if not data_set in predict_dict[pipeline_name]:
                predict_dict[pipeline_name][data_set] = {}
            if not contamination in predict_dict[pipeline_name][data_set]:
                predict_dict[pipeline_name][data_set][contamination] = {}
                
            predict_dict[pipeline_name][data_set][contamination][run_number] = pd.read_csv(results_path + file_name).iloc[:,-1].to_numpy()
                
    elif 'transformer' in file_name and '.csv' in file_name:

        if 'proba' in file_name:
            model_name = file_name.split('_con_')[0]
            contamination = file_name.split('_con_')[1].split('_')[0]
            data_set = file_name.split('data_')[1].split('_')[0]
            if not data_set in dataset_list:
                dataset_list.append(data_set)
            run_number = file_name[-5]
            if not model_name in proba_dict:
                proba_dict[model_name] = {}
            if not data_set in proba_dict[model_name]:
                proba_dict[model_name][data_set] = {}
            if not contamination in proba_dict[model_name][data_set]:
                proba_dict[model_name][data_set][contamination] = {}
                
            proba_dict[model_name][data_set][contamination][run_number] = pd.read_csv(results_path + file_name).iloc[:,-1].to_numpy()
            
        else:
            model_name = file_name.split('_con_')[0]
            contamination = file_name.split('_con_')[1].split('_')[0]
            data_set = file_name.split('data_')[1].split('_')[0]
            run_number = file_name[-5]
            if not model_name in predict_dict:
                predict_dict[model_name] = {}
            if not data_set in predict_dict[model_name]:
                predict_dict[model_name][data_set] = {}
            if not contamination in predict_dict[model_name][data_set]:
                predict_dict[model_name][data_set][contamination] = {}
                
            predict_dict[model_name][data_set][contamination][run_number] = pd.read_csv(results_path + file_name).iloc[:,-1].to_numpy()



# grouping the models by their names. This ensures that models with different hyper parameters are grouped together
model_short_to_long = {}


for model_short_name in models_test:
    model_long_names = []
    for model_long_name in result_dict.keys():
        if model_short_name in model_long_name:
            if model_short_name == 'transformer_encoder':
                if not 'decoder' in model_long_name:
                    model_long_names.append(model_long_name)
            else:
                model_long_names.append(model_long_name)

    model_short_to_long[model_short_name] = model_long_names


# model names that will be plotted
model_short_to_plot_name = {'ocsvm': 'One Class SVM',
                             'iforest': 'Isolation Forest',
                             'ae': 'Auto-Encoder',
                             'ar': 'Auto-Regressive',
                            'GBReg': 'Gradient Boosting Regression',
                             'lstm': 'Long Short Term Memory RNN',
                             'transformer_encoder': 'Transformer Encoder',
                             'transformer_encoder_decoder': 'Transformer Encoder Decoder'}


# building dictionaries containing the best scores for each group of models
best_scores = {data_set: {model: {} for model in models_test} for data_set in dataset_list}
best_scores_std = {data_set: {model: {} for model in models_test} for data_set in dataset_list}
best_scores_model = {data_set: {model: {} for model in models_test} for data_set in dataset_list}
best_scores_proba = {data_set: {model: {} for model in models_test} for data_set in dataset_list}
contamination_list = ['0.05', '0.1', '0.15', '0.2', '0.25']

for data_set in best_scores.keys():
    for model_short_name in best_scores[data_set].keys():
        for contamination in contamination_list:
            
            scores = [np.mean(list(result_dict[model_long_name][str(contamination)][data_set+ '.csv'].values())) for model_long_name in model_short_to_long[model_short_name]]
            scores_sd = [np.std(list(result_dict[model_long_name][str(contamination)][data_set+ '.csv'].values())) for model_long_name in model_short_to_long[model_short_name]]
            max_score = np.max(scores)
            max_score_model = model_short_to_long[model_short_name][np.argmax(scores)]
            max_score_sd = scores_sd[np.argmax(scores)]
            
            best_scores[data_set][model_short_name][str(contamination)] = max_score
            best_scores_model[data_set][model_short_name][str(contamination)] = max_score_model
            best_scores_std[data_set][model_short_name][str(contamination)] = max_score_sd



# calculating the roc and the auc-roc for each of the models' best f1 scores
roc_auc_dict = {data_set: {model: {} for model in models_test} for data_set in dataset_list}

for data_set in dataset_list:
    
    Y_train = data_dict_csv[data_set+'.csv'].iloc[:,-1].to_numpy()
    
    for model in models_test:
        scores_temp = list(best_scores[data_set][model].values())
        best_contamination_value = contamination_list[np.argmax(scores_temp)]
        model_long_name = best_scores_model[data_set][model][best_contamination_value]
        proba_values_over_runs = proba_dict[model_long_name][data_set][str(contamination)]

        roc_auc_dict[data_set][model]['roc'] = {'fpr': [], 'tpr': {}}
        roc_auc_dict[data_set][model]['auc'] = {}

        tpr_values = []
        auc_values = []

        mean_fpr = np.linspace(0,1,100)

        # performing linear interpolation on the roc score so that they are all of the same length
        for proba_values in proba_values_over_runs.values():
            # skipping experiments where the outlier score was nan
            if np.isnan(proba_values).any():
                continue
            fpr, tpr = roc_curve(Y_train, proba_values)[:2]
            auc_values.append(roc_auc_score(Y_train, proba_values))
            interpolation_tpr = np.interp(mean_fpr, fpr, tpr)
            tpr_values.append(interpolation_tpr)

        mean_tpr = np.mean(tpr_values, axis=0)
        std_tpr = np.std(tpr_values, axis=0)

        roc_auc_dict[data_set][model]['roc']['fpr'] = mean_fpr
        roc_auc_dict[data_set][model]['roc']['tpr']['mean'] = mean_tpr
        roc_auc_dict[data_set][model]['roc']['tpr']['std'] = std_tpr
        roc_auc_dict[data_set][model]['auc']['mean'] = np.mean(auc_values)
        roc_auc_dict[data_set][model]['auc']['std'] = np.std(auc_values)


# making an array from the roc data
roc_array = np.zeros((len(dataset_list), len(models_test)))
for nm, model in enumerate(models_test):   
    for nd, data_set in enumerate(dataset_list):

        roc_array[nd,nm] = roc_auc_dict[data_set][model]['auc']['mean']



# plotting a histogram from the roc-auc data
fig, axes = plt.subplots(1,1,figsize = (8,5))

non_rec = np.max(roc_array[:,:3], axis = 1)
rec = np.max(roc_array[:,3:], axis = 1)

n_bins = 10
bins = np.arange(0,1+1/n_bins,1/n_bins)

axes.hist(non_rec, label = 'Non-Recurrent', histtype='step', linewidth = 2, bins = bins, alpha = 0.7)
axes.hist(rec, label = 'Recurrent', histtype='step', linewidth = 2, bins = bins, alpha = 0.7)


axes.legend(title = 'Methods', frameon = True, facecolor = None, edgecolor = 'black', loc = 'upper left')

axes.set_xlabel('AUC', fontsize = 15)


axes.set_ylabel('Frequency', fontsize = 15)
axes.set_title('Distribution Of AUC-ROC For Each Synthetic Dataset', fontsize = 15)

fig.savefig('{}/data_graphs/AUC_ROC_synthetic_histogram.pdf'.format(save_path), bbox_inches = 'tight')




# making an array containing the f1 score data
f1_array = np.zeros((len(dataset_list), len(models_test)))

for nm, model in enumerate(models_test):   
    for nd, data_set in enumerate(dataset_list):

        f1_array[nd,nm] = np.max(list(best_scores[data_set][model].values()))



# plotting a histogram with the f1 scores
fig, axes = plt.subplots(1,1,figsize = (8,5))

non_rec = np.max(f1_array[:,:3], axis = 1)
rec = np.max(f1_array[:,3:], axis = 1)

n_bins = 10
bins = np.arange(0,1+1/n_bins,1/n_bins)

axes.hist(non_rec, label = 'Non-Recurrent', histtype='step', linewidth = 2, bins = bins, alpha = 0.7)
axes.hist(rec, label = 'Recurrent', histtype='step', linewidth = 2, bins = bins, alpha = 0.7)


axes.legend(title = 'Methods', frameon = True, facecolor = None, edgecolor = 'black', loc = 'upper right')

axes.set_xlabel('F1 Score', fontsize = 15)
axes.set_xticks(np.arange(0,1.1,0.2))
axes.set_xticklabels(['{:.0f}%'.format(x*100) for x in np.arange(0,1.1,0.2)])

axes.set_ylabel('Frequency', fontsize = 15)
axes.set_title('Distribution Of Maximum F1 Score For Each Synthetic Dataset', fontsize = 15)

fig.savefig('{}/data_graphs/f1_synthetic_histogram.pdf'.format(save_path), bbox_inches = 'tight')

best_models = np.argwhere((f1_array.T == np.max(f1_array, axis = 1)).T)
counts = np.zeros(len(models_test))
for index in best_models:
    counts[index[1]] += 1



# plotting the number of times each model obtained the largest f1 score on a household's data
fig, axes = plt.subplots(1,1,figsize = (8,5))


x = [0,1,2,5,6,7,8,9] 

axes.bar(x[:3], counts[:3], label = 'Non-Recurrent')
axes.bar(x[3:], counts[3:], label = 'Recurrent')
axes.legend(title = 'Methods', frameon = True, facecolor = None, edgecolor = 'black', loc = 'upper right')

axes.set_xticks(x)
axes.set_xticklabels([model_short_to_plot_name[name] for name in models_test], rotation = 90)

axes.set_ylim(0,10.25)

axes.set_ylabel('Frequency', fontsize = 15)
axes.set_title('Frequency Of Maximum F1 Scores When Predicting Synthetic Data', fontsize = 15)

axes.grid(False, axis = 'x')

fig.savefig('{}/data_graphs/max_f1_score_synthetic_bar.pdf'.format(save_path), bbox_inches = 'tight')



# calculating how many of each outlier type exists in these datasets
outlier_type_synthetic = {}
for dataset in dataset_list:
    outlier_indices = np.argwhere(data_dict_csv[dataset + '.csv'].iloc[:,-1].to_numpy() == 1).reshape(-1)
    
    outlier_type_synthetic[dataset] = {'point_outlier': 0, 'collective_outlier': 0}

    
    def find_inlier(io, outlier_indices, direction):
        io_test = io
        while io_test in outlier_indices:
            if direction == 'start':
                io_test += -1
            if direction == 'end':
                io_test += 1
        return io_test

    
    
    end = 0
    for io in outlier_indices:
        if io < end:
            outlier_type_synthetic[dataset]['collective_outlier'] += 1
        else:
            start = find_inlier(io, outlier_indices, 'start')
            end = find_inlier(io, outlier_indices, 'end')
            if end - start == 2:
                outlier_type_synthetic[dataset]['point_outlier'] += 1
            else:
                outlier_type_synthetic[dataset]['collective_outlier'] += 1


# calculating which outlier type each dataset is dominated by
dominated_dataset = {}
for dataset in dataset_list:
    if outlier_type_synthetic[dataset]['point_outlier'] > outlier_type_synthetic[dataset]['collective_outlier']:
            dominated_dataset[dataset] = 'point'
    else:
        dominated_dataset[dataset] = 'collective'

# calculating how many times each model achieved the best f1 score on each dataset and which outlier type dominated that dataset
method_dominate = {'point': {model:0 for model in models_test}, 'collective': {model:0 for model in models_test}}

for nd, dataset in enumerate(dataset_list): 
    model = models_test[best_models[nd,1]]
    method_dominate[dominated_dataset[dataset]][model] += 1

point_values = np.asarray(list(method_dominate['point'].values()))
collective_values = np.asarray(list(method_dominate['collective'].values()))


# plotting a stacked bar chart showing the number of times a model obtained the greatest f1 score and plotting it, split by outlier type
fig, axes = plt.subplots(1,1,figsize = (6,4))


x = [0,1,2,5,6,7,8,9] 

axes.bar(x, collective_values+point_values, label = 'Point Outliers', color = 'lightblue')
axes.bar(x, collective_values, label = 'Collective Outliers', color = 'slategrey')


axes.legend(title = 'Dataset Dominated By', frameon = True, facecolor = None, edgecolor = 'black', loc = 'upper right')

axes.set_xticks(x)
axes.set_xticklabels([model_short_to_plot_name[name] for name in models_test], rotation = 90)

axes.set_ylim(0,8.25)

axes.set_ylabel('Frequency', fontsize = 15)
axes.set_title('Frequency Of Maximum F1 Scores When Predicting Synthetic Data', fontsize = 15)

axes.grid(False, axis = 'x')

fig.savefig('{}/data_graphs/max_f1_score_synthetic_count.pdf'.format(save_path), bbox_inches = 'tight')