################################################################################################
################################################################################################
# based on https://github.com/vicliv/DTE, adapted by first author
################################################################################################
################################################################################################

from adbench.baseline.PyOD import PYOD
from baselines.GANomaly.run import GANomaly
from baselines.dagmm import DAGMM
from baselines.drocc import DROCC
from baselines.normalizing_flow import FlowModel
from baselines.goad import GOAD
from baselines.icl import ICL
from baselines.vae import Vae
from baselines.slad.slad import SLAD
from baselines.deep_iforest.dif import DIF
from baselines.MCM.main import MCM
from baselines.WNCROD.main import WNCROD
from baselines.new_pyod.PyOD_new import PYOD_new
from baselines.new_pyod.additional_methods import PYOD_like
from baselines.new_pyod.additional.ODIN import ODIN

from diffusion.dte import DTECategorical, DTEInverseGamma
from diffusion.non_param_dte import DTENonParametric
from diffusion.ddpm import DDPM

from NCSBAD.main import Sb_fit, Sb_predict_score
from NCSBAD.main_val import Sb_fit_val, Sb_predict_score_val

import argparse
import numpy as np
import os
import pandas as pd
import pickle
import time

from adbench.myutils import Utils

import sklearn.metrics as skm
from data_generator_val import DataGenerator


def low_density_anomalies(test_log_probs, num_anomalies):
    """ Helper function for the F1-score, selects the num_anomalies lowest values of test_log_prob
    """
    anomaly_indices = np.argpartition(test_log_probs, num_anomalies-1)[:num_anomalies]
    preds = np.zeros(len(test_log_probs))
    preds[anomaly_indices] = 1
    return preds

def main(args):
    start_seed = args.start_seed
    end_seed = args.end_seed
    save_path = args.save_path
    dir = args.results_path
    start = 1
    
    for seed in range(start_seed, end_seed):        
        if not os.path.exists(dir):
            os.makedirs(dir)

        datagenerator = DataGenerator(seed = seed, test_size=0.5, normal=True) # data generator
        
        utils = Utils() # utils function
        utils.set_seed(seed)
        
        model_dict = {}

        model_dict['NCSBAD'] = None
        model_dict['NCSBADVAL'] = None

        # # Select models
        for _ in ['IForest', 'OCSVM', 'COPOD', 'ECOD', 'FeatureBagging', 'HBOS', 'KNN', 'LODA',
                        'LOF', 'MCD', 'PCA', 'DeepSVDD']:
            model_dict[_] = PYOD

        for _ in ['INNE', 'KPCA', 'KDE', 'GMM', 'CBLOF', 'SOD', 'LUNAR', 'SOGAAL', 'ALAD', 'AE', 'CD',
                  'MOGAAL', 'QMCD', 'Sampling']: # ['INNE']:
            model_dict[_] = PYOD_new

        for _ in ['EIF', 'Ensemble', 'GEN2OUT', 'DynamicHBOS', 'COF', 'ABOD', 'LMDD']: # ['INNE', 'KPCA', 'KDE', 'GMM', 'CBLOF', 'SOD', 'LUNAR', 'SOGAAL']: # ['INNE']:
            model_dict[_] = PYOD_like

        # for _ in ['KPCA', 'GMM', 'LUNAR']: # ['INNE']:
        #     model_dict[_] = PYOD_new
            
        model_dict['DAGMM'] = DAGMM
        model_dict['DROCC'] = DROCC
        model_dict['GOAD'] = GOAD
        model_dict['ICL'] = ICL
        model_dict['PlanarFlow'] = FlowModel
        model_dict['VAE'] = Vae
        model_dict['GANomaly'] = GANomaly
        model_dict['SLAD'] = SLAD
        model_dict['DIF'] = DIF
        model_dict['DDPM'] = DDPM
        model_dict['DTE-IG'] = DTEInverseGamma
        model_dict['DTE-C'] = DTECategorical
        model_dict['DTENonParametric'] = DTENonParametric
        model_dict['MCM'] = MCM
        model_dict['3WNCROD'] = WNCROD
        model_dict['ODIN'] = ODIN
        
        # Create dataframes to save the results
        aucroc_name = dir + str(seed) + "_AUCROC.csv"
        aucpr_name = dir + str(seed) + "_AUCPR.csv"
        f1_name = dir + str(seed) + "_AUCF1.csv"
        train_name = dir + str(seed) + "_TrainTime.csv"
        inference_name = dir + str(seed) + "_InferenceTime.csv"
        
        try:
            df_AUCROC = pd.read_csv(aucroc_name, index_col = 0) 
        except:
            df_AUCROC = pd.DataFrame(data=None)
        try:
            df_AUCPR = pd.read_csv(aucpr_name, index_col = 0)
        except:
            df_AUCPR = pd.DataFrame(data=None)
        try:
            df_F1 = pd.read_csv(f1_name, index_col = 0)
        except:
            df_F1 = pd.DataFrame(data=None)
        try:
            df_train = pd.read_csv(train_name, index_col = 0)
        except:
            df_train = pd.DataFrame(data=None)
        try:
            df_inference = pd.read_csv(inference_name, index_col = 0)
        except:
            df_inference = pd.DataFrame(data=None)

        
        
        # Get the datasets from ADBench
        for dataset_list in [datagenerator.dataset_list_classical, datagenerator.dataset_list_cv, datagenerator.dataset_list_nlp]:
            for dataset in dataset_list:             
                print(dataset)
                
                # data_dir = f"{save_path}/{dataset}/seed_{seed}.pkl"
                data_dir = f"./data/{dataset}/seed_{seed}.pkl"

                with (open(f"{data_dir}" , "rb")) as data_file:
                    data = pickle.load(data_file)
    
                for name, clf in model_dict.items():
                    print(data_dir, name)
                    if start == 1 or (dataset == '3_backdoor' and name == 'NCSBADVAL'):
                        start = 1    
                    
                        # model initialization
                        print(name)
                        if name == "MCM":
                            mean_mse_auc , mean_mse_pr , mean_mse_f1, time_fit, time_inference = MCM(data_dir)
                            df_F1.loc[dataset, name] = mean_mse_f1
                            df_AUCROC.loc[dataset, name] = mean_mse_auc
                            df_AUCPR.loc[dataset, name] = mean_mse_pr
                        else:
                            if name == "VAE":
                                clf = clf(seed=seed, model_name=name, num_features=data['X_train'].shape[-1])
                            elif name == "3WNCROD":
                                pass
                            elif name == "ODIN":
                                clf = clf()
                            elif name == "NCSBAD" or name == "NCSBADVAL":
                                pass
                            else:
                                clf = clf(seed=seed, model_name=name)
                            
                            # training, for unsupervised models the y label will be discarded
                            start_time = time.time()
                            if name == "3WNCROD":       #distance based, no training
                                pass
                            elif name == "ODIN":
                                clf = clf.fit(data['X_test'])
                            elif name == "NCSBAD":
                                net, name_m = Sb_fit(data['X_train'], dataset)
                            elif name == "NCSBADVAL":
                                net, name_m = Sb_fit_val(data['X_train'], data['X_val'], data['y_val'], dataset)
                            else:
                                clf = clf.fit(data['X_train'], np.zeros_like(data['y_train']))
                            end_time = time.time(); time_fit = end_time - start_time 
                            
                            start_time = time.time()
                            if name == 'DAGMM':
                                score = clf.predict_score(data['X_train'], data['X_test'])
                            elif name == '3WNCROD':
                                score = WNCROD(data['X_train'], data['X_test'])
                            elif name == 'ODIN':
                                score = clf.decision_scores_
                            elif name == "NCSBAD":
                                score = Sb_predict_score(data['X_test'], 1, net, dataset, name_m, )
                            elif name == "NCSBADVAL":
                                score = Sb_predict_score_val(data['X_test'], 1, net, dataset, name_m, )
                            else:
                                score = clf.predict_score(data['X_test'])
                            end_time = time.time(); time_inference = end_time - start_time
                        
                            indices = np.arange(len(data['y_test']))
                            p = low_density_anomalies(-score, len(indices[data['y_test']==1]))
                            f1_score = skm.f1_score(data['y_test'], p)
                            print('F1 score: ' + str(f1_score))


                            inds = np.where(np.isnan(score))
                            score[inds] = 0
                        
                            result = utils.metric(y_true=data['y_test'], y_score=score)
                            print('AUCROC: ' + str(result['aucroc']))
                        
                            # save results
                            df_F1.loc[dataset, name] = f1_score
                            

                            df_AUCROC.loc[dataset, name] = result['aucroc']
                            df_AUCPR.loc[dataset, name] = result['aucpr']
                            
                        df_train.loc[dataset, name] = time_fit
                        df_inference.loc[dataset, name] = time_inference

                        df_F1.to_csv(f1_name)
                        
                        df_AUCROC.to_csv(aucroc_name)
                        df_AUCPR.to_csv(aucpr_name)
                        
                        df_train.to_csv(train_name)
                        df_train.to_csv(train_name)
                        
                        df_inference.to_csv(inference_name)
                        df_inference.to_csv(inference_name)

                        
        # Get the datasets from Additional
        for dataset in ['Parkinson', 'abalone', 'arrhythmia', 'ecoli', 'hrss_anomalous_optimized', 'hrss_anomalous_standard',
                         'mif', 'miv', 'mulcross', 'nasa', 'pen-global', 'pen-local', 'seismic-bumps', 'wbc2', 'yeast6']:
            print(dataset)
            
            # data_dir = f"{save_path}/{dataset}/seed_{seed}.pkl"
            data_dir = f"./data_add/{dataset}/seed_{seed}.pkl"

            with (open(f"{data_dir}" , "rb")) as data_file:
                data = pickle.load(data_file)

            for name, clf in model_dict.items():
                print(data_dir, name)
                if start == 1 or (dataset == 'Parkinson' and name == 'NCSBAD'):
                    start = 1    
                    if (dataset == 'mulcross' and name == 'CBLOF') or (dataset == 'seismic-bumps' and name == 'PCA'): # CBLOF doesn't work on mulcross, PCA doesn't work on seismic-bumps
                        time_fit = 0.0
                        time_inference = 0.0
                        df_F1.loc[dataset, name] = 0.0
                        df_AUCROC.loc[dataset, name] = 0.0
                        df_AUCPR.loc[dataset, name] = 0.0
                    else:
                        # model initialization
                        print(name)
                        if name == "MCM":
                            mean_mse_auc , mean_mse_pr , mean_mse_f1, time_fit, time_inference = MCM(data_dir)
                            df_F1.loc[dataset, name] = mean_mse_f1
                            df_AUCROC.loc[dataset, name] = mean_mse_auc
                            df_AUCPR.loc[dataset, name] = mean_mse_pr
                        else:
                            if name == "VAE":
                                clf = clf(seed=seed, model_name=name, num_features=data['X_train'].shape[-1])
                            elif name == "3WNCROD":
                                pass
                            elif name == "ODIN":
                                clf = clf()
                            elif name == "NCSBAD" or name == "NCSBADVAL":
                                pass
                            else:
                                clf = clf(seed=seed, model_name=name)
                            
                            # training, for unsupervised models the y label will be discarded
                            start_time = time.time()
                            if name == "3WNCROD":
                                pass
                            elif name == "ODIN":
                                clf = clf.fit(data['X_test'])
                            elif name == "NCSBAD":
                                net, name_m = Sb_fit(data['X_train'], dataset)
                            elif name == "NCSBADVAL":
                                net, name_m = Sb_fit_val(data['X_train'], data['X_val'], data['y_val'], dataset)
                            else:
                                # print(data['X_train'].shape)
                                clf = clf.fit(data['X_train'], np.zeros_like(data['y_train']))
                            end_time = time.time(); time_fit = end_time - start_time 
                            
                            start_time = time.time()
                            if name == 'DAGMM':
                                score = clf.predict_score(data['X_train'], data['X_test'])
                            elif name == '3WNCROD':
                                score = WNCROD(data['X_train'], data['X_test'])
                            elif name == 'ODIN':
                                score = clf.decision_scores_
                            elif name == "NCSBAD":
                                score = Sb_predict_score(data['X_test'], 1, net, dataset, name_m, )
                            elif name == "NCSBADVAL":
                                score = Sb_predict_score_val(data['X_test'], 1, net, dataset, name_m, )
                            else:
                                score = clf.predict_score(data['X_test'])
                            end_time = time.time(); time_inference = end_time - start_time

                            # print(score)
                        
                            indices = np.arange(len(data['y_test']))
                            p = low_density_anomalies(-score, len(indices[data['y_test']==1]))
                            f1_score = skm.f1_score(data['y_test'], p)
                            print('F1 score: ' + str(f1_score))


                            inds = np.where(np.isnan(score))
                            score[inds] = 0
                        
                            result = utils.metric(y_true=data['y_test'], y_score=score)
                            print('AUCROC: ' + str(result['aucroc']))
                        
                            # save results
                            df_F1.loc[dataset, name] = f1_score
                            

                            df_AUCROC.loc[dataset, name] = result['aucroc']
                            df_AUCPR.loc[dataset, name] = result['aucpr']
                            
                    df_train.loc[dataset, name] = time_fit
                    df_inference.loc[dataset, name] = time_inference

                    df_F1.to_csv(f1_name)
                    
                    df_AUCROC.to_csv(aucroc_name)
                    df_AUCPR.to_csv(aucpr_name)
                    
                    df_train.to_csv(train_name)
                    df_train.to_csv(train_name)
                    
                    df_inference.to_csv(inference_name)
                    df_inference.to_csv(inference_name)
            
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Settings')
    parser.add_argument('--start_seed', type=int, 
        default=0, help='first random seed')
    parser.add_argument('--end_seed', type=int, 
        default=5, help='last random seed -1')
    parser.add_argument('--save_path', type=str, 
        default='./data', help='folder to saved data files')
    parser.add_argument('--results_path', type=str, 
        default='./results/all/', help='folder to saved data files')

    args = parser.parse_args()
    main(args)
