import argparse



def get_arguments():
    parser = argparse.ArgumentParser()

    parser.add_argument("--data_root", type=str, default="/root/projects/AttackDefence/data/AttackDefence") # put data under this folder
    parser.add_argument("--trigger_dir", type=str, default="/root/projects/AttackDefence/output") # put data under this folder
    parser.add_argument("--checkpoints", type=str, default="/root/projects/AttackDefence/checkpoints") # save models in this folder
    parser.add_argument("--temps", type=str, default="./temps")
    parser.add_argument("--device", type=str, default="gpu")
    parser.add_argument("--continue_training", action="store_true", default=True)
    parser.add_argument("--job_name", type=str, default="default") # put data under this folder
    
    parser.add_argument("--dataset", type=str, default="mnist")
    parser.add_argument("--attack_mode", type=str, default="all2one")
    parser.add_argument('--attack_ratio', type=float, default=0.01)
    parser.add_argument('--attack_method', type=str, default='hard')
    parser.add_argument('--attack_type', type=str, default='AETCB') # our proposed strategy
    parser.add_argument('--attack_locs', type=str, default='top-left')
    parser.add_argument('--attack_modes', type=str, default='all_col')
    parser.add_argument("--reflect", type=str, default="load")
    parser.add_argument("--reflect_mode", type=str, default="single")
    parser.add_argument("--train_mode", type=str, default="train")
    parser.add_argument("--defencer", type=str, default="ShrinkPad")

    parser.add_argument('--alpha', type=float, default=0.005)
    parser.add_argument('--freqs', type=str, default='2,4,8,16')
    parser.add_argument('--top_k', type=int, default=6)

    parser.add_argument('--trigger_dim', type=int, default=20)

    parser.add_argument("--bs", type=int, default=128)
    parser.add_argument("--lr_C", type=float, default=1e-2)
    parser.add_argument("--schedulerC_milestones", type=list, default=[100, 200, 300, 400])
    parser.add_argument("--schedulerC_lambda", type=float, default=0.1)
    parser.add_argument("--n_iters", type=int, default=20)
    parser.add_argument("--num_workers", type=int, default=6)
    parser.add_argument("--use_label_smooth", type=int, default=0)

    parser.add_argument("--target_label", type=int, default=2)
    parser.add_argument("--pc", type=float, default=0.1)
    parser.add_argument("--cross_ratio", type=float, default=2)  # rho_a = pc, rho_n = pc * cross_ratio
    parser.add_argument("--model_check_metric", type=str, default='Acc')
    parser.add_argument("--random_rotation", type=int, default=10)
    parser.add_argument("--random_crop", type=int, default=4)

    parser.add_argument("--s", type=str, default="0.55")
    parser.add_argument("--k", type=str, default="4")
    parser.add_argument(
        "--grid-rescale", type=float, default=1
    )  # scale grid values to avoid pixel values going out of [-1, 1]. For example, grid-rescale = 0.98

    return parser
