import copy

import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score
from sklearn.neighbors import KernelDensity

from astropy.visualization import hist
from scipy.stats import wasserstein_distance
from scipy.special import softmax

from scipy.stats import levy_stable, cauchy, fisk, weibull_min, burr, pareto, chi2

import torch
import optuna

from BO_Objective import Objective

if __name__ == "__main__":
    datasets = [
        # '22_magic.gamma', '23_mammography',

        # '4_breastw',
        # '45_wine',
        # '12_fault',
        # '48_arrhythmia',
        # '14_glass',
        # '15_Hepatitis',
        # '20_letter', '21_Lymphography',
        # '24_mnist',
        # '25_musk', '26_optdigits',
        # '42_WBC',
        # '46_WPBC',
        # '29_Pima', '38_thyroid', '39_vertebral',
        # '47_yeast', '41_Waveform'

        # '17_InternetAds', '18_Ionosphere', '19_landsat',
        # '27_PageBlocks', '28_pendigits', '2_annthyroid',
        # '30_satellite', '31_satimage-2', '35_SpamBase', '36_speech',
        # '37_Stamps', '40_vowels',
        # '43_WDBC', '44_Wilt', '6_cardio',
        # '7_Cardiotocography',
                ]

    for dataset in datasets:
        for n_noise in [0]:
            for k in [0.05]:
                for method in ['relative-topk-median', 'avg-var-20-k']: #, 'avg-var-k', 'otsu-max-ind'

                    if method in ['wasserstein', 'kde-wasserstein', 'kde-kl', 'kde-wasserstein-v2', 'kde-l2']:
                        direction = 'minimize'
                        temp_name = 'None'
                    else:
                        direction = 'maximize'
                        temp_name = 'New2'

                    objective = Objective(dataset, objective='delta', k=k, n_noise=n_noise, method=method, temp_name=temp_name)

                    algo = optuna.samplers.TPESampler(n_startup_trials=10, n_ei_candidates=24)
                    # algo = optuna.samplers.GPSampler(n_startup_trials=10)
                    study = optuna.create_study(sampler=algo, direction=direction, storage="sqlite:///record/ocsvm2.db")
                    study.optimize(objective.ocsvm_search, n_trials=500, show_progress_bar=True)
                    print(study.best_trial)

                    objective.plot_results('OCSVM', root='./res-ocsvm-large_data')
                    objective.save_result('OCSVM')

    

    
