from data_process import load_data
from nod import NOD
from utils import list_length

import time
import os
import json

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.metrics import precision_score, f1_score

def main(config):

    result_Statistics = pd.DataFrame()
    auc_mean_list, ap_mean_list, precision_mean_list, f1_mean_list = [], [], [], []
    auc_std_list,  ap_std_list,  precision_std_list,  f1_std_list  = [], [], [], []

    cl_auc_mean_list, cl_pre_mean_list = [], []

    dataset_list = []
    save_result_path = ""

    for data_path in config['data_path_list']:
        summary_key = data_path.split("/")[-1].split(".")[0]
        dataset_list.append(summary_key)

        auc_list, ap_list, precision_list, f1_list = [], [], [], []
        cl_auc_list, cl_pre_list = [], []
        lowest_loss_list = []
        data_path = os.path.join(config['data_dir'], data_path)
        init_data,init_label,contamination = load_data(data_path)# load data
        config['contamination'] = contamination
        # config['delta_threshold'] = contamination
        config['features_size'] = init_data.shape[-1]

        print("-"*10 + summary_key + " is training" + "-"*10)
        for num in range(config['number_experiment']):
            config['num'] = num
            config['summary_key'] = summary_key
            name_list = [config['summary']]
            name_list += [str(config['number_linear'])+"linear", config['activation']]
            name_list += ["epoch"+str(config['epochs'])]
            name_list += [config['loss']]
            name_list += [config['optimizer'], str(config['learning_rate']).replace('.','')]
            name_list += ["earlystop" + str(int(config['use_classification_auc_early_stopping'] or config['use_classification_precision_early_stopping']))]
            config['version'] = "_".join(name_list)
            save_result_path = os.path.join(config["save_dir"], config["version"])

            print(config['version'])
            
            start_train_time = time.time()
            model = NOD(config)
            
            loss_list, auc_list_plot, classification_auc_list, classification_precision_list, sum_time_classification_threshold, sum_time_paint_auc = model.train(init_data,init_label) # train NOD
            if not classification_auc_list:
                classification_auc_list = [1]
            if not classification_precision_list:
                classification_precision_list = [1]
            print("train cost {} s, including get auc {} s and threshold {} s".format(time.time() - start_train_time, sum_time_classification_threshold, sum_time_paint_auc))
            loss_list, auc_list_plot, classification_auc_list, classification_precision_list = list_length(loss_list, config['epochs']), list_length(auc_list_plot, config['epochs']), list_length(classification_auc_list, config['epochs']), list_length(classification_precision_list, config['epochs'])
            
            # Draw the changes of loss and AUC
            plt.figure()
            plt_x = range(1, config["epochs"]+1,1)
            plt.plot(plt_x,loss_list, label = "loss")
            plt.plot(plt_x,auc_list_plot, label = "auc", c = 'red')
            legend = ["loss", "auc","cl_auc", "cl_pre"]
            if config['paint_classificaiton_auc']:
                plt.plot(plt_x, classification_auc_list, label='cl_auc', c='purple')
                legend.append('cl_auc')
            if config['paint_classificaiton_precision']:
                plt.plot(plt_x, classification_precision_list, label='cl_pre', c='green')
                legend.append('cl_pre')
            
            plt.legend(legend)

            plt.title("{}, ratio {}, oc_auc {}, cl_precision(outlier) {}".format(summary_key.split('_')[0], round(contamination*100, 2), round(auc_list_plot[-1] * 100 ,2),  round(classification_precision_list[-1]* 100, 2)))

            pic_save_path = os.path.join(save_result_path, "loss_plots", config['summary_key'])
            if not os.path.exists(pic_save_path):
                os.makedirs(pic_save_path)
            plt.savefig(os.path.join(pic_save_path,str(config['num']) + ".png"))
            plt.close()

            # predict
            start_predict_time = time.time()
            model.predict()
            print("predict cost {} s".format(time.time() - start_predict_time))
            # compute the ROC_AUC
            auc = roc_auc_score(init_label,model.labels_)
            ap = average_precision_score(init_label, model.labels_)
            precision_sco = precision_score(init_label, model.pred_labels)
            f1 = f1_score(init_label, model.pred_labels)

            # model.pred_labels
            auc_list.append(auc)
            ap_list.append(ap)
            precision_list.append(precision_sco)
            f1_list.append(f1)
            lowest_loss_list.append(min(loss_list))
            cl_auc_list.append(classification_auc_list[-1])
            cl_pre_list.append(classification_precision_list[-1])
            # print("auc",auc)
        # calculate the average ROC_AUC of 20 independent experiments
        auc_mean_list.append(round(np.mean(auc_list) * 100, 2))
        auc_std_list.append(round(np.std(auc_list,ddof=0) * 100, 2))

        ap_mean_list.append(round(np.mean(ap_list) * 100, 2))
        ap_std_list.append(round(np.std(ap_list,ddof=0) * 100, 2))

        precision_mean_list.append(round(np.mean(precision_list) * 100, 2))
        precision_std_list.append(round(np.std(precision_list,ddof=0) * 100, 2))

        f1_mean_list.append(round(np.mean(f1_list) * 100, 2))
        f1_std_list.append(round(np.std(f1_list,ddof=0) * 100, 2))

        cl_auc_mean_list.append(round(np.mean(cl_auc_list) * 100, 2))
        cl_pre_mean_list.append(round(np.mean(cl_pre_list) * 100, 2))

    result_Statistics["dataset"] = dataset_list
    result_Statistics['auc'] = auc_mean_list
    result_Statistics['auc_std'] = auc_std_list
    result_Statistics['auprc'] = ap_mean_list
    result_Statistics['auprc_std'] = ap_std_list
    result_Statistics['precision'] = precision_mean_list
    result_Statistics['precision_std'] = precision_std_list
    result_Statistics['f1'] = f1_mean_list
    result_Statistics['f1_std'] = f1_std_list

    result_Statistics['cl_auc'] = cl_auc_mean_list
    result_Statistics['cl_pre'] = cl_pre_mean_list

    print(result_Statistics)
    result_Statistics.to_csv(os.path.join(save_result_path, config['version']+".csv"),index=False)

    del config['device']
    del config['summary_key']
    del config['contamination']
    del config['num']
    dict_json = json.dumps(config)
    with open(os.path.join(save_result_path, config['version']+".json"), 'w') as f:
        f.write(dict_json)