import argparse

def args_parser():
    parser = argparse.ArgumentParser()
    # general
    parser.add_argument('-dev', "--device", type=str, default="cuda",
                        choices=["cpu", "cuda"])
    parser.add_argument('-did', "--device_id", type=str, default="0", help='id of employed GPUs')
    parser.add_argument('-data', "--dataset", type=str, default="dataset")
    parser.add_argument('-t_data', "--test_data", type=str, default="test_data")
    parser.add_argument('-m_data', "--mask4d_ls", type=str, default="mask4d_ls")
    # parser.add_argument('-nb', "--num_classes", type=int, default=10)
    parser.add_argument('-m', "--model", type=str, default="SRN")
    parser.add_argument("--prompt_model", type=str, default="Prompt_base")
    parser.add_argument('-lbs', "--batch_size", type=int, default=4)
    parser.add_argument("--batch_size_trntst", type=int, default=4)
    parser.add_argument('-lr', "--local_learning_rate", type=float, default=0.0004,
                        help="Local learning rate")
    parser.add_argument('-ld', "--learning_rate_decay", type=bool, default=True)
    parser.add_argument('-ldg', "--learning_rate_decay_gamma", type=float, default=0.5)
    parser.add_argument("--milestones", type=int, default=[50, 100, 150, 200, 250], help='milestones for MultiStepLR')

    parser.add_argument('-gr', "--global_rounds", type=int, default=1000)
    parser.add_argument('-ls', "--local_steps", type=int, default=2)
    parser.add_argument('-algo', "--algorithm", type=str, default="FedAvg")
    parser.add_argument('-jr', "--join_ratio", type=float, default=1.0,
                        help="Ratio of clients per round")
    parser.add_argument('-rjr', "--random_join_ratio", type=bool, default=False,
                        help="Random ratio of clients per round")
    parser.add_argument('-nc', "--num_clients", type=int, default=2,
                        help="Total number of clients")
    parser.add_argument('-eg', "--eval_gap", type=int, default=1,
                        help="Rounds gap for evaluation")
    parser.add_argument('-ab', "--auto_break", type=bool, default=False)
    # practical
    parser.add_argument('-cdr', "--client_drop_rate", type=float, default=0.0,
                        help="Rate for clients that train but drop out")
    parser.add_argument('-tsr', "--train_slow_rate", type=float, default=0.0,
                        help="The rate for slow clients when training locally")
    parser.add_argument('-ssr', "--send_slow_rate", type=float, default=0.0,
                        help="The rate for slow clients when sending global model")
    parser.add_argument('-ts', "--time_select", type=bool, default=False,
                        help="Whether to group and select clients at each round according to time cost")
    parser.add_argument('-tth', "--time_threthold", type=float, default=10000,
                        help="The threthold for droping slow clients")
    parser.add_argument('-bt', "--beta", type=float, default=0.4,
                        help="Average moving parameter for pFedMe, Second learning rate of Per-FedAvg, \
                            or L1 regularization weight of FedTransfer")
    parser.add_argument('-lam', "--lamda", type=float, default=1.0,
                        help="Regularization weight")
    # FedProx
    parser.add_argument('-mu', "--mu", type=float, default=0,
                        help="Proximal rate for FedProx")
    # pFedMe
    parser.add_argument('-K', "--K", type=int, default=5,
                        help="Number of personalized training steps for pFedMe")
    parser.add_argument('-lrp', "--p_learning_rate", type=float, default=0.01,
                        help="personalized learning rate to caculate theta aproximately using K steps")

    # @WJM: cp from HSI ##################################################################################################
    # Set-up #################################################
    # parser.add_argument('--device', default='0,1', help='CUDA ID')
    parser.add_argument('--debug', type=int, default=1,
                        help='wheter to use debug mode for training, if True, a very small training set will be applied')
    parser.add_argument('--workers', type=int, default=8, help='keep, num of workers for dataloader processing')
    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument('--cluster', default="US", choices=['US', 'CHN'], help="cluster location")
    # data #################################################
    parser.add_argument('--train_data_path', default="./NIPS2021/Data/training/simu/",
                        help="training data directory")
    parser.add_argument('--mask_path', default="./masks/mask_files", help="mask directory")
    parser.add_argument("--mask_ids", type=int, default=[1, 5], nargs='+',
                        help='real mask [names] to be employed, See masks/figures/ for more info')
    parser.add_argument("--mask_id_single", type=int, default=0,
                        help='use one mask [index] of the mask_ids to train a singleM2M model')
    parser.add_argument('--test_data_path', default="./NIPS2021/Data/testing/simu/mat/",
                        help="testing data directory")

    # federated learning #################################################

    parser.add_argument('--trn_split', type=int, default=1,
                        help="keep, if True(1), use private datasets for users, if False(0), use public/shared dataset, if Customize(2), then split the private datasets accroding to the [trn_split_ratio]")
    parser.add_argument('--trn_split_ratio', type=int, default=[1,1], nargs='+',
                        help='need to be the same len as num_clients! split the original dataset according to ratios')
    # network #################################################

    parser.add_argument('--meas_init', default="meas", choices=['meas*mask', 'meas'],
                        help="if initialzing the measurement with the mask")
    parser.add_argument('--mask_op', default="rand_crop", choices=['rand_crop', 'fixed256'],
                        help="random crop masks or use single mask for train/test")
    parser.add_argument('--patch_size', type=int, default=256, help="input data patch size")
    parser.add_argument('--trial_num', type=int, default=10, help="testing trials for random mask, [mask_op]=rand_crop")

    # optimization #################################################
    parser.add_argument("--epoch_sum_num", type=int, default=5000,
                        help='total number of samples to be retrieved for an epoch: batch_num = epoch_sum_num/batch_size')
    parser.add_argument('--optimizer', default="ADAM", choices=['ADAM', 'PERAVG'],)
    parser.add_argument("--backbone_lr", type=float, default=0.0004,
                        help="only used in MABFT mode, for backbone optimization")

    parser.add_argument('--base_optimizer', default="ADAM", choices=['ADAM', 'PERAVG'],)
    parser.add_argument("--base_local_learning_rate", type=float, default=0.00001,
                        help="Local learning rate")
    parser.add_argument("--base_learning_rate_decay_gamma", type=float, default=0.5)
    parser.add_argument("--base_milestones", type=int, default=[250], help='milestones for MultiStepLR')
    parser.add_argument("--base_learning_rate_decay", type=bool, default=True)

    # checkpoint #################################################
    parser.add_argument("--psnr_set", type=int, default=10,
                        help='start to save the checkpoint if validation performance achieves this psnr value')
    parser.add_argument("--last_train", type=int, default=0,
                        help='specify the checkpoint to be load for breakpoint training/testing')
    parser.add_argument('--model_path', default="results", help="checkpoint directory")
    parser.add_argument('--model_save_filename', default="",
                        help="upper level folder of the checkpoint, i.e., 2022_01_10_15_36_25")

    # testing  #################################################
    parser.add_argument('--last_train_ls', type=int, default=[0], nargs='+',
                        help='provide a series of pre-trained models as ensemble')
    parser.add_argument('--model_save_filename_ls', type=str, default=[''], nargs='+',
                        help='provide a series of pre-trained models as ensemble')
    parser.add_argument("--focus", type=str, default='Personalization',choices=['Personalization', 'Generalization'],
                        help='focus on which properity of the pFL/clients')

    # MST   #################################################
    parser.add_argument("--adaptor", type=str, default=None,
                        help='None: do not apply,  [LnPlain]: linear adaptor, downscaled; [ConvPlain]: Conv adaptor, no downscale')
    parser.add_argument("--param_init", type=str, default='default',
                        help='different global model initializaion')

    # FEDAPTOR   #################################################
    parser.add_argument('--last_train_clients', type=int, default=[0], nargs='+',
                        help='provide pre-trained models as for clients')
    parser.add_argument('--model_save_filename_clients', type=str, default=[''], nargs='+',
                        help='provide pre-trained models for clients')
    parser.add_argument("--test_mode", action='store_true', help='skip the training, but only perform the global evaluation')
    parser.add_argument("--check", action='store_true', help='check/print the grad and params of adaptor once specified')

    # FedAvg-like FL frameworks   #################################################
    parser.add_argument("--CA", action='store_true', help='for non personlized FL: CA(Client Aggregation), once sepcified, will initialize a global model by weighted sum of client weights. The global round is set as 0')

    # Learning mode   #################################################
    parser.add_argument("--PTP", action='store_true',help='PTP(Pre-trained Personlization), use value 0/1, onece specified, will load pre-trained backbone params to clients, and fix the backbones in optmization')

    parser.add_argument("--MB", action='store_true',help='MB (meta backbone), no adaptor allowed, update backbone weights as meta-weight')

    parser.add_argument("--MABFT", action='store_true',help='MABFT (meta adaptor + backbone fine tuning), adaptor involed in MAML, backboned upon MSE, alternating training')

    parser.add_argument("--FMABFT", action='store_true',help='FMABFT (first meta adaptor,  backbone fine tuning), adaptor involed in MAML, backboned upon MSE, alternating training, stable version of MABFT')
    parser.add_argument("--backbone_interval", type=int, default=2, help='update the backbone every [backbone_interval] epoch.')

    parser.add_argument("--SYNCMABFT", action='store_true',help='SYNCFMABFT (syncanize meta adaptor and backbone fine tuning), adaptor involed in MAML, backboned upon MSE, stable version of MABFT')

    parser.add_argument("--MP", action='store_true',help='MP (mask prompt learning for personlization),  if specified, must specify the [prompt_model]')
    parser.add_argument("--last_train_prompt", type=int, default=0,
                        help='specify the prompt model checkpoint to be load for breakpoint training/testing')
    parser.add_argument('--model_save_filename_prompt', default="",
                        help="upper level folder of the prompt model checkpoint, i.e., 2022_01_10_15_36_25")
    parser.add_argument("--prompt_learning_rate", type=float, default=0.0004,help="prompt_model learning rate")
    parser.add_argument("--last_train_PTP", type=int, default=0,
                        help='specify the prompt model checkpoint to be load for breakpoint training/testing')
    parser.add_argument('--model_save_filename_PTP', default="",
                        help="upper level folder of the prompt model checkpoint, i.e., 2022_01_10_15_36_25")
    parser.add_argument("--backbone", type=str, default="MST-S")

    parser.add_argument("--CF", action='store_true',help='CF (client frozen), used in FedMP/MPT, ablation of only proving a shared prompt network, the client keeps frozen')


    # Prompt network settings #################################################
    parser.add_argument("--Prompt_BN", action='store_true',help='add a BN layer at the end of the network')
    parser.add_argument("--Prompt_BLK", type=int, default=2, help='number of residual blocks in the promt_net')
    parser.add_argument("--embed_dim", type=int, default=32, help='number of residual blocks in the promt_net')


    # MPT settings #################################################
    parser.add_argument("--local_steps_warmup", type=int, default=3, help='number of residual blocks in the promt_net')
    parser.add_argument("--local_steps_B", type=int, default=1, help='number of residual blocks in the promt_net')
    parser.add_argument("--local_steps_P", type=int, default=2, help='number of residual blocks in the promt_net')
    parser.add_argument("--align_intensity", type=float, default=0.001,help="align intensity")



    args = parser.parse_args()

    return args