import argparse
import os


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loss_function',
                        choices=[
                            'DeepHit',
                            'DRSA',
                            'logarithmic_pwl',
                            'logarithmic_simple_pwl',
                            'Brier_pwl',
                            'ProperRPS_pwl',
                            'SurvivalCRPS_pwl',
                            'Portnoy_pwl',
                            'SurvivalGame_Brier'
                        ],
                        default='logarithmic_pwl')
    parser.add_argument('--DeepHit_alpha', type=float, default=0.1)
    parser.add_argument('--DRSA_alpha', type=float, default=0.25)
    parser.add_argument('--neural_network', choices=['MLP', 'LSTM'],
                        default='MLP')
    parser.add_argument('--model',
                        choices=['Cox', 'Kaplan-Meier', 'Softmax',
                                 'SurvivalGame'],
                        default='Softmax')
    parser.add_argument('--withoutEM', action='store_true')  # default = False

    parser.add_argument('--num_epoch', type=int, default=300)
    parser.add_argument('--learning_rate', type=float, default=0.0001)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--num_neuron', type=int, default=128)
    parser.add_argument('--num_bin', type=int, default=32)

    parser.add_argument('--early_stopping_epoch', type=int, default=0)
    parser.add_argument('--early_stopping_threshold', type=int, default=0)

    parser.add_argument('--cross_validation', type=int, default=0)
    parser.add_argument('--use_pytorch_lightning', action='store_false')  # defalt = True

    parser.add_argument('--output_log', action='store_false')  # defalt = True
    parser.add_argument('--log_dir', default='./logs')
    parser.add_argument('--save_prediction', default=False)

    parser.add_argument('--dataset_name',
                        choices=['flchain', 'prostateSurvival', 'support'],
                        default='flchain')

    args = parser.parse_args()

    args.dir_name = os.path.join(args.log_dir, 'log+'+args.dataset_name)
    if args.model=='Softmax':
        args.model_name = '%s_%d' % (args.loss_function, args.num_bin)
    else:
        args.model_name = '%s_%d' % (args.model, args.num_bin)
    if args.loss_function=='DeepHit':
        args.model_name += '_%f' % args.DeepHit_alpha
    if args.withoutEM:
        args.model_name += '_withoutEM'

    return args
