################################################################################################
################################################################################################
# based on https://github.com/yzhao062/pyod/tree/master?tab=readme-ov-file#kingma2013auto, 
# Version 2.0.2 adapted by first author, for running in this benchmark and environment
################################################################################################
################################################################################################


from adbench.myutils import Utils
import numpy as np

#add the baselines from the pyod package

from baselines.new_pyod.additional.ExtendedIForest import ExtendedIForest
from baselines.new_pyod.additional.combination import maximization
from baselines.new_pyod.additional.ensemble import Ensemble
from baselines.new_pyod.additional.gen2out import gen2Out
from baselines.new_pyod.additional.HBOS import DynamicHBOS
from baselines.new_pyod.additional.cof import COF
from baselines.new_pyod.additional.abod import ABOD
from baselines.new_pyod.additional.lmdd import LMDD




class PYOD_like():
    def __init__(self, seed, model_name, tune=False):
        '''
        :param seed: seed for reproducible results
        :param model_name: model name
        :param tune: if necessary, tune the hyper-parameter based on the validation set constructed by the labeled anomalies
        '''
        self.seed = seed
        self.utils = Utils()

        self.model_name = model_name
        self.model_dict = {'EIF':ExtendedIForest, 'Ensemble':Ensemble, 'GEN2OUT':gen2Out, 'DynamicHBOS':DynamicHBOS, 'COF':COF, 'ABOD':ABOD, 'LMDD':LMDD}

        self.tune = tune

    def grid_hp(self, model_name):
        '''
        define the hyper-parameter search grid for different unsupervised mdoel
        '''

        param_grid_dict = {'EIF': [1000], # n_estimators, default=200   
                            'Ensemble': None,      
                            'GEN2OUT': None,
                            'DynamicHBOS': None,   
                            'COF': None,
                            'ABOD': None,   
                            'LMDD': None,                       
                           }

        return param_grid_dict[model_name]

    def grid_search(self, X_train, y_train, ratio=None):
        '''
        implement the grid search for unsupervised models and return the best hyper-parameters
        the ratio could be the ground truth anomaly ratio of input dataset
        '''

        # set seed
        self.utils.set_seed(self.seed)
        # get the hyper-parameter grid
        param_grid = self.grid_hp(self.model_name)

        if param_grid is not None:
            # index of normal ana abnormal samples
            idx_a = np.where(y_train==1)[0]
            idx_n = np.where(y_train==0)[0]
            idx_n = np.random.choice(idx_n, int((len(idx_a) * (1-ratio)) / ratio), replace=True)

            idx = np.append(idx_n, idx_a) #combine
            np.random.shuffle(idx) #shuffle

            # valiation set (and the same anomaly ratio as in the original dataset)
            X_val = X_train[idx]
            y_val = y_train[idx]

            # fitting
            metric_list = []
            for param in param_grid:
                try:
                    if self.model_name == 'INNE':
                        model = self.model_dict[self.model_name](n_estimators=param).fit(X_train)

                    else:
                        raise NotImplementedError

                except:
                    metric_list.append(0.0)
                    continue

                try:
                    # model performance on the validation set
                    score_val = model.decision_function(X_val)
                    metric = self.utils.metric(y_true=y_val, y_score=score_val, pos_label=1)
                    metric_list.append(metric['aucpr'])

                except:
                    metric_list.append(0.0)
                    continue

            best_param = param_grid[np.argmax(metric_list)]

        else:
            metric_list = None
            best_param = None

        print(f'The candidate hyper-parameter of {self.model_name}: {param_grid},',
              f' corresponding metric: {metric_list}',
              f' the best candidate: {best_param}')

        return best_param

    def fit(self, X_train, y_train, ratio=None):
        if self.model_name in ['AutoEncoder', 'VAE']:
            # only use the normal samples to fit the model
            idx_n = np.where(y_train==0)[0]
            X_train = X_train[idx_n]
            y_train = y_train[idx_n]

        # selecting the best hyper-parameters of unsupervised model for fair comparison (if labeled anomalies is available)
        if sum(y_train) > 0 and self.tune:
            assert ratio is not None
            best_param = self.grid_search(X_train, y_train, ratio)
        else:
            best_param = None

        # print(f'best param: {best_param}')

        # set seed
        self.utils.set_seed(self.seed)

        # fit best on the best param
        if best_param is not None:
            if self.model_name == 'IForest':
                self.model = self.model_dict[self.model_name](n_estimators=best_param).fit(X_train)

            elif self.model_name == 'OCSVM':
                self.model = self.model_dict[self.model_name](kernel=best_param).fit(X_train)

            elif self.model_name == 'ABOD':
                self.model = self.model_dict[self.model_name](n_neighbors=best_param).fit(X_train)

            elif self.model_name == 'CBLOF':
                self.model = self.model_dict[self.model_name](n_clusters=best_param).fit(X_train)

            elif self.model_name == 'COF':
                self.model = self.model_dict[self.model_name](n_neighbors=best_param).fit(X_train)

            elif self.model_name == 'FeatureBagging':
                self.model = self.model_dict[self.model_name](n_estimators=best_param).fit(X_train)

            elif self.model_name == 'HBOS':
                self.model = self.model_dict[self.model_name](n_bins=best_param).fit(X_train)

            elif self.model_name == 'KNN':
                self.model = self.model_dict[self.model_name](n_neighbors=best_param).fit(X_train)

            elif self.model_name == 'LMDD':
                self.model = self.model_dict[self.model_name](dis_measure=best_param).fit(X_train)

            elif self.model_name == 'LODA':
                self.model = self.model_dict[self.model_name](n_bins=best_param).fit(X_train)

            elif self.model_name == 'LOF':
                self.model = self.model_dict[self.model_name](n_neighbors=best_param).fit(X_train)

            elif self.model_name == 'LOCI':
                self.model = self.model_dict[self.model_name](alpha=best_param).fit(X_train)

            elif self.model_name == 'LSCP':
                self.model = self.model_dict[self.model_name](detector_list=[LOF(), LOF()], n_bins=best_param).fit(X_train)

            elif self.model_name == 'PCA':
                self.model = self.model_dict[self.model_name](n_components=best_param).fit(X_train)

            elif self.model_name == 'SOD':
                self.model = self.model_dict[self.model_name](n_neighbors=best_param).fit(X_train)

            elif self.model_name == 'SOS':
                self.model = self.model_dict[self.model_name](perplexity=best_param).fit(X_train)

            elif self.model_name == 'SOGAAL':
                self.model = self.model_dict[self.model_name](stop_epochs=best_param).fit(X_train)

            elif self.model_name == 'MOGAAL':
                self.model = self.model_dict[self.model_name](stop_epochs=best_param).fit(X_train)

            elif self.model_name == 'DeepSVDD':
                self.model = self.model_dict[self.model_name](epochs=best_param).fit(X_train)

            else:
                raise NotImplementedError

        else:
            # unsupervised method would ignore the y labels
            self.model = self.model_dict[self.model_name]().fit(X_train, y_train)

        return self

    # from pyod: for consistency, outliers are assigned with larger anomaly scores
    def predict_score(self, X):
        print(self.model)
        score = self.model.decision_function(X)
        return score