import argparse


def get_base_args_parser(add_help=True):
    parser = argparse.ArgumentParser(
        description="Pretraining Neural Network", add_help=add_help
    )
    parser = _add_architecture_and_dataset_args(parser)
    parser = _add_storage_args(parser)
    parser = _add_other_args(parser)
    parser = _add_finetune_args(parser)
    return parser


def get_tp_args_parser(add_help=True):
    parser = argparse.ArgumentParser(
        description="Pruning Neural Network by TP", add_help=add_help
    )
    parser = _add_architecture_and_dataset_args(parser)
    parser = _add_storage_args(parser)
    parser = _add_prune_args(parser)
    parser = _add_other_args(parser)
    parser = _add_finetune_args(parser)

    parser = _add_torch_pruning_args(parser)
    return parser


def get_FPVE_args_parser(add_help=True):
    parser = argparse.ArgumentParser(
        description="Pruning Neural Network by F-PVE", add_help=add_help
    )
    parser = _add_architecture_and_dataset_args(parser)
    parser = _add_storage_args(parser)
    parser = _add_prune_args(parser)
    parser = _add_other_args(parser)
    parser = _add_finetune_args(parser)

    parser = _add_FPVE_args(parser)
    return parser


def _add_architecture_and_dataset_args(parser):
    group = parser.add_argument_group(title="architecture and dataset")

    group.add_argument(
        "--arch", type=str, default="resnet50", help="model architecture"
    )
    group.add_argument(
        "--dataset",
        type=str,
        default="cub200",
        choices=["cifar100", "cub200"],
        help="dataset to use",
    )
    group.add_argument(
        "--data_root",
        type=str,
        default="../datasets/",
        help="root directory for the dataset",
    )
    return parser


def _add_storage_args(parser):
    group = parser.add_argument_group(title="storage")

    group.add_argument(
        "--ckpt_load_dir",
        type=str,
        help="The directory used to load the models to be pruned",
        default=None,
    )
    group.add_argument(
        "--ckpt_save_dir",
        type=str,
        help="The directory used to save the models and logs",
        default="../ckpt",
    )
    group.add_argument(
        "--save_every",
        dest="save_every",
        help="Saves checkpoints at every specified number of epochs",
        type=int,
        default=1000,
    )
    group.add_argument(
        "--no_save_model",
        action="store_false",
        default=True,
        dest="save_model",
        help="disable save model",
    )
    return parser


def _add_prune_args(parser):
    group = parser.add_argument_group(title="prune")

    group.add_argument(
        "--pruning_ratio",
        type=float,
        default=0.5,
        help="The ratio of parameters to prune",
    )
    group.add_argument(
        "--iterative_steps",
        default=1,
        type=int,
        help="Number of iterative pruning steps",
    )

    return parser


def _add_other_args(parser):
    group = parser.add_argument_group(title="other")

    group.add_argument(
        "--gpu", type=int, default=1, help="cuda training"
    )
    group.add_argument(
        "--random_seed", default=2023, type=int, help="seed for dataset split"
    )
    group.add_argument(
        "--workers",
        "-j",
        default=32,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 32)",
    )
    group.add_argument(
        "--use_deterministic_algorithms",
        action="store_true",
        help="Forces the use of deterministic algorithms only.",
    )
    group.add_argument(
        "--fairness_eval_flag",
        action="store_true",
        default=False,
        help="also eval acc per classes",
    )
    group.add_argument(
        "--fairness_type",
        default="max_min",
        type=str,
        choices=["max_min"],
        help="fairness type for FPVE, should use with fairness_eval_flag",
    )
    group.add_argument("--scalar", default=0.01, type=float, help="scalar for fitness")
    return parser


def _add_finetune_args(parser):
    group = parser.add_argument_group(title="finetune")

    group.add_argument(
        "--no_finetune",
        action="store_false",
        default=True,
        dest="finetune",
        help="disable finetuning",
    )
    group.add_argument(
        "--batch_size", default=128, type=int, help="batch size (default: 128)"
    )
    group.add_argument(
        "--ft_epochs", type=int, default=20, help="number of finetuning epochs"
    )
    group.add_argument(
        "--lr", type=float, default=0.01, help="learning rate"
        )
    group.add_argument(
        "--clip_grad_norm",
        default=None,
        type=float,
        help="the maximum gradient norm (default None)",
    )
    group.add_argument(
        "--opt_name",
        choices=["sgd", "rmsprop", "adamw"],
        help="Choose the optimizer to use (default: sgd)",
        default="sgd",
    )
    group.add_argument(
        "--lr_scheduler_name",
        choices=["steplr", "multisteplr", "cosineannealinglr", "exponentiallr", "None"],
        default="None",
        type=str,
        help="the lr scheduler (default: None)",
    )
    group.add_argument(
        "--momentum", default=0.9, type=float, metavar="M", help="momentum"
    )
    group.add_argument(
        "--wd",
        "--weight_decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    group.add_argument(
        "--norm_weight_decay",
        default=None,
        type=float,
        help="weight decay for Normalization layers (default: None, same value as --wd)",
    )
    group.add_argument(
        "--bias_weight_decay",
        default=None,
        type=float,
        help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
    )
    group.add_argument(
        "--label_smoothing",
        default=0.0,
        type=float,
        help="label smoothing (default: 0.0)",
        dest="label_smoothing",
    )
    group.add_argument(
        "--lr_warmup_epochs",
        default=0,
        type=int,
        help="the number of epochs to warmup (default: 0)",
    )
    group.add_argument(
        "--lr_warmup_method",
        default="constant",
        type=str,
        help="the warmup method (default: constant)",
    )
    group.add_argument(
        "--lr_warmup_decay", default=0.01, type=float, help="the decay for lr"
    )
    group.add_argument(
        "--lr_step_size",
        default=30,
        type=int,
        help="decrease lr every step-size epochs",
    )
    group.add_argument(
        "--lr_decay_gamma",
        default=0.1,
        type=float,
        help="decrease lr by a factor of lr-gamma",
    )
    group.add_argument(
        "--lr_min",
        default=0.0,
        type=float,
        help="minimum lr of lr schedule (default: 0.0)",
    )
    group.add_argument(
        "--lr_decay_milestones", type=str, help="milestones for learning rate decay"
    )
    group.add_argument(
        "--amp",
        action="store_true",
        default=False,
        help="Use torch.cuda.amp for mixed precision training",
    )

    return parser


def _add_torch_pruning_args(parser):
    group = parser.add_argument_group(title="torch_pruning")

    group.add_argument("--prune_method", type=str, default="l1", help="pruning method")
    group.add_argument(
        "--global_pruning",
        action="store_true",
        default=False,
        help="whether to use global pruning",
    )
    group.add_argument(
        "--max_pruning_ratio", type=float, default=1.0, help="maximum pruning ratio"
    )
    group.add_argument(
        "--layer_wise_imp",
        action="store_true",
        default=False,
        help="whether to use layer-wise importance",
    )
    group.add_argument(
        "--batch_num_Hessian", default=10, type=int, help="batch num for hessian"
    )
    return parser


def _add_FPVE_args(parser):
    group = parser.add_argument_group(title="FPVE")

    group.add_argument(
        "--fitness_mode",
        type=str,
        default="ACC_TOP1",
        choices=["ACC_TOP1", "ACC_TOP5", "ACC_FAIRNESS"],
        help="Evaluation mode of individual fitness in sub-EAs",
    )
    group.add_argument(
        "--FPVE_fitness_data_ratio",
        default=0,
        type=float,
        help="FPVE_fitness_data ratio",
    )
    group.add_argument(
        "--evolution_epoch", "-ep", default=10, type=int, help="Evolution epoch"
    )
    group.add_argument(
        "--pop_size", default=5, type=int, help="population size for FPVE"
    )
    group.add_argument(
        "--mutation_rate", default=0.1, type=float, help="Mutation rate in sub-EAs"
    )
    group.add_argument(
        "--pop_init_rate",
        default=0.95,
        type=float,
        help="init pruning rate when init population",
    )
    return parser
