# This file handles argument parser
import argparse
def argument_parsing():
    parser = argparse.ArgumentParser(description='PyTorch CIFAR10/MNIST Training')
    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--dp', default=False, type=bool, help='use dp to train models')
    parser.add_argument('--epsilon', default=1, type=float, help='privacy budget')
    parser.add_argument('--delta', default=1e-5, type=float, help='delta')
    parser.add_argument('--max_grad_norm', default=1.2, type=float, help='max gradient norm for clipping')
    parser.add_argument('--batch_size', default=128, type=int, help='batch size')
    parser.add_argument('--total_epoch', default=200, type=int, help='epoch')
    parser.add_argument('--dataset', default='CIFAR10', type=str, help='dataset')
    parser.add_argument('--model', default='CNN', type=str, help='model')
    parser.add_argument('--unlearn_model', default='fisher', type=str, help='unlearning method')
    parser.add_argument('--optimizer', default='Adam', type=str, help='optimizer')
    parser.add_argument('--loss', default='CrossEntropyLoss', type=str, help='loss')
    parser.add_argument('--gpu', default='0', type=str, help='gpu')
    parser.add_argument('--retain_ratio', default = 0.6, type = float, help = "ratio of retain data")
    parser.add_argument('--size_ratio', default = 1, type = float, help = "ratio of retain data")

    # MIA
    parser.add_argument('--attack_batch_size', default = 128, type = int, help = "batch_size of attack")
    parser.add_argument('--momentum', default = 0.9, type = float, help = "momentum of SGD")
    parser.add_argument('--wd', default = 1e-4, type = float, help = "weight decay")
    parser.add_argument('--seed', default = 42, type = int, help = "seed used for dataset split")
    parser.add_argument('--secseed', default = 42, type = int, help = "seed used for model training")

    # Unlearn
    parser.add_argument('--cutmix-prob', type=float, default=0.5, metavar='P',
                        help='Prob. for cutmix (default: 0.5)')
    parser.add_argument('--cutmix-alpha', type=float, default=1.0, metavar='A',
                        help='Alpha for cutmix (default: 1.0)')
    parser.add_argument('--clip', type=float, default=10.0, metavar='M',
                        help='Gradient clipping (default: 10)')
    parser.add_argument('--disable-bn', action='store_true', default=False,
                    help='Put batchnorm in eval mode and don\'t update the running averages')
    parser.add_argument('--regularization', default=None,
                    help='Regularization type (default: None)')
    parser.add_argument('--maxlr', type=float, default=0.1, metavar='LR',
                        help='max learning rate for SGDR (default: 0.1)')
    parser.add_argument('--minlr', type=float, default=0.005, metavar='LR',
                        help='min learning rate for SGDR (default: 0.005)')
    parser.add_argument('--init_checkpoint', required=False,
                        help='Path to init checkpoint for golatkar')
    parser.add_argument('--scheduler', default='CosineAnnealingWarmRestarts',
                        choices = ['CosineAnnealingWarmRestarts', 'CosineAnnealingLR'],
                        help='Pytorch Scheduler name: (default: The one used for train, in args_re')
    parser.add_argument('--maxlr-ft', type=float, default=0.1, metavar='LR',
                        help='max learning rate for SGDR (default: 0.1)')
    parser.add_argument('--minlr-ft', type=float, default=0.005, metavar='LR',
                        help='min learning rate for SGDR (default: 0.005')
    parser.add_argument('--name-ft', default='Finetune')
    parser.add_argument('--epochs-ft', type=int, default=30, metavar='N',
                        help='number of epochs to train (default: 62)')
    parser.add_argument('--epochs-rf', type=int, default=62, metavar='N',
                        help='number of epochs to train (default: 62)')
    parser.add_argument('--maxL-rf', type=int, default=3, metavar='UL',
                        help='Layers to retrain upperbound (default: 3)')
    parser.add_argument('--minL-rf', type=int, default=1, metavar='LL',
                        help='Layers to retrain lowerbound (default: 1)')
    parser.add_argument('--stepL-rf', type=int, default=1, metavar='LS',
                        help='Layers to retrain step size (default: 1)')
    parser.add_argument('--name-rf', default='RetrainFinal')
    parser.add_argument('--name-ftF', default='FinetuneFinal')
    parser.add_argument('--epochs-ftF', type=int, default=62, metavar='N',
                        help='number of epochs to train (default: 62)')
    parser.add_argument('--maxL-ftF', type=int, default=3, metavar='UL',
                        help='Layers to finetune upperbound (default: 3)')
    parser.add_argument('--minL-ftF', type=int, default=1, metavar='LL',
                        help='Layers to finetune lowerbound (default: 1)')
    parser.add_argument('--stepL-ftF', type=int, default=1, metavar='LS',
                        help='Layers to finetune step size (default: 1)')
    # eiu
    parser.add_argument('--eiu', type=bool,default=False, help="whether to use eiu or not")
    parser.add_argument('--num-change', type=int, default=10,
                        help='No. of samples per class to exch (default: 10)')
    parser.add_argument("--exch-classes", nargs="+", default=None, 
                        type=int, help='List of classes to exchange space separated')
    
    # epi
    parser.add_argument('--epi', type=bool,default=False, help="whether to use epi or not")

    args = parser.parse_args()

    return args