import argparse

optim_choices=['SGD, Adam']

parser = argparse.ArgumentParser(description='Cifar Classification with HEX-graphs')

parser.add_argument('-b', dest='beta', default=0, type=float,
                    metavar='beta',
                    help='beta value (default: 0)')

parser.add_argument('--pb', dest='pre_beta', default=None, type=int, help='pre beta value used to compute beta if provided')

parser.add_argument('-e', '--epochs', dest='epochs', default=60, type=int, metavar='Nb epochs',
                    help='number of total epochs to run (default: 60)')

# parser.add_argument('--gr', '--growth-rate', dest='gr', default=12, type=int,
#                     metavar='LR', help='controls the growth rate of the DenseNet (default: 12)')

parser.add_argument('-a', '--arch', dest='model', default='efficientnet_b0', type=str,
                    help='model architecture (default: efficientnet_b0)')

parser.add_argument('--pre', dest='pretrained', default=False, type=bool,
                    help='use pretrained weights (default:False)')

parser.add_argument('-o', '--opt', dest='optim', default='Adam', type=str,
                    choices=optim_choices,
                    help='optimizers: ' +
                        ' | '.join(optim_choices) +
                        ' (default: Adam)')

parser.add_argument('--lr', '--learning-rate', dest='lr', default=10**(-4), type=float,
                    metavar='LR', help='initial learning rate')

parser.add_argument('--loss_norm', dest='loss_normalization', default=True, type=bool,
                    help='normalization of the loss given the number of variables')

parser.add_argument('--bs', dest='bs', default=64, type=int,
                    metavar='batch size',
                    help='mini-batch size (default: 64)')

parser.add_argument('-w', '--workers', dest='num_workers', default=8, type=int, metavar='Nb workers',
                    help='number of data loading workers (default: 8)')

parser.add_argument('--tr', dest='trainset_ratio', default=1, type=float, metavar='Train set ratio',
                    help='ratio of train set kept for training (default: 1)')

parser.add_argument('--chkpt', dest='save_checkpoint', default=True, type=bool, metavar='Save checkpoint',
                    help='saves the parameters of the model to a checkpoint file at the end of training (default: True)')

parser.add_argument('--load', dest='load_checkpoint', default=False, type=bool, metavar='Load checkpoint',
                    help='loads the parameters of the model from a checkpoint file at the begining of training (default: False)')

parser.add_argument('--load_run_id', dest='load_run_ID', default=None, type=str, metavar='Load run ID',
                    help='ID of the run from which to retrieve the config parameters')

parser.add_argument('--load_job_id', dest='load_job_ID', default=None, type=int, metavar='Load job ID',
                    help='ID of the job from which to load the checkpoint')

parser.add_argument('--job_id', dest='job_ID', default=None, type=int, metavar='Job ID',
                    help='ID of the current job')

parser.add_argument('--clamp', dest='clamping', default=False, type=bool, metavar='Clamping',
                    help='clamps the scores produced by the network to avoid numerical instabilities in the HEX-layer (default: True)')

parser.add_argument('--mac', dest='max_clamp', default=40, type=int, metavar='Max clamp',
                    help='maximum value of clamped scores (default: 40)')

parser.add_argument('--mic', dest='min_clamp', default=-15, type=int, metavar='Min clamp',
                    help='minimum value of clamped scores (default: -15)')

parser.add_argument('--sat', dest='saturation_callout', default=False, type=bool, metavar='Saturation callout',
                    help='stops training when the scores are saturated (default: False)')
                
parser.add_argument('--gc', dest='grad_clip', default=False, type=bool, metavar='Gradient clipping',
                    help='does gradient clipping (default: False)')

parser.add_argument('-m', '--momentum', dest='momentum', default=0, type=float, metavar='momentum',
                    help='momentum (default: 0)')

parser.add_argument('--sched', dest='scheduler', default=False, type=bool, metavar='Schedular',
                    help='Reduces the learning rate exponentially during training (default: False)')

parser.add_argument('--sched_step', dest='step_size', default=20, type=int, metavar='Scheduler step period',
                    help='step period of the scheduler (default: 20)')

parser.add_argument('--gamma', dest='gamma', default=0.95, type=float, metavar='Gamma scheduler',
                    help='Gamma value of the scheduler which controls the rate of lr decrease (default: 0.95)')

parser.add_argument('--do_eval', dest='do_eval', default=True, type=bool, metavar='Evaluation',
                    help='does evaluation on the test set (default: True)')

parser.add_argument('--eval_step', dest='eval_step', default=1, type=int, metavar='Evaluation step',
                    help='step period of the evaluation (default: 1)')

parser.add_argument('--log', dest='log', default=True, type=bool, metavar='Log metrics',
                    help='log metrics (default: True)')

parser.add_argument('--lp', dest='log_path', default='./wandb', type=str, metavar="Logging path",
                help="Path where to save WandB logs (default: ./wandb)")

parser.add_argument('--online', dest='online', default=False, type=bool, metavar='Log online',
                    help='log online (default: True)'),

parser.add_argument('--log_train', dest='log_train', default=False, type=bool, metavar='Log train metrics',
                    help='log train metrics (default: False)')

parser.add_argument('--rec_leaves', dest='record_leaves_acc', default=False, type=bool,
                    help='normalization of the loss given the number of variables')

parser.add_argument('--seed', default=42, type=int,
                    help='seed for initializing training.')