from Utils.pipeline import *

dataset_name = "arrhythmia"

#############   select class for training   ###################

# if "dataset_train" in kwargs_data:
#     training_classes = [0]
#     possible_datasets = OrderedDict({
#       "misinfo": ["LUN", "satnews"],
#       "disinfo": ["amazon_lb", "CGFake"],
#       "toxic": ["HSOL", "jigsaw"],
#       "spam": ["assassin", "enron"],
#       "sensitive": ["EDENCE", "FAS"]
#     })
#     dataset_harm_label = dict()
#     j = 0
#     for i, (harm_type, datasets) in enumerate(possible_datasets.items()):
#         for dataset in datasets:
#     #         dataset_harm_label[dataset] = i + 1
#             if dataset in kwargs_data["dataset_train"]:
#                 training_classes.append(j+1)
#                 j += 1
# else:
# anoms = ['CAD', 'Old AMI', 'Old IMI', 'ST', 'SB', 'PVC', 'PSC', 'Left BBB', 'Right BBB', 'Left VH', 'AF', 'Others']
# OrderedDict([('CAD', 1), ('Old AMI', 2), ('Old IMI', 3), ('ST', 4), ('SB', 5), ('PVC', 6), ('PSC', 7), ('Left BBB', 8), ('Right BBB', 9), ('Left VH', 10), ('AF', 14), ('Others', 15)])
# 'Left BBB', 'Right BBB'
# training_classes = [0, 8, 9]
training_classes = [0, 2, 3]

###############################################################
if dataset_name == "kitsune":
    kwargs_data = {"mirai": False, "custom": "fuzzing"}
    n_dim = 115
elif dataset_name == "kdd":
    kwargs_data = {}
    n_dim = 119
elif dataset_name == "thyroid":
    kwargs_data = {}
    n_dim = 21
elif dataset_name == "arrhythmia":
    kwargs_data = {"training_classes": training_classes}
    n_dim = 279
elif dataset_name == "nlp":
    kwargs_data = {}
    n_dim = 384
else:
    raise ValueError("Dataset name must be either 'kitsune' or 'kdd' or 'thyroid' or 'nlp'")

val_split = 0.1
scaler = MinMaxScaler()

train_index_match_col = 'attack_map'
train_label_col = 'normal_flag'
test_index_match_col = 'attack_map'
test_label_col = 'attack_map'

# model_type = "NN"  # NN, SVM, IsolationForest, DeepSVDD

# model_type = {"unsup_AD": ["IsolationForest", "DeepSVDD"], "sup_BC": ["SVM", "NN"]}
model_type = OrderedDict({"unsup_AD": ["IsolationForest", "OCSVM", "DeepSVDD"], "sup_BC": ["RF", "SVM", "NN"]})

one_class = False
quadratic_bump = False
sigmoid_head = True
classifier_layers = 2 + 1
n_neurons = 200
rep_dim = [n_dim] + [n_neurons for i in range(classifier_layers - 1)]
assert len(rep_dim) == classifier_layers
activation = torch.nn.LeakyReLU()
# activation = torch.nn.LeakyReLU()
dropout = 0

classifier, classifier_name = get_classifier_and_name(
    "NN", classifier_layers=classifier_layers, rep_dim=rep_dim, activation=activation,
    one_class=one_class,
    dropout=dropout, sigmoid_head=sigmoid_head, quadratic_bump=quadratic_bump, seed=179)

# CHANGE EPOCH
# epochs = 5
epochs = 500
weight_decay = 0
# weight_decay = 1e-2
optimizer = torch.optim.Adam
optimizer_params = {'lr': 3e-4}
# from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
# lr_scheduler = LinearWarmupCosineAnnealingLR
# lr_scheduler_params = {"warmup_epochs": epochs//10, "max_epochs": epochs}
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR
# lr_scheduler_params = {"T_max": epochs}
lr_scheduler = None
lr_scheduler_params = dict()
use_hinge = False
patience = 7
batch_size = 1024
repeats = 3

theory_nn = False
custom = True

svm_params = {"C": 1., "class_weight": 'balanced'}
# CHANGE n_estimators
rf_params = {"n_estimators": 100, "min_samples_split": 10, "min_samples_leaf": 1,
             "max_features": None, "max_leaf_nodes": None, "class_weight": "balanced_subsample"}
nn_params = {"classifier": classifier, "classifier_name": classifier_name,
             "positive_class": 0, "epochs": epochs, "optimizer": optimizer, "optimizer_params": optimizer_params,
             "lr_scheduler": lr_scheduler, "lr_scheduler_params": dict(),
             "loss_fn": torch.nn.functional.binary_cross_entropy,
             "neg_labels": False, "patience": patience, "seed": 42, "exp_name": 'kdd', "batch_size": batch_size}
ocsvm_params = {"cache_size": 2000, "nu": 1e-10}
isolationforest_params = {"n_estimators": 100, "contamination": "auto"}
lof_params = {}

# CHANGE EPOCH
deepsvdd_params = {
    "rep_dim": [n_dim, n_dim // 2, n_dim // 4], "epochs": 500, "batch_size": 1024, "patience": 7,
    "optimizer": torch.optim.Adam, "optimizer_params": {'lr': 3e-4}, "lr_scheduler": None,
    "lr_scheduler_params": dict(), "exp_name": "kdd"}

synthetic_anom_ratio = 0
synthetic_val_anom_constant = True
synthetic_anom_test_ratio = False
one_hot_col_len = None
binary_cols = True
delta = 0.

fpr = 0.05
eval_only = False

args = {
    "repeats": repeats,
    "model_type": model_type,
    "classifier_layers": classifier_layers,
    "rep_dim": rep_dim,
    "one_class": one_class,
    "sigmoid_head": sigmoid_head,
    "quadratic_bump": quadratic_bump,
    "activation": activation,
    "dropout": dropout,
    "epochs": epochs,
    "weight_decay": weight_decay,
    "optimizer_params": optimizer_params,
    "lr_scheduler": lr_scheduler,
    "lr_scheduler_params": lr_scheduler_params,
    "use_hinge": use_hinge,
    "patience": patience,
    "batch_size": batch_size,
    "theory_nn": theory_nn,
    "custom": custom,
    "num_real_training": 1,  # can change the proportion of real normal data used during training
    "synthetic_anom_ratio": synthetic_anom_ratio,
    "synthetic_anom_test_ratio": synthetic_anom_test_ratio,
    "synthetic_val_anom_constant": synthetic_val_anom_constant,
    "one_hot_col_len": one_hot_col_len,
    "binary_cols": binary_cols,
    "delta": delta,
    "fpr": fpr,
    "eval_only": eval_only,
    "shallow_params": {"svm_params": svm_params, "rf_params": rf_params,
                       "ocsvm_params": ocsvm_params, "isolationforest_params": isolationforest_params
                       },
    "deepsvdd_params": deepsvdd_params,
    "deepsvdd_ERM": False,
    "nn_params": nn_params
}

aggregated_results = main(
    dataset_name, kwargs_data, val_split, training_classes, scaler,
    train_index_match_col, train_label_col, test_index_match_col, test_label_col,
    **args
)

print(aggregated_results)
