

CLASSIFIER_TYPE = 'catboost'
MODEL_TYPE = 'AE'
FOLD = 1

def stringify(x):
    if x is None:
        return None
    else:
        string_list = [str(i) for i in x]
        return ' '.join(string_list)

class KeelDataset:
    def __init__(self, dataset, split, cat_columns=None):
        self.name = dataset
        self.cat_columns = cat_columns
        self.cat_columns_argument = stringify(cat_columns)
        self.train_pt = f"datasets/Keel/preprocessed/{dataset}/{dataset}_train_{split}.pt"
        self.test_pt = f"datasets/Keel/preprocessed/{dataset}/{dataset}_test_{split}.pt"
        self.new_minority_pt = f"datasets/Keel/preprocessed/{dataset}/{dataset}_new_minority_deep_smote.pt"


class ImblearnDataset:
    def __init__(self, dataset, split, cat_columns=None):
        self.name = dataset
        self.cat_columns = cat_columns
        self.cat_columns_argument = stringify(cat_columns)
        self.train_pt = f"datasets/imblearn/{dataset}/{dataset}_train_{split}.pt"
        self.test_pt = f"datasets/imblearn/{dataset}/{dataset}_test_{split}.pt"
        self.new_minority_pt = f"datasets/imblearn/{dataset}/{dataset}_new_minority_deep_smote.pt"


keel_datasets = [("glass-0-1-6_vs_2", FOLD, None),
                 ("glass2", FOLD, None),
                 ("glass4", FOLD, None),
                 ("page-blocks-1-3_vs_4", 5, None),     # only fold for which not all methods achieve AP=1
                 ("yeast-0-5-6-7-9_vs_4", FOLD, None),
                 ("yeast-1_vs_7", FOLD, None),
                 ("yeast-1-2-8-9_vs_7", FOLD, None),
                 ("yeast-1-4-5-8_vs_7", FOLD, None),
                 ("yeast-2_vs_4", FOLD, None),
                 ("yeast-2_vs_8", FOLD, None),
                 ("yeast4", FOLD, None),
                 ("yeast5", FOLD, None),
                 ("yeast6", FOLD, None)]

if CLASSIFIER_TYPE == 'svm':
    imblearn_datasets = [('ecoli', 3, None),
                         ('letter_img', FOLD, None),
                         ('libras_move', 4, None),
                         ('mammography', FOLD, None),
                         ('ozone_level', FOLD, None),
                         ('pen_digits', FOLD, None),
                         ('satimage', FOLD, None),
                         ('spectrometer', FOLD, None),
                         ('us_crime', FOLD, None),
                         #('webpage', FOLD, None),  # large dataset + high #F
                         ('wine_quality', FOLD, None),
                         ('yeast_me2', FOLD, None),
                         ('yeast_ml8', FOLD, None),
                         ('coil_2000', FOLD, None),
                         ('oil', FOLD, None),
                         ('optical_digits', FOLD, None),
                         #('arrhythmia', FOLD, None), # too high #F for SVM
                         ('car_eval_34', FOLD, None)]
else:
    imblearn_datasets = [('ecoli', 3, None),
                         ('letter_img', FOLD, None),
                         ('libras_move', 4, None),
                         ('mammography', FOLD, None),
                         ('ozone_level', FOLD, None),
                         ('pen_digits', FOLD, None),
                         ('satimage', FOLD, None),
                         ('spectrometer', FOLD, None),
                         ('us_crime', FOLD, None),
                         ('webpage', FOLD, None),  # large dataset + high #F
                         ('wine_quality', FOLD, None),
                         ('yeast_me2', FOLD, None),
                         ('yeast_ml8', FOLD, None),
                         ('coil_2000', FOLD, None),
                         ('oil', FOLD, None),
                         ('optical_digits', FOLD, None),
                         ('arrhythmia', FOLD, None),
                         ('car_eval_34', FOLD, None)]

if CLASSIFIER_TYPE == 'catboost':
    # with categories
    imblearn_datasets += [
                         ('abalone', FOLD, [0]),
                         ('abalone_19', FOLD, [0]),
                         ('sick_euthyroid', FOLD, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 17, 19, 21, 23]),
                         ('thyroid_sick', FOLD, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24, 27]),
                         ('solar_flare_m0', FOLD, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
                         ]


DATASETS = [KeelDataset(dataset, split, cat_columns) for dataset, split, cat_columns in keel_datasets] + \
           [ImblearnDataset(dataset, split, cat_columns) for dataset, split, cat_columns in imblearn_datasets]
