import copy

import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score
from sklearn.neighbors import KernelDensity

from astropy.visualization import hist
from scipy.stats import wasserstein_distance
from scipy.special import softmax

from scipy.stats import levy_stable, cauchy, fisk, weibull_min, burr, pareto, chi2

import torch
import pyod
##from pyod.models import iforest
##from pyod.models import copod
##from pyod.models import auto_encoder
##from pyod.models import knn
##from pyod.models import lof
from pyod.models import deep_svdd
##from pyod.models import anogan
##from pyod.models import ecod

from pyod.models import ocsvm

from dataset import shallow_method_data_loader as data_loader

import optuna


from BO_Objective import Objective


if __name__ == "__main__":
    datasets = [
        # '4_breastw',
        # '45_wine',
        # '12_fault',
        # '48_arrhythmia',
        # '14_glass', '15_Hepatitis',
        # '20_letter', '21_Lymphography',
        # '22_magic.gamma', '23_mammography',
        # '24_mnist',
        '32_shuttle'
                ]
    for dataset in datasets:
        for method in ['None']:
            for n_noise in [0, 5]:
                for k in [1]:
                    # burr', 'log-logistic', 'log-cauchy', 'levy', 'weibull', 'pareto', 'chi2', 'handcraft']
                    # if isinstance(k_off, str):
                    #     k = int(n_noise * float(k_off))
                    # else:
                    #     k = n_noise + k_off

                    # if k <= 0:
                    #     continue

                    if method in ['wasserstein', 'kde-wasserstein', 'kde-wasserstein-v2']:
                        direction = 'minimize'
                        temp_name = 'handcraft'
                    else:
                        direction = 'maximize'
                        temp_name = None

                    objective = Objective(dataset, objective='auc', k=k, n_noise=n_noise, method=method, temp_name=temp_name)

                    algo = optuna.samplers.TPESampler(n_startup_trials=10, n_ei_candidates=24)
                    study = optuna.create_study(sampler=algo, direction=direction, storage="sqlite:///record/ocsvm2.db")
                    study.optimize(objective.deepsvdd_search, n_trials=1000, show_progress_bar=True)
                    print(study.best_trial)

                    objective.plot_results('DeepSVDD')
                    objective.save_result('DeepSVDD')

    

    
