from Utils.data_methods_synthetic import *
from Utils.model_methods import *
from Utils.eval_methods import *
from sklearn.svm import SVC


def experiment_run(seed, seed_shuffle, fpr,
                   num_training_normal, num_train_normal, num_val_normal, num_test_normal,
                   num_training_anom, num_train_anom, num_val_anom, num_test_anom,
                   n_dimensions, m_nonzero_entries=2, target_density=False, rejection_sampling=False, exp=1, **kwargs):
    """

    Args:
        seed:
        seed_shuffle:
        fpr:
        num_training_normal:
        num_train_normal:
        num_val_normal:
        num_test_normal:
        num_training_anom:
        num_train_anom:
        num_val_anom:
        num_test_anom:
        n_dimensions:
        m_nonzero_entries:
        target_density:
        exp:
        **kwargs:

    Returns:

    """

    if exp == 1:
        add = False
    elif exp == 2:
        add = True
    else:
        raise ValueError('Exp must be 1 or 2')

    seed_test = seed + 500
    # label 0 for anom, 1 for normal
    x_train, y_train, x_val, y_val, x_test, y_test = get_data_and_label(
        num_training_normal, num_train_normal, num_val_normal, num_test_normal,
        num_training_anom, num_train_anom, num_val_anom, num_test_anom,
        n_dimensions, m_nonzero_entries=m_nonzero_entries, add=add, target_density=target_density,
        rejection_sampling=rejection_sampling,
        seed=seed, seed_test=seed_test, seed_shuffle=seed_shuffle, **kwargs)

    # Methods

    # SVM
    print("RBF Kernel SVM")
    pr_auc_svm, roc_auc_svm = train_eval_svm(x_train, y_train, x_test, y_test, fpr=fpr, kernel='rbf')
    print("Degree 3 Polynomial Kernel SVM")
    pr_auc_svm_poly, roc_auc_svm_poly = train_eval_svm(x_train, y_train, x_test, y_test, fpr=fpr, kernel='poly',
                                                       degree=3)

    # NN (Binary Classifier)
    train_loader = get_dataloader(x_train, y_train)
    val_loader = get_dataloader(x_val, y_val)
    test_loader = get_dataloader(x_test, y_test)

    rep_dim = kwargs.get("rep_dim", 10)
    classifier_layers = kwargs.get("classifier_layers", 3)
    epochs = kwargs.get("epochs", 100)
    optimizer = kwargs.get("optimizer", torch.optim.Adam)
    lr = kwargs.get("lr", 1e-3)
    patience = kwargs.get("patience", 7)
    plot = kwargs.get("plot", True)
    exp_num = kwargs.get("exp_num", 1)
    quantile = fpr
    tpr = False
    one_class = False

    (best_model, best_model_ckpt_path, precision, recall, f1, average_precision, auroc, df_results,
     threshold) = train_eval(
        classifier_layers, rep_dim, train_loader, val_loader, test_loader, testing_datasets=None, class_label=None,
        epochs=epochs, optimizer=optimizer, lr=lr, patience=patience, quantile=quantile, tpr=tpr, plot=plot,
        seed=seed, one_class=one_class, eval_comments=False, exp_num=exp_num)

    print("Best NN (BC) Model:", best_model_ckpt_path)

    # NN (One-Class)
    one_class = True

    (best_model_oc, best_model_ckpt_path_oc, precision_oc, recall_oc, f1_oc, average_precision_oc, auroc_oc,
     df_results_oc, threshold_oc) = train_eval(
        classifier_layers, rep_dim, train_loader, val_loader, test_loader, testing_datasets=None, class_label=None,
        epochs=epochs, optimizer=optimizer, lr=lr, patience=patience, quantile=quantile, tpr=tpr, plot=plot,
        seed=seed, one_class=one_class, eval_comments=False, exp_num=exp_num)

    print("Best NN (OC) Model:", best_model_ckpt_path_oc)

    return pr_auc_svm, roc_auc_svm, pr_auc_svm_poly, roc_auc_svm_poly, average_precision, auroc, average_precision_oc, auroc_oc


def train_eval_svm(x_train, y_train, x_test, y_test, fpr=0.05, **kwargs):
    svm = SVC(**kwargs)
    svm.fit(x_train, y_train)
    y_pred = svm.decision_function(x_test)
    y_anom_score = 1. - y_pred
    y_true_anom = 1 - y_test

    classes = ["Anomaly", "Normal"]
    pr_auc_svm, roc_auc_svm = plot_metrics(y_true_anom, y_anom_score, classes, fpr=fpr, name="SVM")

    return pr_auc_svm, roc_auc_svm


def get_df_results(results):
    df_results = pd.DataFrame(data=results).T
    mean, std = df_results.mean(axis=1), df_results.std(axis=1)
    df_results["mean"] = mean
    df_results["std"] = std

    return df_results
