import argparse
import torch

def args_parser():

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='cifar100', help="name of dataset")
    parser.add_argument('--task_classes', type=int, nargs='+', default=[80, 20], help="number of data classes for each tasks")
    parser.add_argument('--method', type=str, default='gal', help="name of method")
    parser.add_argument('--img_size', type=int, default=32, help="size of images")
    parser.add_argument('--device', type=int, default=0, help="GPU ID, -1 for CPU")
    parser.add_argument('--batch_size', type=int, default=128, help='size of mini-batch')
    parser.add_argument('--seed', type=int, default=2023, help='random seed')
    parser.add_argument('--epochs_local', type=int, default=10, help='local epochs of each global round')
    parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate')
    parser.add_argument('--num_clients', type=int, default=10, help='number of clients')
    parser.add_argument('--local_clients', type=int, default=5, help='number of selected clients each round')
    parser.add_argument('--round_per_task', type=int, default=30)
    parser.add_argument('--dirichlet_alpha', type=float, default=1e-1)
    parser.add_argument('--ema_alpha', type=float, default=0.99, help='ema')
    parser.add_argument('--ema_decay', type=float, default=0.000, help='ema alpha decay along global epoch')
    parser.add_argument('--source', action="store_true", help='train the source labeled categories')
    args = parser.parse_args()
    # initialize arguments
    args.epochs_global = args.round_per_task * len(args.task_classes)
    return args