import copy

import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score
from sklearn.neighbors import KernelDensity

from astropy.visualization import hist
from scipy.stats import wasserstein_distance
from scipy.special import softmax

from scipy.stats import levy_stable, cauchy, fisk, weibull_min, burr, pareto, chi2

import torch
import pyod
##from pyod.models import iforest
##from pyod.models import copod
##from pyod.models import auto_encoder
##from pyod.models import knn
##from pyod.models import lof
from pyod.models import deep_svdd
##from pyod.models import anogan
##from pyod.models import ecod

from pyod.models import ocsvm

from dataset import shallow_method_data_loader as data_loader

import optuna


from BO_Objective import Objective


if __name__ == "__main__":
    datasets = [
        # '5_campaign', '8_celeba', '9_census', '11_donors', '13_fraud', '16_http',  # discarded
        # '32_shuttle', '33_skin','34_smtp', '10_cover' # discarded
        # '3_backdoor',  # large

        '48_arrhythmia', '49_shuttle',
         '12_fault',  '14_glass', '15_Hepatitis',
        '17_InternetAds', '18_Ionosphere', '19_landsat',
         '20_letter', '21_Lymphography',  '24_mnist',
        '25_musk', '26_optdigits', '27_PageBlocks', '28_pendigits', '29_Pima', '2_annthyroid',
        '30_satellite', '31_satimage-2',   '35_SpamBase', '36_speech',
        '37_Stamps', '38_thyroid', '39_vertebral',  '40_vowels', '41_Waveform', '42_WBC',
        '43_WDBC', '44_Wilt', '45_wine', '46_WPBC', '47_yeast', '4_breastw',  '6_cardio',
        '7_Cardiotocography',

        '22_magic.gamma', '23_mammography', '1_ALOI',  # mid
                ]

    permu_dict = {}
    for dataset in datasets:
        for method in [ 'generated', 'relative-topk-median', 'avg-var-20-k', 'EM']: # 'avg-var-20-k',
            for n_noise in [0]:
                for k in [0.05]:
                    if method in ['wasserstein', 'kde-wasserstein', 'EM']:
                        direction = 'minimize'
                        # temp_name = 'handcraft'
                    else:
                        direction = 'maximize'
                    temp_name = 'New2'

    #                 train_data, test_data, classes = data_loader.load_adbench_dataset(
    #                     'G:\\fan\\ad\other_ad_methods\\datasets',
    #                     name=dataset)
    #                 d_szie = train_data.shape[0]
    #                 for i in range(5):
    #                     permu_dict[dataset] = permu_dict.get(dataset, []) + [np.random.permutation(d_szie)]
    # torch.save(permu_dict, 'dataset_permutation_record.tar')
                    objective = Objective(dataset, objective='delta', k=k, n_noise=n_noise, method=method,
                                          temp_name=temp_name, save_path='./res', permutation_seed=0)
                    algo = optuna.samplers.TPESampler(n_startup_trials=10, n_ei_candidates=24)
                    # algo = optuna.samplers.GPSampler(n_startup_trials=10)
                    study = optuna.create_study(sampler=algo, direction=direction,
                                                # storage="sqlite:///record/ocsvm2.db"
                                                )
                    study.optimize(objective.deepsvdd_search, n_trials=500, show_progress_bar=True)
                    print(study.best_trial)

                    objective.plot_results('DeepSVDD', root='./res')
                    objective.save_result('DeepSVDD')

    

    
