import argparse

def get_arguments():
    parser = argparse.ArgumentParser()

    # Directory option
    parser.add_argument('--data_root', type=str, default='../data/')   
    parser.add_argument('--checkpoints', type=str, default='../backdoor_checkpoints')
    parser.add_argument('--result_checkpoints', type=str, default='./finetuning_checkpoints')
    parser.add_argument('--result_file', type=str, default='./cifar10_results.txt')
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--dataset', type=str, default='cifar10')
    parser.add_argument('--saving_prefix', type=str, help='Folder in /checkpoints for saving ckpt')
    parser.add_argument('--input_height', type=int, default=32)
    parser.add_argument('--input_width', type=int, default=32)
    parser.add_argument('--input_channel', type=int, default=3)
    parser.add_argument('--num_classes', type=int, default=10)
    parser.add_argument('--target_label', type=int, default=0)
    parser.add_argument('--num_workers', type=int, default=2)
    parser.add_argument('--noise_rate', type=float, default=0.08)
    parser.add_argument('--ratio', type=float, default=0.65)
    parser.add_argument("--post_transform_option", type=str, default="use", choices=["use", "no_use", "use_modified"])
    parser.add_argument("--random_rotation", type=int, default=10)
    parser.add_argument("--random_crop", type=int, default=5)

    # ---------------------------- For fine-tuning --------------------------
    # Model hyperparameters
    parser.add_argument('--portion', type=int, default=0.05)
    parser.add_argument('--n_iters', type=int, default=20)
    parser.add_argument('--bs', type=int, default=128)
    parser.add_argument('--lr_C', type=float, default=3e-4)
    parser.add_argument('--lr_C_max1', type=float, default=0.1)
    parser.add_argument('--lr_C_max2', type=float, default=0.001)
    parser.add_argument('--scheduler_step_size', type=int, default=100)
    parser.add_argument('--schedulerG_milestones', type=list, default=[100, 200])
    parser.add_argument('--schedulerC_milestones', type=list, default=[100, 200])
    parser.add_argument('--schedulerT_milestones', type=list, default=[100, 200])
    parser.add_argument('--schedulerG_lambda', type=float, default=0.1)
    parser.add_argument('--schedulerC_lambda', type=float, default=0.1)
    parser.add_argument('--schedulerT_lambda', type=float, default=0.1)


    # parser.add_argument('--beta1', type=int, default=500, help='beta of low layer')
    # parser.add_argument('--beta2', type=int, default=1000, help='beta of middle layer')
    # parser.add_argument('--beta3', type=int, default=1000, help='beta of high layer')

    return parser