import argparse
import torchvision.models as models

model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

def get_parser():

    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('--train-data-dir', metavar='DIR', required=True, help='Path to training dataset.')
    parser.add_argument('--val-data-dir', metavar='DIR', required=True, help='Path to validation dataset.')
    parser.add_argument('--phases', type=str,
                    help='Specify epoch order of data resize and learning rate schedule: [{"ep":0,"sz":128,"bs":64},{"ep":5,"lr":1e-2}]')
    parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                        help='number of data loading workers (default: 8)')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)')
    parser.add_argument('--init-bn0', action='store_true', help='Intialize running batch norm mean to 0')
    parser.add_argument('--print-freq', '-p', default=5, type=int,
                        metavar='N', help='log/print every this many steps (default: 5)')
    parser.add_argument('--no-bn-wd', action='store_true', help='Remove batch norm from weight decay')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--half-prec', action='store_true', help='Run model in half-precision mode')
    
    parser.add_argument('--distributed', action='store_true', help='Run distributed training. Default True')
    parser.add_argument('--dist-url', default='env://', type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--local_rank', default=0, type=int,
                        help='Used for multi-process training. Can either be manually set ' +
                        'or automatically set by using \'python -m multiproc\'.')
    parser.add_argument('--logdir', default='', type=str,
                        help='where logs go')
    parser.add_argument('--short-epoch', action='store_true',
                        help='make epochs short (for debugging)')
    parser.add_argument('--model-path', type=str, help="path for model of natural variation")
    parser.add_argument('--delta-dim', type=int, default=2, help="dimension of nuisance latent space")
    parser.add_argument('--setup-verbose', action='store_true', help='Print setup messages to console')
    parser.add_argument('--data-size', type=int, default=224, help="Size of each image")
    parser.add_argument('--batch-size', type=int, default=256, help='Training/validation batch size')

    parser.add_argument('--architecture', default='resnet50', type=str, help='Architecture for classifier')
    parser.add_argument('--pretrained', action='store_true', help='Use pretrained model')
    parser.add_argument('--apex-opt-level', default='O1', type=str, help='opt_level for Apex amp initialization')
    parser.add_argument('--num-classes', default=1000, type=int, help='Number of classes in datset')
    parser.add_argument('-k', default=1, type=int, help='Hyperparameter k for model-based training')
    parser.add_argument('--save-path', type=str, help='Path for saving outputs')
    parser.add_argument('--mrt', action='store_true', help='Run MRT with different values of k')
    parser.add_argument('--mda', action='store_true', help='Run MDA with different values of k')
    parser.add_argument('--pgd', action='store_true', help='Run PGD')
    parser.add_argument('--mat', action='store_true', help='RUN MAT with different values of k')
    parser.add_argument('--optimizer', type=str, default='sgd', choices=['sgd', 'adadelta'], help='Optimization algorithm to use')
    args = parser.parse_args()

    return args