import argparse


def get_cifar10_args(model_names):
    parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch')

    # dataset setup
    parser.add_argument('--num-class', type=int, help='number of class (10 for cifar10 and 100 for cifar100)',
                        required=True)

    # model setup
    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet44',
                        choices=model_names,
                        help='model architecture: ' + ' | '.join(model_names) +
                             ' (default: resnet32)')

    # training setup
    parser.add_argument('--epochs', default=200, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)')

    # dataloader setup
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('-b', '--batch-size', default=128, type=int,
                        metavar='N', help='mini-batch size (default: 128)')

    parser.add_argument('--save-dir', dest='save_dir',
                        help='The directory used to save the trained models',
                        default='/data/omf/model/DataValidation/CV/experiments/', type=str)

    parser.add_argument('--print-freq', '-p', default=50, type=int,
                        metavar='N', help='print frequency (default: 50)')
    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                        help='use pre-trained model')

    # validation setting
    parser.add_argument('--val_method', '-vm', required=True, choices=['holdout', 'kfold', 'jkfold', 'LZO',
                                                                       'split_free', 'split_free_joint', 'split_free_random', 'split_free_test',
                                                                        'split_free_noval', 'split_free_holdout'])
    parser.add_argument('--J', type=int, default=1, required=False)
    parser.add_argument('--k', type=int, required=False, default=1)
    parser.add_argument('--save_name', type=str, required=False)

    args = parser.parse_args()
    args.save_name = args.save_dir + args.val_method + '.pkl'
    args.save_dir = args.save_dir + args.val_method
    return args


def get_cifar_fe_args(model_names):
    parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR in pytorch')

    # dataset setup
    parser.add_argument('--num-class', type=int, help='number of class (10 for cifar10 and 100 for cifar100)',
                        required=True)

    # model setup
    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet110',
                        choices=model_names,
                        help='model architecture: ' + ' | '.join(model_names) +
                             ' (default: resnet110)')

    # training setup
    parser.add_argument('--epochs', default=200, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)')

    # dataloader setup
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('-b', '--batch-size', default=128, type=int,
                        metavar='N', help='mini-batch size (default: 128)')

    parser.add_argument('--save-dir', dest='save_dir',
                        help='The directory used to save the trained models',
                        default='/data/omf/model/DataValidation/CV/feature_extractor/', type=str)

    parser.add_argument('--print-freq', '-p', default=50, type=int,
                        metavar='N', help='print frequency (default: 50)')
    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                        help='use pre-trained model')

    args = parser.parse_args()
    print(args)
    return args


def get_cifar_analysis_args(model_names):
    parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR in pytorch')

    # dataset setup
    parser.add_argument('--num-class', type=int, help='number of class (10 for cifar10 and 100 for cifar100)',
                        required=False, default=100)

    # model setup
    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet44',
                        choices=model_names,
                        help='model architecture: ' + ' | '.join(model_names) +
                             ' (default: resnet32)')

    # dataloader setup
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('-b', '--batch-size', default=128, type=int,
                        metavar='N', help='mini-batch size (default: 128)')
    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                        help='use pre-trained model')

    args = parser.parse_args()
    print(args)
    return args