import copy

import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score
from sklearn.neighbors import KernelDensity

from astropy.visualization import hist
from scipy.stats import wasserstein_distance

from scipy.stats import levy_stable, cauchy, fisk, weibull_min, burr, pareto, chi2

import torch
import pyod
##from pyod.models import iforest
##from pyod.models import copod
##from pyod.models import auto_encoder
##from pyod.models import knn
##from pyod.models import lof
from pyod.models import deep_svdd
##from pyod.models import anogan
##from pyod.models import ecod

from pyod.models import ocsvm

from dataset import shallow_method_data_loader as data_loader

import optuna


def min_max_standardize(data, feature_range=(0, 1)):
    """
    Applies Min-Max Standardization to the input data.

    Parameters:
    - data: np.array, the data to be standardized.
    - feature_range: tuple (min, max), default=(0, 1), the desired range of the transformed data.

    Returns:
    - standardized_data: np.array, the standardized data.
    """
    min_val = np.min(data)
    max_val = np.max(data)
    
    # Rescale the data to [0, 1]
    standardized_data = (data - min_val) / (max_val - min_val + 1e-9)

    standardized_data = standardized_data - standardized_data.mean()
    # Scale to the desired range
##    min_range, max_range = feature_range
##    standardized_data = standardized_data * (max_range - min_range) + min_range
    
    return standardized_data


def median_standardize(data):
    median_val = np.median(data)
    standardized_data = (data - median_val) / (median_val + 1e-9)
    return standardized_data


def normalize(data):
    return (data - data.mean()) / (data.std() + 1e-9)


def calculate_delta(scores_, k, method='gap', temp=None):
    assert method in ['gap', 'k-gap', 'relative-k-gap', 'relative-gap', 'topk-median', 'wasserstein', 'kde-wasserstein',
                      'kde-kl']
    scores = copy.deepcopy(scores_)
##    scores = min_max_standardize(scores)
##    scores = normalize(scores)
    scores = np.sort(scores)[::-1]
    if method == 'gap':
        up = scores[k-1]
        down = scores[k]
        return up - down
    elif method == 'k-gap':
        up = np.mean(scores[:k])
        down = np.mean(scores[k:k+k])
        return up - down
    elif method == 'relative-k-gap':
        up = np.mean(scores[:k])
        down = np.mean(scores[k:k + k])
        delta = (up - down) / (down + 1e-9)
        return delta
    elif method == 'relative-gap':
        up = scores[k - 1]
        down = scores[k]
        delta = (up - down) / (down + 1e-9)
        return delta
    elif method == 'topk-median':
        up = np.mean(scores[:k])
        down = np.median(scores)
        return up - down
    elif method == 'wasserstein':
        n_samples = len(scores)
        count, bins, _ = hist(scores, bins=10,)
        distribution = count / n_samples
        bins = bins[:-1]
        temp_distribution, temp_bins = temp
        return wasserstein_distance(bins, temp_bins, distribution, temp_distribution)
    elif method == 'kde-wasserstein':
        scores = median_standardize(scores)
        kde = KernelDensity(kernel='gaussian', bandwidth=0.1)
        kde.fit(scores.reshape(-1, 1))

        x = np.linspace(np.min(scores), np.max(scores), 1000)
        dens = np.exp(kde.score_samples(x.reshape(-1, 1)))

        temp_x = np.linspace(0, 1, 1000)
        temp_dens = temp(temp_x)
        try:
            dist = wasserstein_distance(x, temp_x, dens, temp_dens)
        except ValueError:
            return 10
        if dist > 10:
            return 10
        return dist
    elif method == 'kde-kl':
        scores = min_max_standardize(scores)
        kde = KernelDensity(kernel='gaussian', bandwidth=0.1)
        kde.fit(scores.reshape(-1, 1))

        x = np.linspace(0, np.max(scores) + 1, 1000)
        dens = np.exp(kde.score_samples(x.reshape(-1, 1)))

        temp_x = np.linspace(0, 1, 1000)
        temp_dens = temp(temp_x)

        epsilon = 1e-10
        kl_divergence = np.sum(dens * np.log((dens + epsilon) / (temp_dens + epsilon)) * (x[1] - x[0]))
        return kl_divergence
    else:
        raise NotImplementedError

def f1_calculator(targets, score):
    # count how many positive sample
    n_possitive = sum(targets.reshape(-1) == 1)
    tp_plus_fp_idx = np.argsort(-score)[:n_possitive]
    tp_plus_fp = targets[tp_plus_fp_idx]

    tp = tp_plus_fp[tp_plus_fp == 1]
    fp = tp_plus_fp[tp_plus_fp == 0]

    f1 = len(tp) / (len(tp) + len(fp))
    return f1


def get_temp_pdf(name='handcraft'):
    if name == 'burr':
        return lambda x: burr.pdf(x, 2, 0.4, loc=0, scale=0.5)
    elif name == 'log-logistic':
        return lambda x: fisk.pdf(x, 2, loc=0, scale=0.5)
    elif name == 'log-cauchy':
        return lambda x: np.exp(cauchy.pdf(x, loc=0, scale=0.2))
    elif name == 'levy':
        return lambda x: levy_stable.pdf(x, 0.5, 1, loc=0, scale=0.1)
    elif name == 'weibull':
        return lambda x: weibull_min.pdf(x, 0.7, loc=0, scale=1)
    elif name == 'pareto':
        return lambda x: pareto.pdf(x, 3, loc=-1, scale=1)
    elif name == 'chi2':
        return lambda x: chi2.pdf(x, 1.2, loc=0, scale=1)
    elif name == 'handcraft':
        a = [0.5] * 95 + [100] * 5

        a = min_max_standardize(np.array(a))

        kde = KernelDensity(kernel='gaussian', bandwidth=0.1)
        kde.fit(np.array(a).reshape(-1, 1))

        # x = np.linspace(0, 1, 1000)
        # dens = np.exp(kde.score_samples(x.reshape(-1, 1)))
        return lambda x: np.exp(kde.score_samples(x.reshape(-1, 1)))


class Objective(object):
    def __init__(self, dataset_name, objective, k, n_noise, temp_name, method='k-gap', normal_c=0, save_path='./res'):
        self.dataset_name = dataset_name
        self.method = method
        self.objective = objective
        self.k = k
        self.n_noise = n_noise
        self.save_path = save_path
        
        train_data, test_data, classes = data_loader.load_adbench_dataset('G:\\fan\\ad\other_ad_methods\\datasets',
                                                                  name=dataset_name)
        
        train_x, test_x, test_y \
                 = data_loader.process_dataset(train_data, test_data,
                                                     classes, normal_c,
                                                     n_noise=n_noise,
                                                     normalize=True)
        if train_x is None:
            assert False
        self.train_x = train_x
        self.test_x = test_x
        self.test_y = test_y

        self.deltas = []
        self.f1s = []
        self.aucs = []

        self.train_scores = []
        self.test_scores = []

        self.temp_hist = [0, 0.95, 0.05]
        self.temp_bins = [-500.5,    0. ,  500.5]  # [-1000, -1, 1, 1000]

        self.temp_name = temp_name
        self.temp_kde = get_temp_pdf(temp_name)

        # self.temp_bins = [-505., -1, 0.5, 1, 505.]  # [-1000, -10, -1, 1, 10, 1000]

    def deepsvdd_search(self, trial):
        features = trial.suggest_int('features', 16, 256, 8)
        hidden_neurons1 = trial.suggest_int('hidden_neurons1', 16, 256, 8)
        hidden_neurons2 = trial.suggest_int('hidden_neurons2', 16, 256, 8)
        activation = trial.suggest_categorical('activation', ['relu', 'tanh', 'selu', 'gelu'])
        optimizer = trial.suggest_categorical('optimizer', ['adam', 'sgd'])
        epochs = trial.suggest_int('epochs', 100, 400, 50)
        batch_size = trial.suggest_categorical('batch_size', [64, 128, 256])
        dropout_rate = trial.suggest_float("dropout_rate", 0.0, 1.0, log=False)
        l2_regularizer = trial.suggest_float("l2_regularizer", 0.0, 1.0, log=False)

        clf = deep_svdd.DeepSVDD(n_features=features, c=None, use_ae=False, hidden_neurons=None,
                                 hidden_activation=activation, output_activation=[hidden_neurons1, hidden_neurons2],
                                 optimizer=optimizer, epochs=epochs,
                                 batch_size=batch_size, dropout_rate=dropout_rate, l2_regularizer=l2_regularizer)

        clf.fit(self.train_x)
        train_scores = clf.decision_function(self.train_x)

        temp = self.temp_kde
        # temp = self.temp_hist, self.temp_bins
        delta = calculate_delta(train_scores, k=self.k, method=self.method, temp=temp)

        pred_scores = clf.decision_function(self.test_x)  # outlier scores
        f1 = f1_calculator(self.test_y, pred_scores)
        auc = roc_auc_score(self.test_y, pred_scores)

        print(f1, auc, delta)
        self.deltas.append(delta)
        self.f1s.append(f1)
        self.aucs.append(auc)

        self.train_scores.append(train_scores)
        self.test_scores.append(pred_scores)
        if self.objective == 'auc':
            return auc
        elif self.objective == 'f1':
            return f1
        elif self.objective == 'delta':
            return delta
        else:
            raise KeyboardInterrupt

    def ocsvm_search(self, trial):
        kernel = trial.suggest_categorical("kernel", ["linear", 'rbf', 'sigmoid'])
        nu = trial.suggest_float("nu", 0.0, 1.0, log=False)
##        degree = trial.suggest_int("degree", 0.0, 1000, log=False)
        gamma = trial.suggest_float("gamma", 0.0, 1000, log=False)
        coef0 = trial.suggest_float("coef0", 0.0, 1000, log=False)

        if kernel in ['rbf', 'poly', 'sigmoid']:
            if kernel in ['poly', 'sigmoid']:
                clf = ocsvm.OCSVM(kernel=kernel, nu=nu, gamma=gamma, coef0=coef0, tol=0.0001)
##            elif kernel == 'poly':
##                clf = ocsvm.OCSVM(kernel=kernel, degree=degree, nu=nu, gamma=gamma, coef0=coef0, tol=0.0001)
            elif kernel == 'rbf':
                clf = ocsvm.OCSVM(kernel=kernel, nu=nu, gamma=gamma, tol=0.0001)

        else:
            clf = ocsvm.OCSVM(kernel=kernel, nu=nu, tol=0.0001)

        clf.fit(self.train_x)
        train_scores = clf.decision_function(self.train_x)

        temp = self.temp_kde
        # temp = self.temp_hist, self.temp_bins
        delta = calculate_delta(train_scores, k=self.k, method=self.method, temp=temp)
        
        pred_scores = clf.decision_function(self.test_x)  # outlier scores
        f1 = f1_calculator(self.test_y, pred_scores)
        auc = roc_auc_score(self.test_y, pred_scores)

        print(f1, auc, delta)
        self.deltas.append(delta)
        self.f1s.append(f1)
        self.aucs.append(auc)

        self.train_scores.append(train_scores)
        self.test_scores.append(pred_scores)
        if self.objective == 'auc':
            return auc
        elif self.objective == 'f1':
            return f1
        elif self.objective == 'delta':
            return delta
        else:
            raise KeyboardInterrupt

    def plot_results(self):
        aucs = np.array(self.aucs)
        f1s = np.array(self.f1s)
        deltas = np.array(self.deltas)
        sort = aucs.argsort()

        plt.cla()

        plt.plot(aucs[sort], deltas[sort], label='gap')
        plt.plot(aucs[sort], f1s[sort], label='f1')

        # if self.n_noise < 1:
        #     n_noise = int(self.train_x.shape[0] * self.n_noise / (1 + self.n_noise))
        # else:
        #     n_noise = self.n_noise
        n_noise = self.n_noise
        if self.objective == 'delta':
            title = f'dataset: {self.dataset_name}, delta method: {self.method}, n_noise={n_noise}, k={self.k}, temp_method: {self.temp_name}'
        else:
            title = f'dataset: {self.dataset_name}, objective: {self.objective}, n_noise={n_noise}, k={self.k}'
        plt.title(title)
        plt.xlabel('auc')
        plt.legend()
        plt.savefig(f'./dataset_{self.dataset_name}-{self.objective}-delta method_{self.method}-n_noise={n_noise}-k_{self.k}-temp_method_{self.temp_name}.png')

    def save_result(self):
        f_name = self.save_path + '/' + f'data/dataset-{self.dataset_name}_{self.objective}_delta-method-{self.method}_n-noise-{n_noise}_k-{self.k}_temp-method-{self.temp_name}'
        save = {
            'aucs': self.aucs,
            'f1s': self.f1s,
            'deltas': self.deltas,
            'train_x': self.train_x,
            'test_x': self.test_x,
            'test_y': self.test_y,
            'train_scores': self.train_scores,
            'test_scores': self.test_scores,
        }
        torch.save(save, f_name)

if __name__ == "__main__":
    datasets = [
        # '4_breastw',
##        '45_wine',
        # '12_fault',
##        '48_arrhythmia',
        # '14_glass', '15_Hepatitis',
##        '20_letter', '21_Lymphography',
        # '22_magic.gamma', '23_mammography',
        # '24_mnist',
##        '11_donors'
##        '42_WBC',
        '30_satellite'
                ]

    from sklearn.manifold import TSNE
    from sklearn.preprocessing import StandardScaler
    import pandas as pd

    objective = Objective(datasets[0], objective='delta', k=1, n_noise=0, method='gap', temp_name=None)

    n_norm = objective.train_x.shape[0]

    data_labels = torch.cat([torch.zeros(n_norm), torch.Tensor(objective.test_y[objective.test_y==1][:n_norm]).view(-1)])

    data_features_proj = StandardScaler().fit_transform(
        np.concatenate([objective.train_x[: n_norm], objective.test_x[(objective.test_y==1).reshape(-1)][:n_norm]]))

    tsne = TSNE(perplexity=50, n_components=2, n_jobs=32)
    X_tsne = tsne.fit_transform(data_features_proj)
    X_tsne_data = np.vstack((X_tsne.T, data_labels)).T

    df_tsne = pd.DataFrame(X_tsne_data, columns=['Dim1', 'Dim2', 'class'])

    fig = plt.figure(figsize=(8, 8))
    fig.set(tight_layout=True)
    ax = fig.add_subplot()
    scatter = ax.scatter(df_tsne['Dim1'], df_tsne['Dim2'], c=df_tsne['class'],
                         cmap='RdYlBu', marker='+')
    labels = ['Normal', 'Anomaly']
    legend1 = ax.legend(*[scatter.legend_elements()[0], labels], loc="lower left", fontsize=20)
    ax.add_artist(legend1)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    plt.show()

    # for dataset in datasets:
    #     for n_noise in [0]:
    #         for temp_name in [
    #             'burr', 'log-logistic', 'log-cauchy', 'levy',
    #                           'weibull', 'pareto', 'chi2', 'handcraft']:
    #             for method in ['kde-kl']:
    #                 # if isinstance(k_off, str):
    #                 #     k = int(n_noise * float(k_off))
    #                 # else:
    #                 #     k = n_noise + k_off
    #
    #                 # if k <= 0:
    #                 #     continue
    #
    #                 if method in ['wasserstein', 'kde-wasserstein', 'kde-kl']:
    #                     direction = 'minimize'
    #                 else:
    #                     direction = 'maximize'
    #
    #                 objective = Objective(dataset, objective='delta', k=3, n_noise=n_noise, method=method, temp_name=temp_name)
    #
    #                 algo = optuna.samplers.TPESampler(n_startup_trials=10, n_ei_candidates=24)
    #                 study = optuna.create_study(sampler=algo, direction=direction, storage="sqlite:///record/ocsvm2.db")
    #                 study.optimize(objective.ocsvm_search, n_trials=1000, show_progress_bar=True)
    #                 print(study.best_trial)
    #
    #                 objective.plot_results()
    #                 objective.save_result()

    

    
