import argparse

def make_parser():
    parser = argparse.ArgumentParser()

    # general configurations
    parser.add_argument('--n_pd_rounds', type=int, default=5000, help='total dual rounds for PDFL')
    parser.add_argument('--n_global_rounds', type=int, default=5000, help='total communication rounds for FL')
    parser.add_argument('--test_batch_size', type=int, default=200)
    parser.add_argument('--use_ray', action='store_true')
    parser.add_argument('--load_model', action='store_true')
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--test_metric', type=str, choices=['accuracy', 'class_wise_accuracy'],
                        default='class_wise_accuracy', help='what to report in tensorboard')
    parser.add_argument('--eval_freq', type=int, default=1, help='how often the test loss should be checked')
    parser.add_argument('--weighted', action='store_true', help='allow clients to have different weights initially')
    parser.add_argument('--loss_fn', type=str, choices=['focal-loss', 'cross-entropy-loss','mcr'],
                        default='cross-entropy-loss', help='loss functional')


    # tricks for NN training
    parser.add_argument('--no_data_augmentation', action='store_true', help='disable the data augmentation')
    parser.add_argument('--use_gradient_clip', action='store_true')
    parser.add_argument('--gradient_clip_constant', type=float, default=5.)
    parser.add_argument('--weight_decay', type=float, default=1e-3)

    # Experiment setup
    parser.add_argument('--heterogeneity', type=str, choices=['mix', 'dir'], default='mix',
                        help='Type of heterogeneity, mix or dir(dirichlet)')
    parser.add_argument('--homo_ratio', type=float, default=1.)
    parser.add_argument('--dir_level', type=float, default=.3, help='hyperparameter of the Dirichlet distribution')
    parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'mnist', 'fashion-mnist', 'emnist-letter', 'emnist-digit', 'shakespeare'],
                        default='cifar10', help='dataset (and the corresponding task), now only support cifar10')
    parser.add_argument('--dense_hid_dims', type=str, default='384-192')
    parser.add_argument('--conv_hid_dims', type=str, default='64-64')
    parser.add_argument('--model', type=str, choices=['mlp', 'convnet', 'resnet'], default='convnet')
    parser.add_argument('--learner', type=str, choices=['fed-avg', 'fed-pd', 'scaffold', 'ffgb-d', 'fedavg-d'], default='fed-pd')
    parser.add_argument('--formulation', type=str, choices=['imbalance-fl', 'ratioloss-fl', 'GHMC_loss'],
                        default='imbalance-fl', help='formulation for handling class imbalance problem')
    parser.add_argument('--n_workers', type=int, default=50)
    parser.add_argument('--n_workers_per_round', type=int, default=5)
    parser.add_argument('--l2_reg', type=float, default=-1.)
    parser.add_argument('--remove_classes', action='store_true')


    parser.add_argument('--imbalance', action='store_true', help='create imbalance among classes')
    parser.add_argument('--n_minority', type=int, default=1, help='number of minority classes')
    parser.add_argument('--reduce_to_ratio', type=float, default=1.)

    # General hyperparameters
    parser.add_argument('--local_lr', type=float, default=0.1)
    parser.add_argument('--global_lr', type=float, default=1.)
    parser.add_argument('--local_epoch', type=int, default=5)
    parser.add_argument('--client_step_per_epoch', type=int, default=5)

    # Hyperparameters for the formulation "imbalance-fl"
    parser.add_argument('--lambda_lr', type=float, default=1)
    parser.add_argument('--tolerance_epsilon', type=float, default=1.)
    parser.add_argument('--n_p_steps', type=int, default=5, help="primal steps per dual step in PDFL")

    # Hyperparameters for fed-pd
    parser.add_argument('--eta', type=float, default=10)
    parser.add_argument('--fed_pd_dual_lr', type=float, default=1)

    # General hyperparameters for Functional Federated Learning
    parser.add_argument('--local_dataloader_batch_size', type=int, default=50)
    parser.add_argument('--distill_dataloader_batch_size', type=int, default=50)
    parser.add_argument('--dataset_distill', type=str,
                        choices=['cifar10', 'cifar100', 'mnist', 'fashion-mnist', 'emnist-letter', 'emnist-digit',
                                 'shakespeare'],
                        default='cifar10')
    # Hyperparameters for FFGB-D
    parser.add_argument('--functional_lr', type=float, default=10.)
    parser.add_argument('--f_l2_reg', type=float, default=5e-3)
    parser.add_argument('--local_steps', type=int, default=1)
    parser.add_argument('--weak_learner_epoch', type=int, default=120)
    parser.add_argument('--weak_learner_lr', type=float, default=1e-3)
    parser.add_argument('--weak_learner_weight_decay', type=float, default=0)

    # Hyperparameters for FEDAVG-D
    parser.add_argument('--fedavg_d_local_lr', type=float, default=.01)
    parser.add_argument('--fedavg_d_local_epoch', type=int, default=50)
    parser.add_argument('--fedavg_d_weight_decay', type=float, default=1e-3)

    # distill oracle
    parser.add_argument('--distill_oracle', type=str, choices=["kl", "l2"], default="l2")
    parser.add_argument('--distill_oracle_epoch', type=int, default=10)
    parser.add_argument('--distill_oracle_lr', type=float, default=1e-3)
    parser.add_argument('--distill_oracle_weight_decay', type=float, default=.001)

    # mcr Hyperparams
    parser.add_argument('--gam1', type=float, default=1.,
                    help='gamma1 for tuning empirical loss (default: 1.)')
    parser.add_argument('--gam2', type=float, default=1.,
                        help='gamma2 for tuning empirical loss (default: 1.)')
    parser.add_argument('--eps', type=float, default=0.5,
                        help='eps squared (default: 0.5)')
    parser.add_argument('--fd', type=int, default=128,
                    help='dimension of feature dimension (default: 128)')


    return parser