################################################################################################
################################################################################################
# based on https://github.com/vicliv/DTE, adapted by first author
################################################################################################
################################################################################################

from adbench.baseline.PyOD import PYOD
from baselines.GANomaly.run import GANomaly
from baselines.dagmm import DAGMM
from baselines.drocc import DROCC
from baselines.normalizing_flow import FlowModel
from baselines.goad import GOAD
from baselines.icl import ICL
from baselines.vae import Vae
from baselines.slad.slad import SLAD
from baselines.deep_iforest.dif import DIF
from baselines.MCM.main import MCM
from baselines.new_pyod.PyOD_new import PYOD_new
from baselines.new_pyod.additional_methods import PYOD_like

from diffusion.dte import DTECategorical, DTEInverseGamma
from diffusion.non_param_dte import DTENonParametric
from diffusion.ddpm import DDPM

from NCSBAD.main import Sb_fit, Sb_predict_score

import argparse
import numpy as np
import os
import pandas as pd
import pickle
import time

from adbench.myutils import Utils

import sklearn.metrics as skm
from data_generator import DataGenerator


def low_density_anomalies(test_log_probs, num_anomalies):
    """ Helper function for the F1-score, selects the num_anomalies lowest values of test_log_prob
    """
    anomaly_indices = np.argpartition(test_log_probs, num_anomalies-1)[:num_anomalies]
    preds = np.zeros(len(test_log_probs))
    preds[anomaly_indices] = 1
    return preds

def main(args):
    start_seed = args.start_seed
    end_seed = args.end_seed
    save_path = args.save_path
    dir = args.results_path
    start = 1
    
    for seed in range(start_seed, end_seed):        
        if not os.path.exists(dir):
            os.makedirs(dir)

        datagenerator = DataGenerator(seed = seed, test_size=0.5, normal=True) # data generator
        
        utils = Utils() # utils function
        utils.set_seed(seed)
        
        model_dict = {}

        model_dict['NCSBAD'] = None

        # # Select models
        for _ in ['KNN', 'LOF', 'PCA']:
            model_dict[_] = PYOD

        for _ in ['KPCA', 'KDE', 'GMM', 'LUNAR', 'AE']: 
            model_dict[_] = PYOD_new

        for _ in ['EIF', 'ABOD']: 
            model_dict[_] = PYOD_like

        model_dict['DROCC'] = DROCC
        model_dict['GOAD'] = GOAD
        model_dict['ICL'] = ICL
        model_dict['SLAD'] = SLAD
        model_dict['DDPM'] = DDPM
        model_dict['DTE-C'] = DTECategorical
        model_dict['DTENonParametric'] = DTENonParametric
        model_dict['MCM'] = MCM
        
        # Create dataframes to save the results
        aucroc_name = dir + str(seed) + "_AUCROC.csv"
        aucpr_name = dir + str(seed) + "_AUCPR.csv"
        f1_name = dir + str(seed) + "_AUCF1.csv"
        train_name = dir + str(seed) + "_TrainTime.csv"
        inference_name = dir + str(seed) + "_InferenceTime.csv"
        
        try:
            df_AUCROC = pd.read_csv(aucroc_name, index_col = 0) 
        except:
            df_AUCROC = pd.DataFrame(data=None)
        try:
            df_AUCPR = pd.read_csv(aucpr_name, index_col = 0)
        except:
            df_AUCPR = pd.DataFrame(data=None)
        try:
            df_F1 = pd.read_csv(f1_name, index_col = 0)
        except:
            df_F1 = pd.DataFrame(data=None)
        try:
            df_train = pd.read_csv(train_name, index_col = 0)
        except:
            df_train = pd.DataFrame(data=None)
        try:
            df_inference = pd.read_csv(inference_name, index_col = 0)
        except:
            df_inference = pd.DataFrame(data=None)

        
        
        # Get the datasets from ADBench
        for dataset_list in [datagenerator.dataset_list_classical, datagenerator.dataset_list_cv, datagenerator.dataset_list_nlp]:
            for dataset in dataset_list:             
                print(dataset)
                
                data_dir = f"./{save_path}/{dataset}/seed_{seed}.pkl"

                with (open(f"{data_dir}" , "rb")) as data_file:
                    data = pickle.load(data_file)
    
                for name, clf in model_dict.items():
                    print(data_dir, name)
                    if start == 1 or (dataset == 'MNIST-C_zigzag' and name == 'DROCC'):
                        start = 1    
                    
                        # model initialization
                        print(name)
                        if name == "MCM":
                            score, score2, mean_mse_auc , mean_mse_pr , mean_mse_f1, time_fit, time_inference = MCM(data_dir)
                            score = score.flatten()
                            # df_F1.loc[dataset, name] = mean_mse_f1
                            # df_AUCROC.loc[dataset, name] = mean_mse_auc
                            # df_AUCPR.loc[dataset, name] = mean_mse_pr
                        else:
                            if name == "NCSBAD":
                                pass
                            else:
                                clf = clf(seed=seed, model_name=name)
                            
                            # training, for unsupervised models the y label will be discarded
                            start_time = time.time()
                            if name == "NCSBAD":
                                net, name_m = Sb_fit(data['X_train'], dataset)
                            else:
                                clf = clf.fit(data['X_train'], np.zeros_like(data['y_train']))
                            end_time = time.time(); time_fit = end_time - start_time 
                            
                            start_time = time.time()
                            if name == "NCSBAD":
                                score = Sb_predict_score(data['X_test'], 1, net, dataset, name_m, )
                            else:
                                score = clf.predict_score(data['X_test'])
                            end_time = time.time(); time_inference = end_time - start_time
                        
                        indices = np.arange(len(data['y_test']))
                        p = low_density_anomalies(-score, len(indices[data['y_test']==1]))
                        f1_score = skm.f1_score(data['y_test'], p)
                        print('F1 score: ' + str(f1_score))


                        inds = np.where(np.isnan(score))
                        score[inds] = 0
                    
                        result = utils.metric(y_true=data['y_test'], y_score=score)
                        print('AUCROC: ' + str(result['aucroc']))
                    
                        # save results
                        df_F1.loc[dataset, name] = f1_score
                        

                        df_AUCROC.loc[dataset, name] = result['aucroc']
                        df_AUCPR.loc[dataset, name] = result['aucpr']
                            
                        df_train.loc[dataset, name] = time_fit
                        df_inference.loc[dataset, name] = time_inference

                        df_F1.to_csv(f1_name)
                        
                        df_AUCROC.to_csv(aucroc_name)
                        df_AUCPR.to_csv(aucpr_name)
                        
                        df_train.to_csv(train_name)
                        df_train.to_csv(train_name)
                        
                        df_inference.to_csv(inference_name)
                        df_inference.to_csv(inference_name)

            
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Settings')
    parser.add_argument('--start_seed', type=int, 
        default=0, help='first random seed')
    parser.add_argument('--end_seed', type=int, 
        default=5, help='last random seed -1')
    parser.add_argument('--save_path', type=str, 
        default='./data', help='folder to saved data files')
    parser.add_argument('--results_path', type=str, 
        default='./DTE/results/test/', help='folder to saved data files')

    args = parser.parse_args()
    main(args)
