import argparse

def args_parser():
    parser = argparse.ArgumentParser(description='IncFL Experiments')
    parser.add_argument('--name',
                        default='incfl_results',
                        type=str,
                        help='name of results file')
    parser.add_argument('--dir',
                        default='.',
                        type=str,
                        help='dir for data')
    parser.add_argument('--dataset',
                        default='fmnist',
                        type=str,
                        help='type of dataset')
    parser.add_argument('--alpha',
                        default=0.3,
                        type=float,
                        help='control the non-iidness of dataset')
    parser.add_argument('--q',
                        default=10,
                        type=float,
                        help='parameter for qffl')
    parser.add_argument('--mwfed_c',
                        default=5,
                        type=float,
                        help='parameter for MW-Fed')
    parser.add_argument('--isclust',
                        default=0,
                        type=int,
                        help='perform different types of clustering')
    parser.add_argument('--better',
                        default=0,
                        type=int,
                        help='perform different types of clustering')
    parser.add_argument('--leave',
                        default=0,
                        type=int,
                        help='perform different types of clustering')
    parser.add_argument('--local_tune',
                        default=0,
                        type=int,
                        help='perform different types of clustering')
    parser.add_argument('--local_tune_ep',
                        default=5,
                        type=int,
                        help='perform different types of clustering')
    parser.add_argument('--labelshift',
                        default=0,
                        type=bool,
                        help='perform labelshift that changes P(Y|X)')
    parser.add_argument('--train_ratio',
                        type=float,
                        default=0.1,
                        help='traindata length')
    parser.add_argument('--val_ratio',
                        type=float,
                        default=0.3,
                        help='publicdata length')
    parser.add_argument('--test_ratio',
                        type=float,
                        default=0.0,
                        help='publicdata length')
    parser.add_argument('--device',
                        default = 'cuda',
                        type= str)
    parser.add_argument('--size',
                        default=3,
                        type=int,
                        help='number of local workers')
    parser.add_argument('--ensize',
                        default=100,
                        type=int,
                        help='number of all workers')
    parser.add_argument('--eta',
                        default=0.1,
                        type=float,
                        help='client learning rate')
    parser.add_argument('--local_eta',
                        default=0.1,
                        type=float,
                        help='client learning rate')
    parser.add_argument('--bs',
                        default=64,
                        type=int,
                        help='batch size on each worker/client')
    parser.add_argument('--rounds',
                        default=500,
                        type=int,
                        help='total communication rounds')
    parser.add_argument('--train_ep',
                        default=30,
                        type=int,
                        help='number of local epochs')
    parser.add_argument('--local_train_ep',
                        default=30,
                        type=int,
                        help='number of local epochs')
    parser.add_argument('--model',
                        default="vgg",
                        type=str,
                        help='neural network model')
    parser.add_argument('--num_classes',
                        type=int,
                        default=10,
                        help='number of classes')
    parser.add_argument('--decay',
                        default=1,
                        type=bool,
                        help='1: decay LR, 0: no decay')
    parser.add_argument('--print_freq',
                        default=100,
                        type=int,
                        help='print info frequency')
    parser.add_argument('--seed',
                        default=1,
                        type=int,
                        help='random seed')

    args = parser.parse_args()

    return args