import argparse
import warnings
from typing import Iterable
import logging

logging.basicConfig(level=logging.INFO)

def float_or_none(value):
    if value.lower() == 'none':
        return None
    else:
        return float(value)

def int_or_none(value):
    if value.lower() == 'none':
        return None
    else:
        return int(value)
    
def str_or_none(value):
    if value.lower() == 'none':
        return None
    else:
        return value

def get_args(include:Iterable=["basic", "train", "cert", "evo"]):
    '''
    @param:
        include: a list of parameter groups wanted.
            - basic: dataset, net, batch, roots etc.
            - train: optimizer, training methods etc.
            - cert: timeout etc.
            - evo: hyperparam for evolutional search
    '''
    parser = argparse.ArgumentParser(description='A easy-to-modify library for IBP-based certified training.')

    logging.info(f"Using arguments group: {', '.join(include)}")
    
    if "basic" in include:
        # Basic arguments
        parser.add_argument('--dataset', required=True, type=str, help='Dataset to use.')
        parser.add_argument('--net', required=True, type=str, help='Network to use.')
        parser.add_argument('--init', default='default', type=str, help='Initialization to use.')
        parser.add_argument('--load-model', default=None, type=str, help='Path of the model to load (None for randomly initialized models).')
        parser.add_argument('--frac-valid', default=None, type=float, help='Fraction of validation samples (None for no validation).')
        parser.add_argument('--save-dir', default=None, type=str, help='Path to save the logs and the best checkpoint.')
        parser.add_argument('--random-seed', default=123, type=int, help="Global random seed for setting up torch, numpy and random.")
        parser.add_argument('--train-eps', required=False, type=float, help='Input epsilon to train with. Set eps=0 for standard training.')
        parser.add_argument('--test-eps', required=True, type=float, help='Input epsilon to test with.')
        parser.add_argument('--train-batch', default=100, type=int, help='Batch size for training.')
        parser.add_argument('--test-batch', default=100, type=int, help='Batch size for testing.')
        # PGD attack arguments
        parser.add_argument('--step-size', default=None,  type=float, help='The size of each pgd step. Step size is scaled by the corresponding search box size, i.e. size should be chosen in (0, 1].')
        parser.add_argument('--train-steps', default=None,  type=int, help='The number of pgd steps taken during training.')
        parser.add_argument('--test-steps', default=None,  type=int, help='The number of pgd steps taken during testing.')
        parser.add_argument('--restarts', default=1,  type=int, help='the number of pgd restarts.')
        parser.add_argument("--grad-accu-batch", default=None, type=int, help="If None, do not use grad accumulation; If an int, use the specified number as the batch size and accumulate grad for the whole batch (train/test).")
        
        parser.add_argument('--port', default=None, type=str, help='Port to use in DDP setup.')
        
        # Neptune Logger
        parser.add_argument("--neptune-tags", default=None, type=str, nargs='*', help="Tags for neptune logger.")
        parser.add_argument("--disable-neptune", action='store_true', help='Whether to disable neptune logging.')

    if "train" in include:
        # Optimizer and learning rate scheduling
        parser.add_argument('--resume-train', default=None, type=str, help='Neptune id of the run for which to resume training (None for randomly initialized models).')
        parser.add_argument('--opt', default='adam', type=str_or_none, choices=['adam', 'sgd', 'clipup', None], help='Optimizer to use.')
        parser.add_argument('--n-epochs', default=1, type=int, help='Number of train epochs.')
        parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate for optimizer.')
        parser.add_argument('--lr-std', default=None, type=float, help='Learning rate for optimizer.')
        parser.add_argument('--lr-milestones', default=None,  type=int, nargs='*', help='The milestones for MultiStepLR.')
        parser.add_argument('--lr-decay-factor', default=0.2,  type=float, help='The decay rate of lr.')
        parser.add_argument('--momentum', default=0.9, type=float, help='Momentum for SGD optimizer.')
        parser.add_argument('--grad-clip', default=1e10,  type=float, help="Maximum gradient L2 norm for each step.")
        parser.add_argument('--num-nonneg-layer', default=0, type=int, help='Number of non-negative layers, projected after each step.')
        parser.add_argument('--model-selection', default="robust_accu", type=str_or_none, help='The criterium for selecting models.')

        # Multiple facet loss combination
        parser.add_argument('--multi-facets', default=None,  type=str, nargs='*', help='The facets to be used, e.g., pgd ibp.')
        parser.add_argument('--facets-weight', default=None,  type=float, nargs='*', help='The weights for each facet.')
        parser.add_argument('--facets-eps-ratio', default=None,  type=float, nargs='*', help='The eps ratios for each facet. The first one must be 1, i.e., use the eps defined by scheduler.')


        # Euclidean regularization
        parser.add_argument('--L1-reg', default=0,  type=float, help='the L1 reg coefficient.')
        parser.add_argument('--L2-reg', default=0, type=float, help='the L2 reg coefficient.')

        # customized functionality
        parser.add_argument('--save-every-epoch', action='store_true', help='Whether to store the model after every epoch.')
        parser.add_argument('--verbose-first-epoch', action='store_true', help='Whether to verbose the first epoch.')
        parser.add_argument('--verbose-gap', default=0.05, type=float, help='Percentage in the first epoch for each logging.')
    
        # Configuration of basic robust training
        parser.add_argument('--no-anneal', action='store_true', help='Whether to use eps annealing. Behavior can be customized, e.g. specify using train_eps or test_eps.')
        parser.add_argument('--robust-weight-start', default=1,  type=float, help='the start value of the weight of the robust loss')
        parser.add_argument('--robust-weight-end', default=1,  type=float, help='the end value of the weight of the robust loss')
        parser.add_argument('--start-epoch-robust-weight', default=0,  type=int)
        parser.add_argument('--end-epoch-robust-weight', default=0,  type=int)
        
        # Configuration of PGD training
        parser.add_argument('--use-pgd-training', action='store_true', help='Whether to use PGD training. This would override configuration of all other training methods, i.e. resulting in purely PGD training.')

        # Configuration of Certified training
        parser.add_argument('--start-epoch-eps', default=0, type=int, help="Start epoch of eps annealing.")
        parser.add_argument('--end-epoch-eps', default=40, type=int, help="End epoch of eps annealing.")
        parser.add_argument('--eps-start', default=0, type=float, help="Start value of eps annealing.")
        parser.add_argument('--eps-end', default=0, type=float, help="End value of eps annealing.")
        parser.add_argument("--schedule", default="smooth", type=str, choices=["smooth", "linear", "step"], help="Schedule for eps annealing.")
        parser.add_argument("--step-epoch", default=1, type=int,  help="Epoch for each step; only takes effect for step schedule.")

        # IBP training
        parser.add_argument('--use-vanilla-ibp', action='store_true', help='Whether to use vanilla IBP. This would override use-TAPS-training. If combined with use_small_box, it would invoke SABR.')
        # Configuration of fast regularization
        parser.add_argument('--fast-reg', default=0, type=float, help="Weight of fast regularization. This regularization shortens eps annealing for IBP and increases the performance of IBP-based methods in general.")
        parser.add_argument('--min-eps-reg', default=1e-6, type=float, help="Minimum eps used for regularization computation.")

        # HBox Training
        parser.add_argument('--use-HBox-training', action='store_true', help='Whether to use HBox.')
        
        # HBox Training
        parser.add_argument('--use-Zono-training', action='store_true', help='Whether to use Zono abstract domain.')
        
        # (Small box) SABR Training
        parser.add_argument('--use-small-box', action='store_true', help='Whether to use small box. When combined with use-vanilla-ibp, it invokes SABR; when combined with use-TAPS-training. it invokes STAPS.')
        parser.add_argument('--eps-shrinkage', default=1, type=float, help="The effective eps would be shrinkage * eps. Equivalent to lambda in SABR paper.")
        parser.add_argument('--relu-shrinkage', default=None, type=float_or_none, help="A positive constant smaller than 1, indicating the ratio of box shrinkage after each ReLU. Only useful in eps=2/255 CIFAR10 in SABR paper (set to 0.8). None for no ReLU shrinkage.")

        # Configuration of IBP-R regularization
        parser.add_argument('--IBPR-reg', default=0, type=float, help="The weight for IBP-R weight regularization. IBP-R implementation is still under developing.")


        # TAPS training
        parser.add_argument('--use-TAPS-training', action='store_true', help='Whether to use TAPS. When combined with use-TAPS-training. it invokes STAPS.')
        parser.add_argument('--block-sizes', default=None,  type=int, nargs='*', help='A list of sizes of different blocks. Must sum up to the total number of layers in the network.')
        parser.add_argument('--estimation-batch', default=None, type=int, help='Batch size for bound estimation.')
        parser.add_argument('--soft-thre', default=0.5, type=float, help='The hyperparameter of soft gradient link. Equivalent to c in TAPS paper.')
        parser.add_argument('--TAPS-grad-scale', default=1, type=float, help='The gradient scale of TAPS gradient w.r.t. box gradient. Equivalent to w in TAPS paper.')
        parser.add_argument('--TAPS-anneal-length', default=0, type=int, help='The gradient scale of TAPS gradient w.r.t. box gradient. Equivalent to w in TAPS paper.')
        parser.add_argument('--no-ibp-anneal', action='store_true', help='Whether to use IBP for annealing. Typically used for checking whether TAPS is out-of-memory. Use IBP for eps annealing can increase performance in general.')

        # DeepPoly training
        parser.add_argument('--use-DP-training', action='store_true', help='Whether to use DeepPoly.')
        parser.add_argument('--use-DPZero-training', action='store_true', help='Whether to use DeepPoly_Zero.')
        parser.add_argument('--use-DPBox-training', action='store_true', help='Whether to use DeepPoly_box.')
        parser.add_argument('--loss-smoothing', default=None, type=str_or_none, help="Factor to use for smooth transitioning between the two variants of relu DP transformers. None means no smoothing. Larger factor means more severe smoothing (large factors will still introduce discontinuities at 0 and 1 because of sigmoid).")
        
        parser.add_argument('--log-unstable', action='store_true', help='Whether to log amount of unstable neurons.')
        parser.add_argument('--precomp-bounds', action='store_true', help='Whether to precompute intermediate bounds for each batch.')
        parser.add_argument('--reuse-bound-mode', default="standard", type=str, choices=["standard", "full", "pre_intersect", "post_intersect"], help='How to reuse precompute intermediate bounds.')
        parser.add_argument('--transform-bounds', default="none", type=str, choices=["none", "translation"], help='Whether/how to transform precomputed bounds.')
        parser.add_argument('--loss-fusion', action='store_true', help='Whether to use loss fusion.')
        parser.add_argument('--sync-batches-across-actors', action='store_true', help='Whether to use the same random seed for each actor.')
        parser.add_argument('--compute-bounds-each-actor', action='store_true', help='Whether to precompute the dp bounds within each actor. If not set the bounds are computed in the main process and shared with the actors. If not set, --sync-batches-across-actors will be set to True.')
    
        parser.add_argument('--repeat-batch', default=-1, type=int, help='How many individuals from the population of each actor will be evaluated on the same batch. Must be -1 or an even positive integer. Defaults to -1 (whole popsize evaluated on the same bacth).')

        # Debugging:
        parser.add_argument('--max-batches', default=-1, type=int, help='How many batches to do at most per epoch')
        parser.add_argument('--max-batches-train-eval', default=100, type=int, help='How many batches to do at most per epoch')

        parser.add_argument('--subbatch-size', default=None, type=int_or_none, help='Subbatch used in evo evaluation.')

        parser.add_argument("--freeze-layers", default="", type=str, help="List of layers to exclude from training. Provide a list of indices separated by comma. Defaults to none (empty string).")

    if "cert" in include:
        # certify
        parser.add_argument('--load-certify-file', default=None, type=str, help='the certify file to load. A single filename in the same directory as the model.')
        parser.add_argument('--timeout', default=1000, type=float, help='the time limit for certifying one label.')
        parser.add_argument('--mnbab-config', default=None, type=str, help='the config file for MN-BaB.')
        parser.add_argument('--tolerate-error', action='store_true', help='Whether to ignore MNBaB errors. Normally these are memory overflows.')
        parser.add_argument('--start-idx', default=None, type=int, help='the start index of the input in the test dataset (inclusive).')
        parser.add_argument('--end-idx', default=None, type=int, help='the end index of the input in the test dataset (exclusive).')

    if "evo" in include:
        parser.add_argument('--std-init', default=1e-3, type=float, help='Standard deviation at initialization of search.')
        parser.add_argument('--std-min', default=1e-9, type=float, help='Minimum standard deviation during search.')
        parser.add_argument('--std-max', default=1., type=float, help='Maximum standard deviation during search.')
        parser.add_argument('--popsize', default=None, type=int, help='Population size.')
        parser.add_argument('--num-actors', default=1, type=int, help='Number of parallel actors.')
        parser.add_argument('--start-epoch-std', default=None, type=int, help="Start epoch of std annealing from std_init to std_min.")
        parser.add_argument('--end-epoch-std', default=None, type=int, help="End epoch of std annealing from std_init to std_min.")
        parser.add_argument('--start-epoch-psr', default=None, type=int, help="Start epoch of ratio of perturbation space annealing from psr_init to psr_min.")
        parser.add_argument('--end-epoch-psr', default=None, type=int, help="End epoch of ratio of perturbation space annealing from psr_init to psr_min.")
        parser.add_argument('--psr-min', default=None, type=float, help='Minimum ratio of perturbation space during search.')
        parser.add_argument('--psr-init', default=None, type=float, help='Initial ratio of perturbation space during search.')
        parser.add_argument('--use-current-std', action='store_true', help='Whether to set std_init to the std of current param group.')


        # PGPE
        parser.add_argument('--use-PGPE-evo', action='store_true', help='Whether to use PGPE.')

        parser.add_argument('--scale-grads', action='store_true', help='Whether to scale gradients based on the absolute loss value.')
        parser.add_argument('--only-compute-one-step', action='store_true', help='Whether to stop after one gradient step and save the gradient.')
        parser.add_argument('--batch-idx', default=0, type=int, help='Batch idx to compute the gradient. Only use when --only-compute-one-step is set.')
        parser.add_argument('--fitness-factor', default=1., type=float, help='Factor to multiply the fitness for bigger gradients.')
        parser.add_argument('--dist-sampler-train', action='store_true', help='Whether to scale gradients based on the absolute loss value.')

        # genetics
        parser.add_argument('--use-GA-evo', action='store_true', help='Whether to use GA.')
        parser.add_argument('--num-elites', default=None, type=int, help='Number of elites for each search.')
        parser.add_argument('--num-parents', default=None, type=int, help='Number of parents for each search.')




    args = parser.parse_args()


    # check training parameters
    if "train" in include:
        if args.use_TAPS_training:
            assert args.block_sizes is not None and len(args.block_sizes)==2, "TAPS requires block_sizes to be a list containing 2 integers summing up to the total number of layers."
        if not args.multi_facets:
            # print(args)
            # print(args.use_pgd_training + args.use_vanilla_ibp + args.use_HBox_training + args.use_TAPS_training + args.use_DP_training + args.use_DPBox_training)
            assert (args.use_pgd_training + args.use_vanilla_ibp + args.use_HBox_training + args.use_Zono_training + args.use_TAPS_training + args.use_DP_training + args.use_DPZero_training + args.use_DPBox_training) == 1, f"Only one training method can be used at a time."

        if args.eps_end == 0:
            args.eps_end = args.train_eps
        if args.estimation_batch is None:
            args.estimation_batch = args.train_batch
        if args.relu_shrinkage is not None:
            assert 0 <= args.relu_shrinkage <= 1, "Shrinkage must be between 0 and 1."

        if args.precomp_bounds:
            if not args.compute_bounds_each_actor and not args.sync_batches_across_actors:
                print("Warning: sync_batches_across_actors will be set to True because compute_bounds_each_actor is False.")
                args.sync_batches_across_actors = True

    
    if "cert" in include:
        assert args.load_model is not None, "A saved model is required to be loaded."
        assert (args.start_idx is None) + (args.end_idx is None) in [0, 2], "If a start idx or end idx is specified, then both must be specified"

    if "evo" in include:
        assert args.use_PGPE_evo or args.use_GA_evo, "At least one evolutional search method should be used."
        
        if args.repeat_batch <= 0:
            args.repeat_batch = args.popsize // args.num_actors
            args.simulated_actors = args.num_actors
        else:
            assert (args.repeat_batch % 2 == 0), f"--repeat-batch must be -1 or a positive even integer, got {args.repeat_batch} instead"
            args.simulated_actors = args.popsize // args.repeat_batch
        if args.sync_batches_across_actors:
            args.simulated_actors = args.simulated_actors // args.num_actors
        args.simulated_batch = args.train_batch * args.simulated_actors
        print(args.simulated_actors,args.simulated_batch)

    if args.max_batches > 0:
        args.max_batches_train_eval = args.max_batches

    if args.lr_std is None:
        args.lr_std = args.lr
    args.use_ddp = False
    return args
