from main import main

import torch

import argparse

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

if __name__ == "__main__":
    data_path_list = ["init/musk.mat","init/breastw.mat","arff/KDDCup99_idf.arff", "arff/WDBC_withoutdupl_v01.arff"]
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default="data", help='data folder')
    parser.add_argument('--summary', type=str, default="baseline", help='summary for experiment')
    parser.add_argument('--device', type=str, default="cuda:0", help='device', choices=['cuda:0', 'cuda:1', 'cuda:2'])
    
    parser.add_argument('--number_linear', type=int, default=2, help='number of linears')
    parser.add_argument('--activation', type=str, default="Sigmoid", help='activate function', choices=['Sigmoid', 'ReLU'])

    parser.add_argument('--epochs', type=int, default=500, help='number of epochs of training')
    parser.add_argument('--number_experiment', type=int, default=20, help='summary for experiment')

    parser.add_argument('--is_ln', type=str2bool, default='0', help='ln(x)')
    parser.add_argument('--is_scale', type=str2bool, default='1', help='if use MinMaxScale')

    parser.add_argument('--if_neg_in_every_epoch', type=str2bool, default='0', help='if neg every epoch')
    parser.add_argument('--if_neg_every_feature', type=str2bool, default='1', help='if neg every feature')
    parser.add_argument('--neg_rate', type=float, default=1, help='summary for experiment')
    parser.add_argument('--neg_min', type=float, default=0, help='neg min')
    parser.add_argument('--neg_max', type=float, default=1, help='neg max')

    parser.add_argument('--loss', type=str, default="BCELoss", help='loss function', choices=['BCELoss', 'MSELoss'])

    parser.add_argument('--optimizer', type=str, default="SGD", help='optimizer function', choices=['SGD', 'Adam'])
    parser.add_argument('--learning_rate', type=float, default=0.005, help='summary for experiment')
    parser.add_argument('--weight_decay', type=float, default=1e-6, help='L2 norm weight decay')
    parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')

    parser.add_argument('--paint_classificaiton_auc', type=str2bool, default='0', help='if paint classification auc')
    parser.add_argument('--paint_classificaiton_precision', type=str2bool, default='0', help='if paint classification precision')

    parser.add_argument('--use_classification_auc_early_stopping', type=str2bool, default='0', help='if use classification auc early stop')
    parser.add_argument('--contamination_threshold', type=float, default= 0.1, help='contamination bigger than it will use 0.93 sub some value')

    parser.add_argument('--use_classification_precision_early_stopping', type=str2bool, default='0', help='if use classification precision early stop')
    parser.add_argument('--sub_more_rate', type=float, default=0, help='if sub more precision')

    parser.add_argument('--delta_epochs', type=int, default=100, help='epoch threshold')
    parser.add_argument('--delta_threshold', type=float, default=0.01, help='classificaiton auc or precision must gain at list delta_threshold every delta_epochs')

    parser.add_argument('--save_dir', type=str, default="results", help='save result root path')
    parser.add_argument('--save_process_model', type=str2bool, default=0, help='if save process model')
    parser.add_argument('--save_frequency', type=int, default=100, help='save model frequency')
    
    print(parser.parse_args())
    config = vars(parser.parse_args())
    config['device'] = torch.device(config['device'] if torch.cuda.is_available() else "cpu")
    config['data_path_list'] = data_path_list
    print(config)
    main(config)