import argparse

optim_choices=['Adam', 'SGD']

parser = argparse.ArgumentParser(description='PyHEXgraph ImageNet Training')

parser.add_argument('-b', dest='beta', default=0, type=float,
                    metavar='N',
                    help='beta value (default: 0)')

parser.add_argument('--spl', dest='spl', default=False, type=bool,
                help='use spl implementation (default:False)')

parser.add_argument('--pb', dest='pre_beta', default=None, type=int, help='pre beta value used to compute beta if provided')

parser.add_argument('-e', '--epochs', dest='epochs', default=60, type=int, metavar='Nb epochs',
                    help='number of total epochs to run (default: 60)')

parser.add_argument('-a', '--arch', dest='model', default='efficientnet_b0', type=str,
                    help='model architecture (default: efficientnet_b0)')

parser.add_argument('--pre', dest='pretrained', default=False, type=bool,
                    help='use pretrained weights (default:False)')

parser.add_argument('-o', '--opt', dest='optim', default='Adam', type=str,
                    choices=optim_choices,
                    help='optimizer: ' +
                        ' | '.join(optim_choices) +
                        ' (default: Adam)')

parser.add_argument('--lr', '--learning-rate', dest='lr', default=10**(-4), type=float,
                    metavar='LR', help='initial learning rate')

parser.add_argument('--loss_norm', dest='loss_normalization', default=True, type=bool,
                    help='normalization of the loss given the number of variables')

parser.add_argument('--bs', dest='bs', default=64, type=int,
                    metavar='N',
                    help='mini-batch size (default: 64)')

parser.add_argument('-w', '--workers', dest='num_workers', default=8, type=int, metavar='Nb workers',
                    help='number of data loading workers (default: 8)')

parser.add_argument('--nb_leaves', dest='nb_leaves', default=100, type=int, metavar='Nb leaves',
                    help='number of leaf nodes kept of the hierarchy (default: 100)')

parser.add_argument('-p', '--pruning', dest='pruning', default=True, type=bool, metavar='Pruning',
                    help='prunes the hierarchy tree : delete nodes that have only one parent and one child (default: True)')

parser.add_argument('--ae', '--assume_exclu', dest='assume_exclusive', default=True, type=bool, metavar='Assume exclusive',
                    help='assumes variable nodes to be exclusive if they dont share any child nodes (default: True)')

parser.add_argument('--mapped', dest='mapped', default=False, type=bool, metavar='Mapped',
                    help='maps samples of unselected leaf nodes to the closest parent node in the hierarchy (default: False)')

parser.add_argument('--tr', dest='trainset_ratio', default=1, type=float, metavar='Train set ratio',
                    help='ratio of train set kept for training (default: 1)')

parser.add_argument('--chkpt', dest='save_checkpoint', default=True, type=bool, metavar='Save checkpoint',
                    help='saves the parameters of the model to a checkpoint file at the end of training (default: True)')

parser.add_argument('--load', dest='load_checkpoint', default=False, type=bool, metavar='Load checkpoint',
                    help='loads the parameters of the model from a checkpoint file at the begining of training (default: False)')

parser.add_argument('--load_run_id', dest='load_run_ID', default=None, type=str, metavar='Load run ID',
                    help='ID of the run from which to retrieve the config parameters')

parser.add_argument('--load_job_id', dest='load_job_ID', default=None, type=int, metavar='Load job ID',
                    help='ID of the job from which to load the checkpoint')

parser.add_argument('--job_id', dest='job_ID', default=None, type=int, metavar='Job ID',
                    help='ID of the current job')

parser.add_argument('--clamp', dest='clamping', default=False, type=bool, metavar='Clamping',
                    help='clamps the scores produced by the network to avoid numerical instabilities in the HEX-layer (default: True)')

parser.add_argument('--mac', dest='max_clamp', default=40, type=int, metavar='Max clamp',
                    help='maximum value of clamped scores (default: 40)')

parser.add_argument('--mic', dest='min_clamp', default=-15, type=int, metavar='Min clamp',
                    help='minimum value of clamped scores (default: -15)')

parser.add_argument('--sat', dest='saturation_callout', default=False, type=bool, metavar='Saturation callout',
                    help='stops training when the scores are saturated (default: False)')
                
parser.add_argument('--gc', dest='grad_clip', default=False, type=bool, metavar='Gradient clipping',
                    help='does gradient clipping (default: False)')

parser.add_argument('-m', '--momentum', dest='momentum', default=0, type=float, metavar='momentum',
                    help='momentum (default: 0)')

parser.add_argument('--sched', dest='scheduler', default=False, type=bool, metavar='Schedular',
                    help='Reduces the learning rate exponentially during training (default: False)')

parser.add_argument('--sched_step', dest='step_size', default=20, type=int, metavar='Scheduler step period',
                    help='step period of the scheduler (default: 20)')

parser.add_argument('--gamma', dest='gamma', default=0.95, type=float, metavar='Gamma scheduler',
                    help='Gamma value of the scheduler which controls the rate of lr decrease (default: 0.95)')

parser.add_argument('--do_eval', dest='do_eval', default=True, type=bool, metavar='Evaluation',
                    help='does evaluation on the test set (default: True)')

parser.add_argument('--eval_step', dest='eval_step', default=1, type=int, metavar='Evaluation step',
                    help='step period of the evaluation (default: 1)')

parser.add_argument('--log', dest='log', action=argparse.BooleanOptionalAction, default=True, type=bool, metavar='Log metrics',
                    help='log metrics (default: True)')

parser.add_argument('--lp', dest='log_path', default='./wandb', type=str, metavar="Logging path",
            help="Path where to save WandB logs (default: ./wandb)")

parser.add_argument('--online', dest='online', default=True, type=bool, metavar='Log online',
                    help='log online (default: True)'),

parser.add_argument('--log_train', dest='log_train', default=False, type=bool, metavar='Log train metrics',
                    help='log train metrics (default: False)')

parser.add_argument('--rec_leaves', dest='record_leaves_acc', default=False, type=bool,
                    help='normalization of the loss given the number of variables')

# parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
#                     metavar='W', help='weight decay (default: 1e-4)',
#                     dest='weight_decay')

# parser.add_argument('-p', '--print-freq', default=10, type=int,
#                     metavar='N', help='print frequency (default: 10)')

# parser.add_argument('--resume', default='', type=str, metavar='PATH',
#                     help='path to latest checkpoint (default: none)')

# parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
#                     help='evaluate model on validation set')

# parser.add_argument('--world-size', default=-1, type=int,
#                     help='number of nodes for distributed training')

# parser.add_argument('--rank', default=-1, type=int,
#                     help='node rank for distributed training')

# parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
#                     help='url used to set up distributed training')

# parser.add_argument('--dist-backend', default='nccl', type=str,
#                     help='distributed backend')

parser.add_argument('--seed', default=42, type=int,
                    help='seed for initializing training.')

# parser.add_argument('--gpu', default=None, type=int,
#                     help='GPU id to use.')

# parser.add_argument('--multiprocessing-distributed', action='store_true',
#                     help='Use multi-processing distributed training to launch '
#                         'N processes per node, which has N GPUs. This is the '
#                         'fastest way to use PyTorch for either single node or '
#                         'multi node data parallel training')