import copy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score
from sklearn.neighbors import KernelDensity
from sklearn.cluster import KMeans

from astropy.visualization import hist
from scipy.stats import wasserstein_distance
from scipy.special import softmax

from scipy.stats import levy_stable, cauchy, fisk, weibull_min, burr, pareto, chi2, norm

import torch
import pyod
##from pyod.models import iforest
##from pyod.models import copod

##from pyod.models import knn
##from pyod.models import lof

##from pyod.models import anogan
##from pyod.models import ecod

from pyod.models import ocsvm
from pyod.models import deep_svdd
from pyod.models import auto_encoder
from pyod.models import dif

from dataset import shallow_method_data_loader as data_loader
from supervised import Predictor
import utils, baselines

from dpad import DPAD

import optuna
from NeuTraLAD import NeuTraLAD
from HRN import HRN
from DROCC import DROCC
from PLAD import PLAD
from SCAD import SCAD


def min_max_standardize(data, feature_range=(0, 1)):
    """
    Applies Min-Max Standardization to the input data.

    Parameters:
    - data: np.array, the data to be standardized.
    - feature_range: tuple (min, max), default=(0, 1), the desired range of the transformed data.

    Returns:
    - standardized_data: np.array, the standardized data.
    """
    min_val = np.min(data)
    max_val = np.max(data)
    
    # Rescale the data to [0, 1]
    standardized_data = (data - min_val) / (max_val - min_val + 1e-9)

    standardized_data = standardized_data - standardized_data.mean()
    # Scale to the desired range
##    min_range, max_range = feature_range
##    standardized_data = standardized_data * (max_range - min_range) + min_range
    
    return standardized_data


def median_standardize(data):
    median_val = np.median(data)
    standardized_data = (data - median_val) / (median_val + 1e-9)
    return standardized_data


def normalize(data):
    return (data - data.mean()) / (data.std() + 1e-9)


def get_score_kde(scores, bandwidth, x):
    if bandwidth == 0.0:
        bandwidth = 0.004545
    kde = KernelDensity(kernel='gaussian', bandwidth=bandwidth)
    kde.fit(scores.reshape(-1, 1))
    dens = np.exp(kde.score_samples(x.reshape(-1, 1)))
    return dens

def calculate_delta(scores_, k, method='gap', temp=None, predictor=None):
    assert method in ['gap', 'k-gap', 'relative-k-gap', 'relative-gap', 'topk-median', 'wasserstein', 'kde-wasserstein',
                      'predictor', 'kde-wasserstein-v2', 'clustering', 'otsu', 'k-var', 'kde-l2',
                      'avg-var-otsu', 'avg-var-otsu-v2', 'avg-var-k-v2',
                      'relative-topk-median', 'avg-var-k', 'otsu-max-ind', 'avg-var-20-k']
    scores = copy.deepcopy(scores_)
##    scores = min_max_standardize(scores)
##    scores = normalize(scores)
    scores = np.sort(scores)[::-1]
    if method == 'gap':
        up = scores[k-1]
        down = scores[k]
        return up - down
    elif method == 'k-gap':
        up = np.mean(scores[:k])
        down = np.mean(scores[k:k+k])
        return up - down
    elif method == 'relative-k-gap':
        up = np.mean(scores[:k])
        down = np.mean(scores[k:k + k])
        delta = (up - down) / (down + 1e-9)
        return delta
    elif method == 'relative-gap':
        up = scores[k - 1]
        down = scores[k]
        delta = (up - down) / (down + 1e-9)
        return delta
    elif method == 'topk-median':
        up = np.mean(scores[:k])
        down = np.median(scores)
        return up - down
    elif method == 'relative-topk-median':
        if k < 1:
            n_sample = scores.shape[0]
            top_size = np.max([int(n_sample * k), 2])
        else:
            top_size = k
        up = np.mean(scores[:top_size])
        down = np.median(scores)
        delta = (up - down) / (abs(down) + 1e-9)

        if np.isnan(delta):
            return 0
        if delta > 1000000:
            return 0
        return delta
    elif method == 'wasserstein':
        n_samples = len(scores)
        count, bins, _ = hist(scores, bins=10,)
        distribution = count / n_samples
        bins = bins[:-1]
        temp_distribution, temp_bins = temp
        return wasserstein_distance(bins, temp_bins, distribution, temp_distribution)
    elif method == 'kde-wasserstein':
        scores = softmax(scores)
        bandwidth = (np.max(scores) - np.min(scores)) / 20
        kde = KernelDensity(kernel='gaussian', bandwidth=bandwidth)
        kde.fit(scores.reshape(-1, 1))

        x = np.linspace(np.min(scores), np.max(scores), 1000)
        dens = np.exp(kde.score_samples(x.reshape(-1, 1)))

        # temp_x = np.linspace(0, 1, 1000)
        temp_x = x
        temp_dens = temp(x)
        try:
            dist = wasserstein_distance(x, temp_x, dens+1e-15, temp_dens)
        except ValueError:
            return 10
        if dist > 10:
            return 10
        return dist
    elif method == 'kde-wasserstein-v2':
        v = np.median(scores)
        scores = softmax(scores)
        bandwidth = (np.max(scores) - np.min(scores)) / 20
        x = np.linspace(0, 1, 1000)
        dens = get_score_kde(scores, bandwidth, x)

        n_sample = scores.shape[0]
        temp = [v] * int(n_sample * 0.95) + [v + 100] * int(n_sample * 0.05)
        temp = softmax(np.array(temp))
        bandwidth = (np.max(temp) - np.min(temp)) / 20
        temp_dens = get_score_kde(temp, bandwidth, x)

        dist = wasserstein_distance(x, x, dens+1e-15, temp_dens)
        return dist
    elif method == 'clustering':
        km = KMeans(2, init=np.array([[scores.min()], [scores.max()]]))
        km.fit(scores.reshape((-1, 1)))
        left = scores[km.labels_ == 0]
        right = scores[km.labels_ == 1]
        intra_var, inter_var = utils.get_intra_inter(left, right)
        if intra_var == 0:
            return 0
        delta = inter_var / intra_var
        return delta
    elif method == 'otsu':
        intra_var, inter_var = utils.otsu_thresholding_1d(scores)
        if intra_var == 0:
            return 0
        delta = inter_var / intra_var
        return delta
    elif method == 'k-var':
        if k < 1:
            n_sample = scores.shape[0]
            left_size = np.max([int(n_sample * k), 2])
        else:
            left_size = k
        left, right = scores[:left_size], scores[left_size:]
        intra_var, inter_var = utils.get_intra_inter(left, right)
        if intra_var == 0:
            return 0
        delta = inter_var / (intra_var + 1e-15)
        if delta > 100:
            return 0
        return delta
    elif method == 'avg-var-k':
        if k < 1:
            n_sample = scores.shape[0]
            head = np.max([int(n_sample * k), 2])
        else:
            head = k
        opt_ind, intra_var, inter_var = utils.otsu_thresholding_unbiased(scores)
        end = int(scores.size * 0.10)
        delta = inter_var[head: end+1] / (intra_var[head: end+1] + 1e-15)

        # delta = np.mean(delta[delta < 100])
        # delta = np.mean(delta[delta < 10000])
        delta = np.mean(delta)
        if np.isnan(delta):
            return 0
        return delta
    elif method == 'avg-var-20-k':
        if k < 1:
            n_sample = scores.shape[0]
            head = np.max([int(n_sample * k), 2])
        else:
            head = k
        opt_ind, intra_var, inter_var = utils.otsu_thresholding_unbiased(scores)
        end = int(scores.size * 0.20)
        delta = inter_var[head: end+1] / (intra_var[head: end+1] + 1e-15)

        # delta = np.mean(delta[delta < 1000000])
        delta = np.mean(delta)
        if np.isnan(delta):
            return 0
        return delta

    elif method == 'avg-var-otsu':
        opt_ind, intra_var, inter_var = utils.otsu_thresholding_unbiased(scores)
        head = int(scores.size * 0.05)
        delta = inter_var[min(opt_ind, head): max(opt_ind, head)+1] / (intra_var[min(opt_ind, head): max(opt_ind, head)+1] + 1e-15)
        delta = np.mean(delta[delta < 100])
        if np.isnan(delta):
            return 0
        return delta
    elif method == 'avg-var-otsu-v2':
        opt_ind, intra_var, inter_var = utils.otsu_thresholding_unbiased(scores)
        head = int(scores.size * 0.05)
        
        mean_inter_var = np.mean(inter_var[min(opt_ind, head): max(opt_ind, head)+1])
        mean_intra_var = np.mean(intra_var[min(opt_ind, head): max(opt_ind, head)+1])
        delta = mean_inter_var / (mean_intra_var + 1e-15)

        return delta
    elif method == 'otsu-max-ind':
        opt_ind, _, _ = utils.otsu_thresholding_unbiased(scores)
        size = scores.shape[0]
        if opt_ind == 0:
            return 0
        return (size - opt_ind) / size
    elif method == 'kde-l2':
        bandwidth = (np.max(scores) - np.min(scores)) / 20
        if bandwidth == 0.0:
            bandwidth = 0.004545
        kde = KernelDensity(kernel='gaussian', bandwidth=bandwidth)
        kde.fit(scores.reshape(-1, 1))
        x = np.linspace(np.min(scores), np.max(scores), 1000)
        dens = np.exp(kde.score_samples(x.reshape(-1, 1)))

        scale = norm.pdf(0) / np.exp(kde.score_samples([[np.median(scores)]]))[0]
        target = norm.pdf(x, loc=np.median(scores), scale=scale)
        delta = np.linalg.norm(dens - target)
        if delta < 1e-5:
            return 100
        return delta

    elif method == 'kde-kl':
        scores = min_max_standardize(scores)
        kde = KernelDensity(kernel='gaussian', bandwidth=0.1)
        kde.fit(scores.reshape(-1, 1))

        x = np.linspace(0, np.max(scores) + 1, 1000)
        dens = np.exp(kde.score_samples(x.reshape(-1, 1)))
        temp_dens = np.exp(temp.score_samples(x.reshape(-1, 1)))

        epsilon = 1e-10
        kl_divergence = np.sum(dens * np.log((dens + epsilon) / (temp_dens + epsilon)) * (x[1] - x[0]))
        return kl_divergence
    elif method == 'predictor':
        # scores = median_standardize(scores)
        scores = softmax(scores)

        min_val = np.min(scores)
        max_val = np.max(scores)

        # bandwidth = (np.max(scores) - np.min(scores)) / 20
        kde = KernelDensity(kernel='gaussian', bandwidth=0.1)
        kde.fit(scores.reshape(-1, 1))

        # x = np.linspace(np.min(scores), np.max(scores), 1000)
        x = np.linspace(0, 1, 1000)
        dens = np.exp(kde.score_samples(x.reshape(-1, 1)))
        return predictor.eval(dens, (min_val, max_val), pos=False, normalize=True)
    else:
        raise NotImplementedError

def f1_calculator(targets, score):
    # count how many positive sample
    n_possitive = sum(targets.reshape(-1) == 1)
    tp_plus_fp_idx = np.argsort(-score)[:n_possitive]
    tp_plus_fp = targets[tp_plus_fp_idx]

    tp = tp_plus_fp[tp_plus_fp == 1]
    fp = tp_plus_fp[tp_plus_fp == 0]

    f1 = len(tp) / (len(tp) + len(fp))
    return f1


def get_temp_pdf(name='handcraft'):
    if name == 'burr':
        return lambda x: burr.pdf(x, 2, 0.4, loc=0, scale=0.5)
    elif name == 'log-logistic':
        return lambda x: fisk.pdf(x, 2, loc=0, scale=0.5)
    elif name == 'log-cauchy':
        return lambda x: np.exp(cauchy.pdf(x, loc=0, scale=0.2))
    elif name == 'levy':
        return lambda x: levy_stable.pdf(x, 0.5, 1, loc=0, scale=0.1)
    elif name == 'weibull':
        return lambda x: weibull_min.pdf(x, 0.7, loc=0, scale=1)
    elif name == 'pareto':
        return lambda x: pareto.pdf(x, 3, loc=-1, scale=1)
    elif name == 'chi2':
        return lambda x: chi2.pdf(x, 1.2, loc=0, scale=1)
    elif name == 'handcraft':
        a = [0.5] * 95 + [100] * 5
        # bandwidth = (np.max(a) - np.min(a)) / 20
        a = softmax(np.array(a))

        kde = KernelDensity(kernel='gaussian', bandwidth=0.1)
        kde.fit(np.array(a).reshape(-1, 1))

        # x = np.linspace(0, 1, 1000)
        # dens = np.exp(kde.score_samples(x.reshape(-1, 1)))
        return lambda x: np.exp(kde.score_samples(x.reshape(-1, 1)))


class Objective(object):
    def __init__(self, dataset_name, objective, k, n_noise, temp_name, method='k-gap', permutation_seed=0, normal_c=0, save_path='./res'):
        self.dataset_name = dataset_name
        self.method = method
        self.objective = objective
        self.k = k
        self.n_noise = n_noise
        self.save_path = save_path

        dataset_permutation_dict = torch.load('dataset_permutation_record.tar')

        permu = dataset_permutation_dict[dataset_name][permutation_seed]
        
        train_data, test_data, classes = data_loader.load_adbench_dataset('./dataset',
                                                                  name=dataset_name)
        assert len(permu) == train_data.shape[0]
        train_data = train_data[permu]
        train_x, val_x, test_x, test_y, mu, std\
                 = data_loader.process_dataset(train_data, test_data,
                                                     classes, normal_c,
                                                     n_noise=n_noise,
                                                     normalize=True)
        if train_x is None:
            assert False
        self.train_x = train_x
        self.test_x = test_x
        self.val_x = val_x
        self.test_y = test_y

        if method not in ['generated']:
            self.train_x = np.concatenate([train_x, val_x])
            self.val_x = None

        self.hps = []
        self.clfs = []

        self.deltas = []
        self.f1s = []
        self.aucs = []
        self.selected_ind = []

        self.train_scores = []
        self.test_scores = []

        self.temp_hist = [0, 0.95, 0.05]
        self.temp_bins = [-500.5,    0. ,  500.5]  # [-1000, -1, 1, 1000]

        self.temp_name = temp_name
        self.temp_kde = get_temp_pdf(temp_name)

        if method == 'predictor':
            self.predictor = Predictor.Predictor('./supervised/kde_auc_net_softmax_sgd-no_pos.pth.tar')
        else:
            self.predictor = None


    def deepsvdd_search(self, trial):
        # features = trial.suggest_int('features', 16, 256, 8)
        hidden_neurons1 = trial.suggest_int('hidden_neurons1', 16, 256, step=8)
        hidden_neurons2 = trial.suggest_int('hidden_neurons2', 16, 256, step=8)
        # activation = trial.suggest_categorical('activation', ['relu', 'tanh', 'leaky_relu', 'elu'])
        # optimizer = trial.suggest_categorical('optimizer', ['adam', 'sgd'])
        # epochs = trial.suggest_int('epochs', 100, 400, step=50)
        # batch_size = trial.suggest_categorical('batch_size', [64, 128, 256])
        # dropout_rate = trial.suggest_float("dropout_rate", 0.0, 1.0, log=False)
        l2_regularizer = trial.suggest_float("l2_regularizer", 1e-6, 1.0, log=True)

        if self.train_x.shape[0] > 5000:
            bs = 1024
        else:
            bs = 256

        clf = deep_svdd.DeepSVDD(n_features=self.train_x.shape[-1], c=None, use_ae=False, hidden_neurons=None,
                                 hidden_activation='leaky_relu', output_activation=[hidden_neurons1, hidden_neurons2],
                                 optimizer='adam', epochs=200,
                                 batch_size=bs, dropout_rate=0.2, l2_regularizer=l2_regularizer)

        clf.fit(self.train_x)
        train_scores = clf.decision_function(self.train_x)

        temp = self.temp_kde
        # temp = self.temp_hist, self.temp_bins
        if self.objective == 'delta':
            if self.method == 'EM':
                unif_score = baselines.get_unif_score(self.train_x, clf)
                _, delta = baselines.get_em_mv_original(train_scores, unif_score, self.train_x)
                pass
            elif self.method in ['MC', 'HITS']:
                if len(self.train_scores) < 10:
                    delta, selected_idx = 0, 0
                else:
                    score_mat = np.vstack([np.array(self.train_scores), train_scores]).T
                    if self.method == 'MC':
                        delta, selected_idx = baselines.mc(score_mat)
                    elif self.method == 'HITS':
                        delta, selected_idx = baselines.hits(score_mat)
                self.selected_ind.append(selected_idx)
            elif self.method in ['generated', 'MMP', 'generated2', 'generated3', 'generated1.5']:
                delta = self.delta_from_generated(clf, self.method)
            else:
                delta = calculate_delta(train_scores, k=self.k, method=self.method, temp=temp, predictor=self.predictor)
        else:
            delta = 0

        pred_scores = clf.decision_function(self.test_x)  # outlier scores
        f1 = f1_calculator(self.test_y, pred_scores)
        auc = roc_auc_score(self.test_y, pred_scores)
        self.clfs.append(clf)


        print(f1, auc, delta)
        self.deltas.append(delta)
        self.f1s.append(f1)
        self.aucs.append(auc)
        self.hps.append({
            'hidden_neurons1': hidden_neurons1,
            'hidden_neurons2': hidden_neurons2,
            'l2_regularizer': l2_regularizer,
        })

        self.train_scores.append(train_scores)
        self.test_scores.append(pred_scores)
        if self.objective == 'auc':
            return auc
        elif self.objective == 'f1':
            return f1
        elif self.objective == 'delta':
            return delta
        else:
            raise KeyboardInterrupt

    def dpad_search(self, trial):

        hidden_neurons1 = trial.suggest_int('hidden_neurons1', 16, 256, step=8)
        hidden_neurons2 = trial.suggest_int('hidden_neurons2', 16, 256, step=8)
        gamma = trial.suggest_float("gamma", 1e-6, 100, log=True)
        lamb = trial.suggest_float("lamb", 1e-6, 1000, log=True)
        k_dpad = trial.suggest_int("k_dpad", 3, int(0.2*self.train_x.shape[0]), step=1, log=False)

        if self.method == 'EM':
            is_unif = True
        else:
            is_unif = False

        clf = DPAD.DPAD(train_x=self.train_x, test_x=self.test_x, test_y=self.test_y,
                                                           gamma=gamma, lamb=lamb, k=k_dpad,
                                                           hidden_dims=[hidden_neurons1, hidden_neurons2],
                                                            num_classes=hidden_neurons2//2,
                                                           bs=4096,
                                                           n_epochs=200,
                                                           learning_rate=1e-3,
                                                           adam=1,
                                                            )
        clf.training()
        train_scores = clf.decision_function(self.train_x)

        pred_scores = clf.decision_function(self.test_x)  # outlier scores
        f1 = f1_calculator(self.test_y, pred_scores)
        auc = roc_auc_score(self.test_y, pred_scores)

        self.f1s.append(f1)
        self.aucs.append(auc)
        self.clfs.append(clf)

        temp = self.temp_kde
        # temp = self.temp_hist, self.temp_bins
        if self.objective == 'delta':
            if self.method == 'EM':
                if self.method == 'EM':
                    unif_score = baselines.get_unif_score(self.train_x, clf)
                    _, delta = baselines.get_em_mv_original(train_scores, unif_score, self.train_x)
                pass
            elif self.method in ['MC', 'HITS']:
                if len(self.train_scores) < 10:
                    delta, selected_idx = 0, 0
                else:
                    score_mat = np.vstack([np.array(self.train_scores), train_scores]).T
                    if self.method == 'MC':
                        delta, selected_idx = baselines.mc(score_mat)
                    elif self.method == 'HITS':
                        delta, selected_idx = baselines.hits(score_mat)
                self.selected_ind.append(selected_idx)
                pass
            elif self.method in ['generated', 'MMP', 'generated2', 'generated3', 'generated1.5']:
                delta = self.delta_from_generated(clf, self.method)
            else:
                delta = calculate_delta(train_scores, k=self.k, method=self.method, temp=temp, predictor=self.predictor)
        else:
            delta = 0

        print(f1, auc, delta)
        self.deltas.append(delta)

        self.hps.append({
            'hidden_neurons1': hidden_neurons1,
            'hidden_neurons2': hidden_neurons2,
            'gamma': gamma,
            'lamb': lamb,
            'k_dpad': k_dpad,
        })

        self.train_scores.append(train_scores)
        self.test_scores.append([])
        if self.objective == 'auc':
            return auc
        elif self.objective == 'f1':
            return f1
        elif self.objective == 'delta':
            return delta
        else:
            raise KeyboardInterrupt

    def ocsvm_search(self, trial):
        kernel = trial.suggest_categorical("kernel", ["linear", 'rbf', 'sigmoid'])
        # kernel = 'rbf'
        nu = trial.suggest_float("nu", 0.0, 1.0, log=False)
        gamma = trial.suggest_float("gamma", 1e-6, 100, log=True)
        coef0 = trial.suggest_float("coef0", 0.0, 1000, log=False)
        # nu = 0.5
        clf = ocsvm.OCSVM(kernel=kernel, nu=nu, gamma=gamma, coef0=coef0, tol=0.0001, max_iter=1000)

        clf.fit(self.train_x)
        train_scores = clf.decision_function(self.train_x)

        temp = self.temp_kde
        # temp = self.temp_hist, self.temp_bins
        if self.objective == 'delta':
            if self.method == 'EM':
                unif_score = baselines.get_unif_score(self.train_x, clf)
                _, delta  = baselines.get_em_mv_original(train_scores, unif_score, self.train_x)
                pass
            elif self.method in ['MC', 'HITS']:
                if len(self.train_scores) < 10:
                    delta, selected_idx = 0, 0
                else:
                    score_mat = np.vstack([np.array(self.train_scores), train_scores]).T
                    if self.method == 'MC':
                        delta, selected_idx = baselines.mc(score_mat)
                    elif self.method == 'HITS':
                        delta, selected_idx = baselines.hits(score_mat)
                self.selected_ind.append(selected_idx)
                pass
            elif self.method in ['generated', 'MMP', 'generated2', 'generated3', 'generated1.5']:
                delta = self.delta_from_generated(clf, self.method)
            else:
                delta = calculate_delta(train_scores, k=self.k, method=self.method, temp=temp, predictor=self.predictor)
        else:
            delta = 0
        
        pred_scores = clf.decision_function(self.test_x)  # outlier scores
        f1 = f1_calculator(self.test_y, pred_scores)
        auc = roc_auc_score(self.test_y, pred_scores)

        print(f1, auc, delta)
        self.deltas.append(delta)
        self.f1s.append(f1)
        self.aucs.append(auc)
        self.clfs.append(clf)

        self.hps.append({
            'gamma': gamma,
            'nu': nu,
        })

        self.train_scores.append(train_scores)
        self.test_scores.append(pred_scores)
        if self.objective == 'auc':
            return auc
        elif self.objective == 'f1':
            return f1
        elif self.objective == 'delta':
            return delta
        else:
            raise KeyboardInterrupt

    def ocsvm_search2(self, trial):
        kernel = trial.suggest_categorical("kernel", ["linear", 'rbf', 'sigmoid'])
        nu = trial.suggest_float("nu", 0.0, 1.0, log=False)
        ##        degree = trial.suggest_int("degree", 0.0, 1000, log=False)
        gamma = trial.suggest_float("gamma", 1e-6, 100, log=True)
        coef0 = trial.suggest_float("coef0", 0.0, 1000, log=False)
        # nu = 0.5
        if kernel in ['rbf', 'poly', 'sigmoid']:
            if kernel in ['poly', 'sigmoid']:
                clf = ocsvm.OCSVM(kernel=kernel, nu=nu, gamma=gamma, coef0=coef0, tol=0.0001,
                                  max_iter=1000)
            ##            elif kernel == 'poly':
            ##                clf = ocsvm.OCSVM(kernel=kernel, degree=degree, nu=nu, gamma=gamma, coef0=coef0, tol=0.0001)
            elif kernel == 'rbf':
                clf = ocsvm.OCSVM(kernel=kernel, nu=nu, gamma=gamma, tol=0.0001, max_iter=1000)

        else:
            clf = ocsvm.OCSVM(kernel=kernel, nu=nu, tol=0.0001, max_iter=1000)

        clf.fit(self.train_x)
        train_scores = clf.decision_function(self.train_x)

        temp = self.temp_kde
        # temp = self.temp_hist, self.temp_bins
        if self.objective == 'delta':
            if self.method == 'EM':
                unif_score = baselines.get_unif_score(self.train_x, clf)
                _, delta  = baselines.get_em_mv_original(train_scores, unif_score, self.train_x)
                pass
            elif self.method in ['MC', 'HITS']:
                if len(self.train_scores) < 10:
                    delta, selected_idx = 0, 0
                else:
                    score_mat = np.vstack([np.array(self.train_scores), train_scores]).T
                    if self.method == 'MC':
                        delta, selected_idx = baselines.mc(score_mat)
                    elif self.method == 'HITS':
                        delta, selected_idx = baselines.hits(score_mat)
                self.selected_ind.append(selected_idx)
                pass
            elif self.method == 'MMP':
                delta = self.mmp_delta(clf)
            else:
                delta = calculate_delta(train_scores, k=self.k, method=self.method, temp=temp, predictor=self.predictor)
        else:
            delta = 0

        pred_scores = clf.decision_function(self.test_x)  # outlier scores
        f1 = f1_calculator(self.test_y, pred_scores)
        auc = roc_auc_score(self.test_y, pred_scores)

        print(f1, auc, delta)
        self.deltas.append(delta)
        self.f1s.append(f1)
        self.aucs.append(auc)
        self.clfs.append(clf)

        self.hps.append({
            'kernel': kernel,
            'coef0': coef0,
            'gamma': gamma,
            'nu': nu,
        })

        self.train_scores.append(train_scores)
        self.test_scores.append(pred_scores)
        if self.objective == 'auc':
            return auc
        elif self.objective == 'f1':
            return f1
        elif self.objective == 'delta':
            return delta
        else:
            raise KeyboardInterrupt
        
    def ae_search(self, trial):
        # features = trial.suggest_int('features', 16, 256, 8)
        hidden_neurons1 = trial.suggest_int('hidden_neurons1', 16, 256, step=8)
        hidden_neurons2 = trial.suggest_int('hidden_neurons2', 16, 256, step=8)
        weight_decay = trial.suggest_float('weight_decay', 1e-6, 0.1, log=True)

        clf = auto_encoder.AutoEncoder(hidden_neuron_list=[hidden_neurons1, hidden_neurons2],
                                 hidden_activation_name='leaky_relu',
                                 optimizer_name='adam', epoch_num=100,
                                 batch_size=256, dropout_rate=0.2, optimizer_params={'weight_decay': weight_decay})

        clf.fit(self.train_x)
        train_scores = clf.decision_function(self.train_x)

        temp = self.temp_kde
        # temp = self.temp_hist, self.temp_bins
        if self.objective == 'delta':
            if self.method == 'EM':
                unif_score = baselines.get_unif_score(self.train_x, clf)
                _, delta  = baselines.get_em_mv_original(train_scores, unif_score, self.train_x)
                pass
            elif self.method in ['MC', 'HITS']:
                if len(self.train_scores) < 10:
                    delta, selected_idx = 0, 0
                else:
                    score_mat = np.vstack([np.array(self.train_scores), train_scores]).T
                    if self.method == 'MC':
                        delta, selected_idx = baselines.mc(score_mat)
                    elif self.method == 'HITS':
                        delta, selected_idx = baselines.hits(score_mat)
                self.selected_ind.append(selected_idx)
                pass
            elif self.method in ['generated', 'MMP', 'generated2', 'generated3', 'generated1.5']:
                delta = self.delta_from_generated(clf, self.method)
            else:
                delta = calculate_delta(train_scores, k=self.k, method=self.method, temp=temp, predictor=self.predictor)
        else:
            delta = 0

        pred_scores = clf.decision_function(self.test_x)  # outlier scores
        f1 = f1_calculator(self.test_y, pred_scores)
        auc = roc_auc_score(self.test_y, pred_scores)

        print(f1, auc, delta)
        self.deltas.append(delta)
        self.f1s.append(f1)
        self.aucs.append(auc)
        self.clfs.append(clf)

        self.hps.append({
            'hidden_neurons1': hidden_neurons1,
            'hidden_neurons2': hidden_neurons2,
            'weight_decay': weight_decay,
        })

        self.train_scores.append(train_scores)
        self.test_scores.append(pred_scores)
        if self.objective == 'auc':
            return auc
        elif self.objective == 'f1':
            return f1
        elif self.objective == 'delta':
            return delta
        else:
            raise KeyboardInterrupt

    def neuTralAD_search(self, trial):
        # features = trial.suggest_int('features', 16, 256, 8)
        k = trial.suggest_int('k', 4, 128, step=1)
        tau = trial.suggest_float('tau', 1e-5, 1, log=True)
        hidden_dims = trial.suggest_int('hidden_dims', 16, 256, step=8)
        enc_hdim = trial.suggest_int('enc_hdim', 16, 128, step=8)
        trans_hdim = trial.suggest_int('trans_hdim', 16, 128, step=8)
        # weight_decay = trial.suggest_float('weight_decay', 1e-6, 0.1, log=True)

        clf = NeuTraLAD.NeuTraLAD(dataset_name='None', train_x=self.train_x, test_x=self.test_x, test_y=self.test_y,
                                  k=k, tau=tau, hidden_dims=hidden_dims, enc_hdim=enc_hdim, trans_hdim=trans_hdim)

        clf.fit()
        train_scores = clf.decision_function(self.train_x)
        try:
            temp = self.temp_kde
            # temp = self.temp_hist, self.temp_bins
            if self.objective == 'delta':
                if self.method == 'EM':
                    unif_score = baselines.get_unif_score(self.train_x, clf)
                    _, delta = baselines.get_em_mv_original(train_scores, unif_score, self.train_x)
                    pass
                elif self.method in ['generated', 'MMP', 'generated2', 'generated3', 'generated1.5']:
                    delta = self.delta_from_generated(clf, self.method)
                else:
                    delta = calculate_delta(train_scores, k=self.k, method=self.method, temp=temp, predictor=self.predictor)

            else:
                delta = 0

            pred_scores = clf.decision_function(self.test_x)  # outlier scores


            f1 = f1_calculator(self.test_y, pred_scores)
            auc = roc_auc_score(self.test_y, pred_scores)
        except ValueError:
            delta = 0
            f1 = 0
            auc = 0
            if self.method == 'EM':
                delta = 10000

        print(f1, auc, delta)
        self.deltas.append(delta)
        self.f1s.append(f1)
        self.aucs.append(auc)

        self.hps.append({
            'k': k,
            'tau': tau,
            'hidden_dims': hidden_dims,
            'enc_hdim': enc_hdim,
            'trans_hdim': trans_hdim,
        })

        self.train_scores.append(train_scores)
        self.test_scores.append(pred_scores)
        if self.objective == 'auc':
            return auc
        elif self.objective == 'f1':
            return f1
        elif self.objective == 'delta':
            return delta
        else:
            raise KeyboardInterrupt

    def hrn_search(self, trial):
        # features = trial.suggest_int('features', 16, 256, 8)
        n = trial.suggest_int('n', 1, 100, step=1)
        lamb = trial.suggest_float('lamb', 1e-6, 1, log=True)
        latent_dim = trial.suggest_int('latent_dim', 16, 256, step=8)
        weight_decay = trial.suggest_float('weight_decay', 1e-6, 0.1, log=True)

        clf = HRN.HRN(train_x=self.train_x, test_x=self.test_x, test_y=self.test_y,
                      latent_dim=latent_dim, lamb=lamb, n=n, weight_decay=weight_decay)

        clf.fit()
        train_scores = clf.decision_function(self.train_x)
        try:
            temp = self.temp_kde
            # temp = self.temp_hist, self.temp_bins
            if self.objective == 'delta':
                if self.method == 'EM':
                    unif_score = baselines.get_unif_score(self.train_x, clf)
                    _, delta = baselines.get_em_mv_original(train_scores, unif_score, self.train_x)
                    pass
                elif self.method in ['generated', 'MMP', 'generated2', 'generated3', 'generated1.5']:
                    delta = self.delta_from_generated(clf, self.method)
                else:
                    delta = calculate_delta(train_scores, k=self.k, method=self.method, temp=temp, predictor=self.predictor)

            else:
                delta = 0

            pred_scores = clf.decision_function(self.test_x)  # outlier scores

            f1 = f1_calculator(self.test_y, pred_scores)
            auc = roc_auc_score(self.test_y, pred_scores)
        except ValueError:
            delta = 0
            f1 = 0
            auc = 0
            if self.method == 'EM':
                delta = 10000
            pred_scores = []

        print(f1, auc, delta)
        self.deltas.append(delta)
        self.f1s.append(f1)
        self.aucs.append(auc)

        self.hps.append({
            'n': n,
            'lamb': lamb,
            'latent_dim': latent_dim,
            'weight_decay': weight_decay,
        })

        self.train_scores.append(train_scores)
        self.test_scores.append(pred_scores)
        if self.objective == 'auc':
            return auc
        elif self.objective == 'f1':
            return f1
        elif self.objective == 'delta':
            return delta
        else:
            raise KeyboardInterrupt

    def drocc_search(self, trial):
        # features = trial.suggest_int('features', 16, 256, 8)

        gamma = trial.suggest_int('gamma', 1, 100, step=1)
        lamb = trial.suggest_float('lamb', 1e-3, 100, log=True)
        latent_dim = trial.suggest_int('latent_dim', 16, 256, step=8)
        radius = trial.suggest_float('radius', 0.5, 100, log=False)

        clf = DROCC.DROCC(train_x=self.train_x, test_x=self.test_x, test_y=self.test_y,
                          latent_dim=latent_dim, lamb=lamb, radius=radius, gamma=gamma)


        clf.fit()
        train_scores = clf.decision_function(self.train_x)
        try:
            temp = self.temp_kde
            # temp = self.temp_hist, self.temp_bins
            if self.objective == 'delta':
                if self.method == 'EM':
                    unif_score = baselines.get_unif_score(self.train_x, clf)
                    _, delta = baselines.get_em_mv_original(train_scores, unif_score, self.train_x)
                    pass
                elif self.method in ['generated', 'MMP', 'generated2', 'generated3', 'generated1.5']:
                    delta = self.delta_from_generated(clf, self.method)
                else:
                    delta = calculate_delta(train_scores, k=self.k, method=self.method, temp=temp, predictor=self.predictor)
            else:
                delta = 0

            pred_scores = clf.decision_function(self.test_x)  # outlier scores


            f1 = f1_calculator(self.test_y, pred_scores)
            auc = roc_auc_score(self.test_y, pred_scores)
        except ValueError:
            delta = 0
            f1 = 0
            auc = 0
            if self.method == 'EM':
                delta = 10000
            pred_scores = []

        print(f1, auc, delta)
        self.deltas.append(delta)
        self.f1s.append(f1)
        self.aucs.append(auc)

        self.hps.append({
            'gamma': gamma,
            'lamb': lamb,
            'latent_dim': latent_dim,
            'radius': radius,
        })

        self.train_scores.append(train_scores)
        self.test_scores.append(pred_scores)
        if self.objective == 'auc':
            return auc
        elif self.objective == 'f1':
            return f1
        elif self.objective == 'delta':
            return delta
        else:
            raise KeyboardInterrupt

    def plad_search(self, trial):
        # features = trial.suggest_int('features', 16, 256, 8)

        lamb = trial.suggest_float('lamb', 1e-4, 100, log=True)
        weight_decay = trial.suggest_float('weight_decay', 1e-6, 0.1, log=True)
        hidden_dims1 = trial.suggest_int('hidden_dims1', 16, 256, step=8)
        hidden_dims2 = trial.suggest_int('hidden_dims2', 16, 256, step=8)


        clf = PLAD.PLAD(train_x=self.train_x, test_x=self.test_x, test_y=self.test_y,
                          lamb=lamb, weight_decay=weight_decay, hidden_dims1=hidden_dims1, hidden_dims2=hidden_dims2,)


        clf.fit()
        train_scores = clf.decision_function(self.train_x)
        try:
            temp = self.temp_kde
            # temp = self.temp_hist, self.temp_bins
            if self.objective == 'delta':
                if self.method == 'EM':
                    unif_score = baselines.get_unif_score(self.train_x, clf)
                    _, delta = baselines.get_em_mv_original(train_scores, unif_score, self.train_x)
                    pass
                elif self.method in ['generated', 'MMP', 'generated2', 'generated3', 'generated1.5']:
                    delta = self.delta_from_generated(clf, self.method)
                else:
                    delta = calculate_delta(train_scores, k=self.k, method=self.method, temp=temp, predictor=self.predictor)
            else:
                delta = 0

            pred_scores = clf.decision_function(self.test_x)  # outlier scores

            f1 = f1_calculator(self.test_y, pred_scores)
            auc = roc_auc_score(self.test_y, pred_scores)
        except ValueError:
            delta = 0
            f1 = 0
            auc = 0
            if self.method == 'EM':
                delta = 10000
            pred_scores = []

        print(f1, auc, delta)
        self.deltas.append(delta)
        self.f1s.append(f1)
        self.aucs.append(auc)

        self.hps.append({
            'weight_decay': weight_decay,
            'lamb': lamb,
            'hidden_dims1': hidden_dims1,
            'hidden_dims2': hidden_dims2,
        })

        self.train_scores.append(train_scores)
        self.test_scores.append(pred_scores)
        if self.objective == 'auc':
            return auc
        elif self.objective == 'f1':
            return f1
        elif self.objective == 'delta':
            return delta
        else:
            raise KeyboardInterrupt

    def dif_search(self, trial):
        # features = trial.suggest_int('features', 16, 256, 8)

        n_ensemble = trial.suggest_int('n_ensemble', 10, 300, step=5)
        n_estimators = trial.suggest_int('n_estimators', 4, 100, step=4)
        hidden_dims1 = trial.suggest_int('hidden_dims1', 16, 256, step=8)
        hidden_dims2 = trial.suggest_int('hidden_dims2', 16, 256, step=8)

        clf = dif.DIF(batch_size=1024, representation_dim=hidden_dims2 // 2,
                      hidden_neurons=[hidden_dims1, hidden_dims2],
                      hidden_activation='tanh',
                      skip_connection=False, n_ensemble=n_ensemble, n_estimators=n_estimators,
                      max_samples=min(256, self.train_x.shape[0]),
                      contamination=0.1,
                      random_state=None, device='cuda')

        clf.fit(self.train_x)
        train_scores = clf.decision_function(self.train_x)
        try:
            temp = self.temp_kde
            # temp = self.temp_hist, self.temp_bins
            if self.objective == 'delta':
                if self.method == 'EM':
                    unif_score = baselines.get_unif_score(self.train_x, clf)
                    _, delta = baselines.get_em_mv_original(train_scores, unif_score, self.train_x)
                    pass
                elif self.method in ['generated', 'MMP', 'generated2', 'generated3', 'generated1.5']:
                    delta = self.delta_from_generated(clf, self.method)
                else:
                    delta = calculate_delta(train_scores, k=self.k, method=self.method, temp=temp,
                                            predictor=self.predictor)
            else:
                delta = 0

            pred_scores = clf.decision_function(self.test_x)  # outlier scores

            f1 = f1_calculator(self.test_y, pred_scores)
            auc = roc_auc_score(self.test_y, pred_scores)
        except ValueError:
            delta = 0
            f1 = 0
            auc = 0
            if self.method == 'EM':
                delta = 10000
            pred_scores = []

        print(f1, auc, delta)
        self.deltas.append(delta)
        self.f1s.append(f1)
        self.aucs.append(auc)

        self.hps.append({
            'n_ensemble': n_ensemble,
            'n_estimators': n_estimators,
            'hidden_dims1': hidden_dims1,
            'hidden_dims2': hidden_dims2,
        })

        self.train_scores.append(train_scores)
        self.test_scores.append(pred_scores)
        if self.objective == 'auc':
            return auc
        elif self.objective == 'f1':
            return f1
        elif self.objective == 'delta':
            return delta
        else:
            raise KeyboardInterrupt

    def scad_search(self, trial):
        # features = trial.suggest_int('features', 16, 256, 8)

        x_dim = self.train_x.shape[1]

        k = trial.suggest_int('k', 2, min(x_dim-1, 128), step=1)
        tau = trial.suggest_float('tau', 1e-5, 1, log=True)
        hidden_dims = trial.suggest_int('hidden_dims', 16, 256, step=8)
        weight_decay = trial.suggest_float('weight_decay', 1e-6, 0.1, log=True)

        clf = SCAD.SCAD(self.train_x, self.test_x, self.test_y, k, tau, weight_decay, hidden_dims,
                 device='cuda')

        clf.fit()
        train_scores = clf.decision_function(self.train_x)
        try:
            temp = self.temp_kde
            # temp = self.temp_hist, self.temp_bins
            if self.objective == 'delta':
                if self.method == 'EM':
                    unif_score = baselines.get_unif_score(self.train_x, clf)
                    _, delta = baselines.get_em_mv_original(train_scores, unif_score, self.train_x)
                    pass
                elif self.method in ['generated', 'MMP', 'generated2', 'generated3', 'generated1.5']:
                    delta = self.delta_from_generated(clf, self.method)
                else:
                    delta = calculate_delta(train_scores, k=self.k, method=self.method, temp=temp,
                                            predictor=self.predictor)
            else:
                delta = 0

            pred_scores = clf.decision_function(self.test_x)  # outlier scores

            f1 = f1_calculator(self.test_y, pred_scores)
            auc = roc_auc_score(self.test_y, pred_scores)
        except ValueError:
            delta = 0
            f1 = 0
            auc = 0
            if self.method == 'EM':
                delta = 10000
            pred_scores = []

        print(f1, auc, delta)
        self.deltas.append(delta)
        self.f1s.append(f1)
        self.aucs.append(auc)

        self.hps.append({
            'k': k,
            'tau': tau,
            'hidden_dims': hidden_dims,
            'weight_decay': weight_decay,
        })

        self.train_scores.append(train_scores)
        self.test_scores.append(pred_scores)
        if self.objective == 'auc':
            return auc
        elif self.objective == 'f1':
            return f1
        elif self.objective == 'delta':
            return delta
        else:
            raise KeyboardInterrupt
    def ocsvm_munual_grid_search(self, nu, gamma):
        # kernel = trial.suggest_categorical("kernel", ["linear", 'rbf', 'sigmoid'])
        # nu = trial.suggest_float("nu", 0.0, 1.0, log=False)
        # gamma = trial.suggest_float("gamma", 0.0, 1000, log=False)
        # coef0 = trial.suggest_float("coef0", 0.0, 1000, log=False)
        # nu = 0.5
        clf = ocsvm.OCSVM(kernel='rbf', nu=nu, gamma=gamma, tol=0.0001, max_iter=1000)

        clf.fit(self.train_x)
        train_scores = clf.decision_function(self.train_x)

        temp = self.temp_kde
        # temp = self.temp_hist, self.temp_bins
        delta = calculate_delta(train_scores, k=self.k, method=self.method, temp=temp, predictor=self.predictor)

        pred_scores = clf.decision_function(self.test_x)  # outlier scores
        f1 = f1_calculator(self.test_y, pred_scores)
        auc = roc_auc_score(self.test_y, pred_scores)

        print(f1, auc, delta)
        self.deltas.append(delta)
        self.f1s.append(f1)
        self.aucs.append(auc)
        self.clfs.append(clf)

        self.hps.append({
            'gamma': gamma,
            'nu': nu,
        })

        self.train_scores.append(train_scores)
        self.test_scores.append(pred_scores)
        if self.objective == 'auc':
            return auc
        elif self.objective == 'f1':
            return f1
        elif self.objective == 'delta':
            return delta
        else:
            raise KeyboardInterrupt

    def delta_from_generated(self, clf, method):
        if method == 'MMP':
            return self.mmp_delta(clf)
        elif method == 'generated':
            return self.mmp_native(clf, std=1)
        elif method == 'generated2':
            return self.mmp_native(clf, std=2)
        elif method == 'generated3':
            return self.mmp_native(clf, std=2)
        elif method == 'generated1.5':
            return self.mmp_native(clf, std=1.5)
        else:
            raise NotImplementedError

    def mmp_native(self, clf, std=1.0):
        val_score = clf.decision_function(self.val_x)
        gen_x = torch.randn_like(torch.Tensor(self.val_x)).numpy() * std
        gen_score = clf.decision_function(gen_x)

        intra_var, inter_var = utils.get_intra_inter(val_score, gen_score)
        delta = inter_var / (intra_var + 1e-6)

        # delta = np.abs(np.mean(val_score) - np.mean(gen_score))
        return delta

    def mmp_delta(self, clf):
        res1 = torch.cdist(torch.Tensor(self.val_x), torch.Tensor(self.val_x)).numpy()
        res1 = res1 + np.eye(self.val_x.shape[0]) * np.max(res1)
        d = np.min(res1, axis=0).reshape(-1, 1)
        d = np.clip(d, 0, 20)

        generated = torch.randn(2000 // self.val_x.shape[0], self.val_x.shape[0], self.val_x.shape[1]) * d + self.val_x
        generated = generated.reshape(-1, self.val_x.shape[1])
        generated = utils.filter_generated_max(self.val_x, generated)
        gen_x = generated[np.random.permutation(generated.shape[0])][:self.val_x.shape[0]]

        val_score = clf.decision_function(self.val_x)

        gen_score = clf.decision_function(gen_x)

        intra_var, inter_var = utils.get_intra_inter(val_score, gen_score)
        delta = inter_var / (intra_var + 1e-6)
        return delta


    def plot_results(self, ad_method, root='./', log=False, is_sort=False, is_cummax=True):
        aucs = np.array(self.aucs)
        f1s = np.array(self.f1s)
        deltas = np.array(self.deltas)

        if is_sort:
            sort = deltas.argsort()
        else:
            sort = range(len(deltas))

        if is_cummax:
            ddf = pd.DataFrame({'d': deltas},
                               index=range(len(deltas))
                               )
            ddf['cummax'] = ddf.d.cummax()
            ddf['idx'] = ddf.index

            ddf_ = ddf.merge(ddf.groupby('cummax')[['idx']].first().reset_index(), on='cummax')
            sort = ddf_['idx_y'].to_numpy()

        plt.cla()

        fig, ax1 = plt.subplots()
        x = np.arange(len(sort))
        if log:
            ax1.plot(x, np.log10(deltas[sort]), label='gap', c='b')
        else:
            ax1.plot(x, deltas[sort], label='gap', c='b')

        ax2 = ax1.twinx()
        ax2.plot(x, aucs[sort], label='aucs', c='r')
        ax2.plot(x, f1s[sort], label='f1', c='orange')

        # if self.n_noise < 1:
        #     n_noise = int(self.train_x.shape[0] * self.n_noise / (1 + self.n_noise))
        # else:
        #     n_noise = self.n_noise
        n_noise = self.n_noise
        if self.objective == 'delta':
            title = f'dataset: {self.dataset_name}, delta method: {self.method}, n_noise={n_noise}, k={self.k}, temp_method: {self.temp_name}'
        else:
            title = f'dataset: {self.dataset_name}, objective: {self.objective}, n_noise={n_noise}, k={self.k}'
        plt.title(title)
        ax2.set_xlabel('iter')
        fig.legend(bbox_to_anchor=(1, 0.22), bbox_transform=ax1.transAxes)
        plt.savefig(root + '/' + f'ad_method_{ad_method}-dataset_{self.dataset_name}-{self.objective}-delta method_{self.method}-n_noise={n_noise}-k_{self.k}-temp_method_{self.temp_name}.png')

    def plot_results_auc(self, ad_method, log=False):
        aucs = np.array(self.aucs)
        f1s = np.array(self.f1s)
        deltas = np.array(self.deltas)
        sort = aucs.argsort()

        plt.cla()

        if log:
            plt.plot(aucs[sort], np.log10(deltas[sort]), label='gap')
            plt.ylabel('log-scale')
        else:
            plt.plot(aucs[sort], deltas[sort], label='gap')
        plt.plot(aucs[sort], f1s[sort], label='f1')

        # if self.n_noise < 1:
        #     n_noise = int(self.train_x.shape[0] * self.n_noise / (1 + self.n_noise))
        # else:
        #     n_noise = self.n_noise
        n_noise = self.n_noise
        if self.objective == 'delta':
            title = f'dataset: {self.dataset_name}, delta method: {self.method}, n_noise={n_noise}, k={self.k}, temp_method: {self.temp_name}'
        else:
            title = f'dataset: {self.dataset_name}, objective: {self.objective}, n_noise={n_noise}, k={self.k}'
        plt.title(title)
        plt.xlabel('auc')
        plt.legend()
        plt.savefig(f'./ad_method_{ad_method}-dataset_{self.dataset_name}-{self.objective}-delta method_{self.method}-n_noise={n_noise}-k_{self.k}-temp_method_{self.temp_name}.png')

    def save_result(self, ad_method):
        f_name = self.save_path + '/' + f'data/ad-method-{ad_method}_dataset-{self.dataset_name}_{self.objective}_delta-method-{self.method}_n-noise-{self.n_noise}_k-{self.k}_temp-method-{self.temp_name}'
        save = {
            'aucs': self.aucs,
            'f1s': self.f1s,
            'deltas': self.deltas,
            'train_x': self.train_x,
            'test_x': self.test_x,
            'test_y': self.test_y,
            'train_scores': self.train_scores,
            'test_scores': self.test_scores,
            'hps': self.hps,
            'selected_ind': self.selected_ind,
            # 'clfs': self.clfs
        }
        torch.save(save, f_name)

if __name__ == "__main__":
    pass
    

    
