################################################################################################
################################################################################################
# based on https://github.com/yzhao062/pyod/tree/master?tab=readme-ov-file#kingma2013auto, 
# Version 2.0.2 adapted by first author, for running in this benchmark and environment
################################################################################################
################################################################################################

#Orig:
from adbench.myutils import Utils
import numpy as np

#add the baselines from the pyod package
from baselines.new_pyod.inne import INNE
from baselines.new_pyod.kpca import KPCA
from baselines.new_pyod.kde import KDE
from baselines.new_pyod.gmm import GMM
from baselines.new_pyod.cblof import CBLOF
from baselines.new_pyod.sod import SOD
from baselines.new_pyod.lunar import LUNAR
from baselines.new_pyod.so_gaal import SO_GAAL
from baselines.new_pyod.alad import ALAD
from baselines.new_pyod.auto_encoder import AutoEncoder
from baselines.new_pyod.cd import CD
from baselines.new_pyod.loci import LOCI
from baselines.new_pyod.lscp import LSCP
from baselines.new_pyod.mad import MAD
from baselines.new_pyod.mo_gaal import MO_GAAL
from baselines.new_pyod.qmcd import QMCD
from baselines.new_pyod.rgraph import RGraph
from baselines.new_pyod.sampling import Sampling
from baselines.new_pyod.sos import SOS


class PYOD_new():
    def __init__(self, seed, model_name, tune=False):
        '''
        :param seed: seed for reproducible results
        :param model_name: model name
        :param tune: if necessary, tune the hyper-parameter based on the validation set constructed by the labeled anomalies
        '''
        self.seed = seed
        self.utils = Utils()

        self.model_name = model_name
        self.model_dict = {'INNE':INNE, 'KPCA':KPCA, 'KDE':KDE, 'GMM':GMM, 'CBLOF':CBLOF, 'SOD':SOD, 'LUNAR':LUNAR,
                           'SOGAAL':SO_GAAL, 'ALAD':ALAD, 'AE':AutoEncoder, 'CD':CD, 'LOCI':LOCI, 'LSCP':LSCP, 'MAD': MAD,
                           'MOGAAL':MO_GAAL,'QMCD': QMCD, 'RGraph':RGraph, 'Sampling':Sampling, 'SOS':SOS}

        self.tune = tune

    def grid_hp(self, model_name):
        '''
        define the hyper-parameter search grid for different unsupervised mdoel
        '''

        param_grid_dict = {'INNE': None, # n_estimators, default=200
                            'KPCA': None, # n_estimators, default=200   
                            'KDE': None,
                            'GMM': None, 
                            'CBLOF': None,    
                            'SOD': None, 
                            'LUNAR': None,    
                            'SOGAAL': None,   
                            'ALAD': None,   
                            'AE': None, 
                            'CD': None,
                            'LOCI': None, 
                            'LSCP': None,
                            'MAD': None,
                            'MOGAAL': None,
                            'RGraph': None,
                            'QMCD': None,  
                            'Sampling': None, 
                            'SOS': None,                      
                           }

        return param_grid_dict[model_name]

    def grid_search(self, X_train, y_train, ratio=None):
        '''
        implement the grid search for unsupervised models and return the best hyper-parameters
        the ratio could be the ground truth anomaly ratio of input dataset
        '''

        # set seed
        self.utils.set_seed(self.seed)
        # get the hyper-parameter grid
        param_grid = self.grid_hp(self.model_name)

        if param_grid is not None:
            # index of normal ana abnormal samples
            idx_a = np.where(y_train==1)[0]
            idx_n = np.where(y_train==0)[0]
            idx_n = np.random.choice(idx_n, int((len(idx_a) * (1-ratio)) / ratio), replace=True)

            idx = np.append(idx_n, idx_a) #combine
            np.random.shuffle(idx) #shuffle

            # valiation set (and the same anomaly ratio as in the original dataset)
            X_val = X_train[idx]
            y_val = y_train[idx]

            # fitting
            metric_list = []
            for param in param_grid:
                try:
                    if self.model_name == 'INNE':
                        model = self.model_dict[self.model_name](n_estimators=param).fit(X_train)

                    else:
                        raise NotImplementedError

                except:
                    metric_list.append(0.0)
                    continue

                try:
                    # model performance on the validation set
                    score_val = model.decision_function(X_val)
                    metric = self.utils.metric(y_true=y_val, y_score=score_val, pos_label=1)
                    metric_list.append(metric['aucpr'])

                except:
                    metric_list.append(0.0)
                    continue

            best_param = param_grid[np.argmax(metric_list)]

        else:
            metric_list = None
            best_param = None

        print(f'The candidate hyper-parameter of {self.model_name}: {param_grid},',
              f' corresponding metric: {metric_list}',
              f' the best candidate: {best_param}')

        return best_param

    def fit(self, X_train, y_train, ratio=None):
        if self.model_name in ['AutoEncoder', 'VAE']:
            # only use the normal samples to fit the model
            idx_n = np.where(y_train==0)[0]
            X_train = X_train[idx_n]
            y_train = y_train[idx_n]

        # selecting the best hyper-parameters of unsupervised model for fair comparison (if labeled anomalies is available)
        if sum(y_train) > 0 and self.tune:
            assert ratio is not None
            best_param = self.grid_search(X_train, y_train, ratio)
        else:
            best_param = None

        # print(f'best param: {best_param}')

        # set seed
        self.utils.set_seed(self.seed)

        # fit best on the best param
        if best_param is not None:
            if self.model_name == 'IForest':
                self.model = self.model_dict[self.model_name](n_estimators=best_param).fit(X_train)

            elif self.model_name == 'OCSVM':
                self.model = self.model_dict[self.model_name](kernel=best_param).fit(X_train)

            elif self.model_name == 'ABOD':
                self.model = self.model_dict[self.model_name](n_neighbors=best_param).fit(X_train)

            elif self.model_name == 'CBLOF':
                self.model = self.model_dict[self.model_name](n_clusters=best_param).fit(X_train)

            elif self.model_name == 'COF':
                self.model = self.model_dict[self.model_name](n_neighbors=best_param).fit(X_train)

            elif self.model_name == 'FeatureBagging':
                self.model = self.model_dict[self.model_name](n_estimators=best_param).fit(X_train)

            elif self.model_name == 'HBOS':
                self.model = self.model_dict[self.model_name](n_bins=best_param).fit(X_train)

            elif self.model_name == 'KNN':
                self.model = self.model_dict[self.model_name](n_neighbors=best_param).fit(X_train)

            elif self.model_name == 'LMDD':
                self.model = self.model_dict[self.model_name](dis_measure=best_param).fit(X_train)

            elif self.model_name == 'LODA':
                self.model = self.model_dict[self.model_name](n_bins=best_param).fit(X_train)

            elif self.model_name == 'LOF':
                self.model = self.model_dict[self.model_name](n_neighbors=best_param).fit(X_train)

            elif self.model_name == 'LOCI':
                self.model = self.model_dict[self.model_name](alpha=best_param).fit(X_train)

            elif self.model_name == 'LSCP':
                self.model = self.model_dict[self.model_name](detector_list=[LOF(), LOF()], n_bins=best_param).fit(X_train)

            elif self.model_name == 'PCA':
                self.model = self.model_dict[self.model_name](n_components=best_param).fit(X_train)

            elif self.model_name == 'SOD':
                self.model = self.model_dict[self.model_name](n_neighbors=best_param).fit(X_train)

            elif self.model_name == 'SOS':
                self.model = self.model_dict[self.model_name](perplexity=best_param).fit(X_train)

            elif self.model_name == 'SOGAAL':
                self.model = self.model_dict[self.model_name](stop_epochs=best_param).fit(X_train)

            elif self.model_name == 'MOGAAL':
                self.model = self.model_dict[self.model_name](stop_epochs=best_param).fit(X_train)

            elif self.model_name == 'DeepSVDD':
                self.model = self.model_dict[self.model_name](epochs=best_param).fit(X_train)

            else:
                raise NotImplementedError

        else:
            # unsupervised method would ignore the y labels
            self.model = self.model_dict[self.model_name]().fit(X_train, y_train)

        return self

    # from pyod: for consistency, outliers are assigned with larger anomaly scores
    def predict_score(self, X):
        score = self.model.decision_function(X)
        return score