import copy

import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score
from sklearn.neighbors import KernelDensity

from astropy.visualization import hist
from scipy.stats import wasserstein_distance
from scipy.special import softmax

from scipy.stats import levy_stable, cauchy, fisk, weibull_min, burr, pareto, chi2

import torch
import pyod
##from pyod.models import iforest
##from pyod.models import copod
##from pyod.models import auto_encoder
##from pyod.models import knn
##from pyod.models import lof
from pyod.models import deep_svdd
##from pyod.models import anogan
##from pyod.models import ecod

from pyod.models import ocsvm

from dataset import shallow_method_data_loader as data_loader

import optuna
from dpad import DPAD

import BO_Objective
from BO_Objective import Objective


if __name__ == "__main__":
    datasets = [
        # '48_arrhythmia', '49_shuttle',
        #  '12_fault',  '14_glass', '15_Hepatitis',
        # '17_InternetAds', '18_Ionosphere', '19_landsat',
        #  '20_letter', '21_Lymphography',  '24_mnist',
        # '25_musk', '26_optdigits', '27_PageBlocks', '28_pendigits', '29_Pima', '2_annthyroid',
        # '30_satellite', '31_satimage-2',   '35_SpamBase', '36_speech',
        # '37_Stamps', '38_thyroid', '39_vertebral',
        #
        '40_vowels', '41_Waveform', '42_WBC',
        '43_WDBC', '44_Wilt', '45_wine', '46_WPBC', '47_yeast', '4_breastw',  '6_cardio',
        '7_Cardiotocography',

        '22_magic.gamma', '23_mammography',

        '1_ALOI',  # mid
                ]
    for dataset in datasets:
        for method in ['avg-var-20-k']:
            for run in range(5):
                for k in [0.05]:
                    if method in ['wasserstein', 'kde-wasserstein']:
                        direction = 'minimize'
                        temp_name = 'handcraft'
                    else:
                        direction = 'maximize'
                        temp_name = 'None2'

                    objective = Objective(dataset, objective='delta', k=k, n_noise=0, method=method, temp_name=temp_name,
                                          permutation_seed=0)

                    clf = DPAD.DPAD(train_x=objective.train_x, test_x=objective.test_x, test_y=objective.test_y,
                                    gamma=0.01, lamb=0.1, k=10,
                                    hidden_dims=[256, 128],
                                    num_classes=128,
                                    bs=4096,
                                    n_epochs=200,
                                    learning_rate=1e-3,
                                    adam=1,
                                    )
                    clf.training()
                    # train_scores = clf.decision_function(objective.train_x)

                    pred_scores = clf.decision_function(objective.test_x)  # outlier scores
                    f1 = BO_Objective.f1_calculator(objective.test_y, pred_scores)
                    auc = roc_auc_score(objective.test_y, pred_scores)
                    print(dataset, {'auc': auc, 'f1': f1})

                    torch.save({'auc': auc, 'f1': f1}, f'./save-default/DPAD-{dataset}-0-defualt-std-{run}.save')

                    # objective.plot_results('DPAD', root='./res-0905')
                    # objective.save_result('DPAD')

    

    
