import copy

import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score
from sklearn.neighbors import KernelDensity

from astropy.visualization import hist
from scipy.stats import wasserstein_distance
from scipy.special import softmax

from scipy.stats import levy_stable, cauchy, fisk, weibull_min, burr, pareto, chi2

import torch
import pyod
##from pyod.models import iforest
##from pyod.models import copod
##from pyod.models import auto_encoder
##from pyod.models import knn
##from pyod.models import lof
from pyod.models import deep_svdd
##from pyod.models import anogan
##from pyod.models import ecod

from pyod.models import ocsvm

from dataset import shallow_method_data_loader as data_loader

import optuna


from BO_Objective import Objective


if __name__ == "__main__":
    datasets = [
        # '22_magic.gamma', '23_mammography',
        '1_ALOI', # mid


        # '48_arrhythmia', '49_shuttle',
        # '12_fault', '14_glass', '15_Hepatitis',

        # '17_InternetAds',
        # '18_Ionosphere', '19_landsat',
        # '20_letter', '21_Lymphography', '24_mnist',
        # '25_musk', '26_optdigits', '27_PageBlocks', '28_pendigits', '29_Pima', '2_annthyroid',
        # '30_satellite', '31_satimage-2', '35_SpamBase', '36_speech',
        #
        # '37_Stamps', '38_thyroid', '39_vertebral', '40_vowels', '41_Waveform', '42_WBC',
        # '43_WDBC', '44_Wilt', '45_wine', '46_WPBC', '47_yeast', '4_breastw', '6_cardio',
        # '7_Cardiotocography',
                ]
    for dataset in datasets:
        for method in ['generated']:
            for n_noise in [0]:
                for k in [0.05]:
                    if method in ['wasserstein', 'kde-wasserstein']:
                        direction = 'minimize'
                        temp_name = 'handcraft'
                    else:
                        direction = 'maximize'
                        temp_name = 'Grid Search-generated'

                    objective = Objective(dataset, objective='delta', k=k, n_noise=n_noise, method=method,
                                          temp_name=temp_name,
                                          save_path='./res-grid-new')

                    search_space = {
                        'nu': np.arange(0.1, 1.0, 0.1),
                        'gamma': [100, 50, 10, 5, 1, 0.5, 0.1, 0.01, 1e-3, 1e-4, 1e-5] + [5e-3, 5e-4, 5e-5, 5e-6]
                    }
                    
                    for nu in search_space['nu']:
                        for gamma in search_space['gamma']:
                            objective.ocsvm_munual_grid_search(nu, gamma)

                    # algo = optuna.samplers.GridSampler(search_space=search_space)
                    # study = optuna.create_study(sampler=algo, direction=direction, storage="sqlite:///record/ocsvm2.db")
                    # study.optimize(objective.ocsvm_search, show_progress_bar=True)
                    # print(study.best_trial)

                    objective.plot_results('OCSVM', root='./res-grid-new')
                    objective.save_result('OCSVM')

    

    
