import torch
import load_dataset
from data_loader import Data_Loader
from train import trainer
import argparse


def main(num):
    # dl=Data_Loader()
    # train_dataset, test_dataset, labels=dl.get_dataset(args.dataset)

    # dataset_names = [
    #     # 'arrhythmia', 'wine', 'lympho', 'glass', 'vertebral', 'wbc', 'ecoli', 'ionosphere', 'breastw',
    #     # 'pima', 'vowels', 'letter', 'cardio', 'seismic', 'musk', 'speech', 'abalone', 'pendigits',
    #     # 'mammography', 'mulcross', 'forest_cover', 'thyroid', 'optdigits', 'satimage', 'shuttle',
    #     'kdd'
    # ]

    dataset_names = ['10_cover', '11_donors', '12_fault', '13_fraud', '14_glass', '15_Hepatitis', '16_http',
                '17_InternetAds', '18_Ionosphere', '19_landsat',
                '1_ALOI', '20_letter', '21_Lymphography', '22_magic.gamma', '23_mammography', '24_mnist',
                '25_musk', '26_optdigits', '27_PageBlocks', '28_pendigits', '29_Pima', '2_annthyroid',
                '30_satellite', '31_satimage-2', '32_shuttle', '33_skin', '34_smtp', '35_SpamBase', '36_speech',
                '37_Stamps', '38_thyroid', '39_vertebral', '3_backdoor', '40_vowels', '41_Waveform', '42_WBC',
                '43_WDBC', '44_Wilt', '45_wine', '46_WPBC', '47_yeast', '4_breastw', '5_campaign', '6_cardio',
                '7_Cardiotocography', '8_celeba', '9_census', '48_arrhythmia']

    # for dataset_ in dataset_names:
    #     # dataset
    #     print(dataset_)
    #     dl=Data_Loader()
    #     train_dataset, test_dataset, labels=dl.get_dataset(args.dataset)
    #
    #     trainer_object = trainer(args)
    #     f_score, auc_score = trainer_object.train_and_evaluate(train_dataset, test_dataset, labels)
    #     print(dataset_, "F1, AUC: ", f_score, auc_score)
    #     torch.save({'f1': f_score, 'auc': auc_score}, f'./save/SCAD-{dataset_}-default_dataset-2.save')

    #
    dataset_names = ['kdd']
    for dataset_ in dataset_names:
        # dataset

        print(dataset_)
        # train_data, test_data, classes = load_dataset.load_adbench_dataset('./datasets',
        #                                                            name=dataset_)

        train_data, test_data, classes = load_dataset.load_dataset('G:\\fan\\ad\\datasets\\data',
                                                                   name=dataset_)
        if len(classes) <= 2:
            classes = [0]
        for normal_c in classes:
            train_dataset, test_dataset, labels\
                = load_dataset.process_dataset(train_data, test_data,
                                              classes, normal_c,
                                              b_size=None,
                                              normalize=True)
            if train_dataset is None:
                continue

            trainer_object=trainer(args)
            f_score, auc_score=trainer_object.train_and_evaluate(train_dataset,test_dataset,labels)
            print(dataset_, normal_c, "F1, AUC: ",f_score, auc_score)
            torch.save({'f1': f_score, 'auc': auc_score}, f'./save/SCAD-{dataset_}-{normal_c}-std-{num}.save')
            # return (f_score,auc_score)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=256, help='batch size for training')
    parser.add_argument('--dataset', type=str, default='annthyroid', help='name of dataset')
    parser.add_argument('--faster_version', type=str, default='no', help='faster version with a lower number of repeats')
    args = parser.parse_args()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    for i in range(10):
        main(i)
    # f1,auc = main()
    # print("F1, AUC: ",f1, auc)

