import argparse
import torch

def args_parser():
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--data', type=str, default='cifar10',
                        help="dataset we want to train on")
    
    parser.add_argument('--num_agents', type=int, default=100,
                        help="number of agents:K")
    
    parser.add_argument('--agent_frac', type=float, default=0.2,
                        help="fraction of agents per round:C")
    
    parser.add_argument('--num_corrupt', type=int, default=40,
                        help="number of corrupt agents")
    
    parser.add_argument('--rounds', type=int, default=200,
                        help="number of communication rounds:R")
    
    parser.add_argument('--aggr', type=str, default='avg',
                        help="aggregation function to aggregate agents' local weights")
    
    parser.add_argument('--local_ep', type=int, default=2,
                        help="number of local epochs:E")
    
    parser.add_argument('--bs', type=int, default=64,
                        help="local batch size: B")
    
    parser.add_argument('--client_lr', type=float, default=0.1,
                        help='clients learning rate')
    parser.add_argument('--server_lr', type=float, default=1,
                        help='servers learning rate for signSGD')
    
    parser.add_argument('--target_class', type=int, default=7,
                        help="target class for backdoor attack")
    
    parser.add_argument('--attack_goal', type=int, default=5,
                        help="target class for backdoor attack")
    parser.add_argument('--attack_all2one', action='store_true', default=True,
                        help="attack goal is all to one")

    parser.add_argument('--poison_frac', type=float, default=0.5,
                        help="fraction of dataset to corrupt for backdoor attack")
#     plus + DBA,square,
    parser.add_argument('--pattern_type', type=str, default='plus',
                        help="shape of bd pattern")

    parser.add_argument('--snap', type=int, default=1,
                        help="do inference in every num of snap rounds")
    parser.add_argument('--device',  default=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
                        help="To use cuda, set to a specific GPU ID.")
    parser.add_argument('--num_workers', type=int, default=1, 
                        help="num of workers for multithreading")

    parser.add_argument('--anneal_factor', type=float, default=0.0001,
                        help="num of workers for multithreading")
    parser.add_argument('--method', type=str, default="Grace",
                        help="num of workers for multithreading")
    parser.add_argument('--se_threshold', type=float, default=1e-4,
                        help="num of workers for multithreading")
    parser.add_argument('--non_iid', action='store_true', default=False)
    parser.add_argument('--partition', type=str,default="None", help="homo,noniid-labeldir,noniid-#label0,iid-diff-quantity,mixed")
    parser.add_argument('--debug', action='store_true', default=False)
    parser.add_argument('--alpha',type=float, default=0.5)
    parser.add_argument('--attack',type=str, default="badnet")
    parser.add_argument('--lr_decay',type=float, default= 1)

    parser.add_argument('--dis_check_gradient', action='store_true', default=False)
    parser.add_argument('--wd', type=float, default= 1e-4)

    parser.add_argument('--cease_poison', type=float, default=10000)

    # class representation Parameters
    parser.add_argument('--alpha_cr', type=float, default=10)
    parser.add_argument('--last_local_ep', type=int, default=10, help="the number of local epochs of last")
    parser.add_argument('--local_rep_ep', type=int, default=1, help="the number of local epochs of Fed_Rep's feature") #ten local epochs to train the local head, followed by one or five epochs for the representation
    parser.add_argument('--beta', default=0.001, type=float, help='the value of beta for Fed_VIB')
    parser.add_argument('--beta2', default=0, type=float, help='the value of beta2 for Z')
    parser.add_argument('--dimZ', default = 256, type=int, help='dimension of encoding Z in Fed_VIB')
    parser.add_argument('--CMI', default=0.001, type=float, help='the value of CMI in FedSR')
    parser.add_argument('--L2R', default=0.001, type=float, help='the value of L2R in FedSR')
    parser.add_argument('--num_avg_train', default = 15, type=int, help='the number of samples when\
            perform multi-shot train')
    parser.add_argument('--num_avg', default = 30, type=int, help='the number of samples when\
            perform multi-shot prediction')

    parser.add_argument('--similar', default=None, type=str, help='the similar of client:cosin')
    parser.add_argument('--mask_update', default=None, type=str, help='bernoulli,None')
    parser.add_argument('--noise', type=float, default=0.001)

    parser.add_argument('--selection', default = True, type=bool, help='client selection or not')
    parser.add_argument('--alpha_sel', type=float, default=0.1, help='cucb parameter')
    parser.add_argument('--param_clip_thres', type=int, default=20,
                        help='param_clip_thres')
    
    parser.add_argument('--drop', default = True, type=bool, help='global model drop')

    parser.add_argument('--alpha_pru', type=float, default=0.8, help='Search area design Parameter')
    parser.add_argument('--beta_pru', type=float, default=0.5, help='Search area design Parameter')
    parser.add_argument('--nb-epochs', type=int, default=40, help='the number of iterations for training')  
    parser.add_argument('--epoch-aggregation', type=int, default=20, help='print results every few iterations')
    parser.add_argument('--val_ratio', type=float, default=0.01, help='Search area design Parameter')
    parser.add_argument('--checkpoint', type=str, default='./save/clean_model/100round_model.pt', help='The checkpoint to be pruned')

    parser.add_argument('--pruning-max', type=float, default=0.02, help='the maximum number/threshold for pruning')
    parser.add_argument('--pruning-step', type=float, default=0.005, help='the step size for evaluating the pruning')












    args = parser.parse_args()
    return args