#!/usr/bin/env python3
# Copyright 2019 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
@title           :utils/cli_args.py
@author          :ch
@contact         :henningc@ethz.ch
@created         :08/21/2019
@version         :1.0
@python_version  :3.6.8

This file has a collection of helper methods that can be used to specify
command-line arguments. In particular, arguments that are necessary for
multiple experiments (even though with different default values) should be
specified here, such that we do not define arguments (and their help texts)
multiple times.

All methods specified here are helper methods for a method (that we call
usually) `parse_cmd_arguments`. See :mod:`cifar.train_args`
for an example.

Important note for contributors
###############################

**DO NEVER CHANGE DEFAULT VALUES.** Instead, add a keyword argument to the
corresponding method, that allows you to change the default value, when you
call the method.
"""
from datetime import datetime
from warnings import warn

def hypernet_args(parser, dhyper_chunks=11010, dhnet_arch='50,50,50',
                  dtemb_size=32, demb_size=32, dhnet_act='relu', prefix=None,
                  pf_name=None):
    """This is a helper method of the method `parse_cmd_arguments` to add
    an argument group for hypernetwork-only arguments.

    Arguments specified in this function:
        - `hyper_chunks`
        - `hnet_arch`
        - `hnet_act`
        - `temb_size`
        - `emb_size`
        - `hnet_noise_dim`
        - `hnet_dropout_rate`
        - `temb_std`
        - `sa_hnet_num_layers`
        - `sa_hnet_filters`
        - `sa_hnet_kernels`
        - `sa_hnet_attention_layers`

    Args:
        parser: Object of class :class:`argparse.ArgumentParser`.
        dhyper_chunks: Default value of option "hyper_chunks".
        dhnet_arch: Default value of option "hnet_arch".
        dtemb_size: Default value of option "temb_size".
        demb_size: Default value of option "emb_size".
        dhnet_act: Default value of option "hnet_act".
        prefix (optional): If arguments should be instantiated with a certain
            prefix. E.g., a setup requires several hypernetworks, that may need
            different settings. For instance: :code:`prefix='gen_'`.
        pf_name (optional): A name of type of hypernetwork for which that prefix
            is needed. For instance: :code:`prefix='generator'`.

    Returns:
        The created argument group, in case more options should be added.
    """
    assert(prefix is None or pf_name is not None)

    heading = 'Hypernet options'

    if prefix is None:
        prefix = ''
        pf_name = ''
    else:
        heading = 'Hypernet options for %s network' % pf_name
        pf_name += ' '

    # Abbreviations.
    p = prefix
    n = pf_name

    ### CHypernet options
    agroup = parser.add_argument_group(heading)
    agroup.add_argument('--%shyper_chunks' % p, type=str, default=dhyper_chunks,
                        help='The output size of the %shypernet. If -1, ' % n +
                             'then the hypernet output dimensionality is the ' +
                             'number of weights in the main network. ' +
                             'If it is a positive integer, the weights are ' +
                             'split into chunks and the hypernet gets as ' +
                             'additional input a trainable chunk-specific ' +
                             'embedding (see class ' +
                             'ChunkedHyperNetworkHandler for details). ' +
                             'If a string of two or three comma-separated ' +
                             'integers is provided, then a chunk of weights ' +
                             'is generated by a transpose convolutional ' +
                             'network with self-attention layers (see ' +
                             'class SAHyperNetwork for details). Default: ' +
                             '%(default)s.')
    agroup.add_argument('--%shnet_arch' % p, type=str, default=dhnet_arch,
                        help='A string of comma-separated integers, each ' +
                             'denoting the size of a hidden layer of the ' +
                             '%shypernetwork. This option is discarded ' % n +
                             'when using SAHyperNetwork class. Default: ' +
                             '%(default)s.')
                             ### We decided to discard remaining weights rather
                             ### than generating them by a seperate network.
                             #'Note, this option ' +
                             #'also determines the architecture of the ' +
                             #'"remaining weight generator" (see constructor ' +
                             #'argument "rem_layers" of class SAHyperNetwork ' +
                             #'for details). The option does not apply for a ' +
                             #'full hypernetwork!')
    agroup.add_argument('--%shnet_act' % p, type=str, default=dhnet_act,
                        help='Activation function used in the hypernetwork. ' +
                             'If "linear", no activation function is used. ' +
                             'Default: %(default)s.',
                        choices=['linear', 'sigmoid', 'relu', 'elu'])
    agroup.add_argument('--%stemb_size' % p, type=int, default=dtemb_size,
                        help='Size of the task embedding space (input to ' +
                             'hypernet). Default: %(default)s.')
    agroup.add_argument('--%semb_size' % p, type=int, default=demb_size,
                        help='If using a hypernetwork that utilizes chunking' +
                             ', then this option defines the size of the ' +
                             'chunk embeddings. Default: %(default)s.')
    agroup.add_argument('--%shnet_noise_dim' % p, type=int, default=-1,
                        help='During training, a zero-mean noise vector will ' +
                             'be concatenated to the task embeddings to help ' +
                             'regularize the task embedding space and the ' +
                             'hypernetwork itself. During testing, zeros ' +
                             'will be concatenated. Default: %(default)s.')
    agroup.add_argument('--%shnet_dropout_rate' % p, type=float, default=-1,
                        help='Use dropout in the hypernet with the given ' +
                             'dropout probability (dropout is deactivated ' +
                             'for a rate of -1). Default: %(default)s.')
    agroup.add_argument('--%stemb_std' % p, type=float, default=-1,
                        help='If not -1, then this number will be ' +
                             'interpreted as the std of zero-mean Gaussian ' +
                             'noise that is used to perturb task embeddings ' +
                             'during training (as a regularization ' +
                             'technique). Default: %(default)s.')
    # Specific to self-attention network!
    agroup.add_argument('--%ssa_hnet_num_layers' % p, type=int, default=5,
                        help='Number of layers in the self-attention ' +
                             'hypernet. ' +
                             'See constructor argument "num_layers" of ' +
                             'class SAHyperNetwork for details. ' +
                             'Default: %(default)s.')
    agroup.add_argument('--%ssa_hnet_filters' % p, type=str,
                        default='128,512,256,128',
                        help='A string of comma-separated integers, each ' +
                             'indicating the number of output channels for a ' +
                             'layer in the self-attention hypernet. ' +
                             'See constructor argument "num_filters" of ' +
                             'class SAHyperNetwork for details. ' +
                             'Default: %(default)s.')
    agroup.add_argument('--%ssa_hnet_kernels' % p, type=str, default=5,
                        help='A string of comma-separated integers, ' +
                             'indicating kernel sizes in the self-attention ' +
                             'hypernet. Note, to specify a distinct kernel ' +
                             'size per dimension of each layer, just enter a ' +
                             'list with twice the number of elements as ' +
                             'convolutional layers in the hypernet. ' +
                             'See constructor argument "kernel_size" of ' +
                             'class SAHyperNetwork for details. ' +
                             'Default: %(default)s.')
    agroup.add_argument('--%ssa_hnet_attention_layers' % p, type=str,
                        default='1,3',
                        help='A string of comma-separated integers, ' +
                             'indicating after which layers of the hypernet' +
                             'a self-attention unit should be added. ' +
                             'See constructor argument "sa_units" of ' +
                             'class SAHyperNetwork for details. ' +
                             'Default: %(default)s.')
    return agroup

# FIXME change default value of `allowed_nets` to `['mlp']` once users had
# enough time to incorporate the deprecation warning.
def main_net_args(parser, allowed_nets=['fc'], dfc_arch=0,
                  dmlp_arch='2000,2000', show_net_act=True, dnet_act='relu',
                  show_no_bias=False, show_dropout_rate=True,
                  ddropout_rate=-1, show_specnorm=True, show_batchnorm=True,
                  show_no_batchnorm=False, show_bn_no_running_stats=False,
                  show_bn_distill_stats=False,
                  show_bn_no_stats_checkpointing=False,
                  prefix=None, pf_name=None):
    """This is a helper function for the function `parse_cmd_arguments` to add
    an argument group for options to a main network.

    Arguments specified in this function:
        - `net_type`
        - `fc_arch`
        - `mlp_arch`
        - `net_act`
        - `no_bias`
        - `dropout_rate`
        - `specnorm`
        - `batchnorm`
        - `no_batchnorm`
        - `bn_no_running_stats`
        - `bn_distill_stats`
        - `bn_no_stats_checkpointing`

    Args:
        parser (:class:`argparse.ArgumentParser`): The argument parser to which
            the argument group should be added.
        allowed_nets (list): List of allowed network identifiers. The following
            identifiers are considered (note, we also reference the network that
            each network type targets):

            - ``mlp``: :class:`mnets.mlp.MLP`
            - ``resnet``: :class:`mnets.resnet.ResNet`
            - ``zenke``: :class:`mnets.zenkenet.ZenkeNet`
            - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet`
            - ``fc``: :class:`mnets.mlp.MLP`

              .. deprecated:: 1.0
                  Please use network type ``mlp`` instead of ``fc``. Network
                  type ``fc`` will be removed in the future.

            .. warning::
                Default value of ``allowed_nets`` is going to change to
                :code:`['mlp']` in the future.
        dfc_arch: Default value of option `fc_arch`.

            .. deprecated:: 1.0
                  Please use network type ``mlp`` and argument ``dfc_arch``
                  instead.
        dmlp_arch: Default value of option `mlp_arch`.
        show_net_act (bool): Whether the option `net_act` should be provided.
        dnet_act: Default value of option `net_act`.
        show_no_bias (bool): Whether the option `no_bias` should be provided.
        show_dropout_rate (bool): Whether the option `dropout_rate` should be
            provided.
        ddropout_rate: Default value of option ``dropout_rate``.
        show_specnorm (bool): Whether the option `specnorm` should be provided.
        show_batchnorm (bool): Whether the option `batchnorm` should be
            provided.
        show_no_batchnorm (bool): Whether the option `no_batchnorm` should be
            provided.
        show_bn_no_running_stats (bool): Whether the option
            `bn_no_running_stats` should be provided.
        show_bn_distill_stats (bool): Whether the option `bn_distill_stats`
            should be provided.
        show_bn_no_stats_checkpointing (bool): Whether the option
            `bn_no_stats_checkpointing` should be provided.
        prefix (optional): If arguments should be instantiated with a certain
            prefix. E.g., a setup requires several main network, that may need
            different settings. For instance: prefix=:code:`prefix='gen_'`.
        pf_name (optional): A name of the type of main net for which that prefix
            is needed. For instance: prefix=:code:`'generator'`.

    Returns:
        The created argument group, in case more options should be added.
    """
    assert(prefix is None or pf_name is not None)

    # TODO Delete 'fc' from list.
    for nt in allowed_nets:
        assert(nt in [ 'mlp'])

    assert(not show_batchnorm or not show_no_batchnorm)

    # TODO 'fc' should be renamed to 'mlp'.
    if 'fc' in allowed_nets and len(allowed_nets) == 1:
        warn('Network type "fc" is deprecated. Default value of argument ' +
             '"allowed_nets" will be changed to [\'mlp\'] in the future!',
             DeprecationWarning)
    elif 'fc' in allowed_nets:
        # TODO change warning into error at some point.
        warn('Network type "fc" is deprecated! Use "mlp" instead.',
             DeprecationWarning)
    if 'fc' in allowed_nets and 'mlp' in allowed_nets:
        # Doesn't make sense to have both.
        raise ValueError('Network type names "fc" and "mlp" refer to the ' +
                         'same network type! Note, "fc" is deprecated.')

    heading = 'Main network options'

    if prefix is None:
        prefix = ''
        pf_name = ''
    else:
        heading = 'Main network options for %s network' % pf_name
        pf_name += ' '

    # Abbreviations.
    p = prefix
    n = pf_name

    ### Main network options.
    agroup = parser.add_argument_group(heading)

    if len(allowed_nets) > 1:
        agroup.add_argument('--%snet_type' % p, type=str,
                            default=allowed_nets[0],
                            help='Type of network to be used for this %s ' % n +
                                 'network. Default: %(default)s.',
                            choices=allowed_nets)

    # DELETEME once we delete option 'fc'.
    if 'fc' in allowed_nets:
        agroup.add_argument('--%sfc_arch' % p, type=str, default=dfc_arch,
                            help='If using a "fc" %s network, this will ' % n +
                                 'specify the hidden layers. ' +
                                 'Default: %(default)s.')

    if 'mlp' in allowed_nets:
        agroup.add_argument('--%smlp_arch' % p, type=str, default=dmlp_arch,
                            help='If using a "mlp" %s network, this will ' % n +
                                 'specify the hidden layers. ' +
                                 'Default: %(default)s.')

    # Note, if you want to add more activation function choices here, you have
    # to add them to the corresponding function `utils.misc.str_to_act` as well!
    if show_net_act:
        agroup.add_argument('--%snet_act' % p, type=str, default=dnet_act,
                        help='Activation function used in the %s network.' % n +
                             'If "linear", no activation function is used. ' +
                             'Default: %(default)s.',
                        choices=['linear', 'sigmoid', 'relu', 'elu'])

    if show_no_bias:
        agroup.add_argument('--%sno_bias' % p, action='store_true',
                        help='No biases will be used in the %s network. ' % n +
                             'Note, does not affect normalization (like ' +
                             'batchnorm).')

    if show_dropout_rate:
        agroup.add_argument('--%sdropout_rate' % p, type=float,
                            default=ddropout_rate,
                            help='Use dropout in the %s network with the ' % n +
                                 'given dropout probability (dropout is ' +
                                 'deactivated for a rate of -1). Default: ' +
                                 '%(default)s.')

    if show_specnorm:
        agroup.add_argument('--%sspecnorm' % p, action='store_true',
                            help='Enable spectral normalization in the ' +
                                 '%s network.' % n)

    ### Batchnorm related options.
    if show_batchnorm:
        agroup.add_argument('--%sbatchnorm' % p, action='store_true',
                            help='Enable batchnorm in the %s network.' % n)
    if show_no_batchnorm:
        agroup.add_argument('--%sno_batchnorm' % p, action='store_true',
                            help='Disable batchnorm in the %s network.' % n)

    if show_bn_no_running_stats:
        agroup.add_argument('--%sbn_no_running_stats' % p, action='store_true',
                            help='If batch normalization is used, then this ' +
                                 'option will deactivate the tracking ' +
                                 'of running statistics. Hence, statistics ' +
                                 'computed per batch will be used during ' +
                                 'evaluation.')

    if show_bn_distill_stats:
        agroup.add_argument('--%sbn_distill_stats' % p, action='store_true',
                            help='If batch normalization is used, ' +
                                 'then usually the running statistics are ' +
                                 'checkpointed for every task (e.g., in ' +
                                 'continual learning), which has linearly ' +
                                 'increasing memory requirements. If ' +
                                 'this option is activated, the running ' +
                                 'statistics will be distilled into the ' +
                                 'hypernetwork after training each task, ' +
                                 'such that only the statistics of the ' +
                                 'current and previous task have to be ' +
                                 'explicitly kept in  memory')

    if show_bn_no_stats_checkpointing:
        agroup.add_argument('--%sbn_no_stats_checkpointing' % p,
                            action='store_true',
                            help='If batch normalization is used, then' +
                                 'this option will prevent the checkpointing' +
                                 'of batchnorm statistics for every task.' +
                                 'In this case, one set of statistics is ' +
                                 'used for all tasks.')

    return agroup

def init_args(parser, custom_option=True):
    """This is a helper method of the method `parse_cmd_arguments` to add
    an argument group for options regarding network initialization.

    Arguments specified in this function:
        - `custom_network_init`
        - `normal_init`
        - `std_normal_init`
        - `std_normal_temb`
        - `std_normal_emb`

    Args:
        parser: Object of class :class:`argparse.ArgumentParser`.
        custom_option: Whether the option `custom_network_init` should be
            provided.

    Returns:
        The created argument group, in case more options should be added.
    """
    ### Weight initialization.
    agroup = parser.add_argument_group('Network initialization options')
    if custom_option:
        # This option becomes important if posthoc custom init is not that
        # trivial anymore (e.g., if networks use batchnorm). Then, the network
        # init must be customized for each such network.
        agroup.add_argument('--custom_network_init', action='store_true',
                            help='Whether network parameters should be ' +
                                 'initialized in a custom way. If this flag ' +
                                 'is set, then Xavier initialization is ' +
                                 'applied to weight tensors (zero ' +
                                 'initialization for bias vectors). The ' +
                                 'initialization of chunk and task ' +
                                 'embeddings is independent of this option.')
    agroup.add_argument('--normal_init', action='store_true',
                        help='Use weight initialization from a zero-mean ' +
                             'normal with std defined by the argument ' +
                             '\'std_normal_init\'. Otherwise, Xavier ' +
                             'initialization is used. Biases are ' +
                             'initialized to zero.')
    agroup.add_argument('--std_normal_init', type=float, default=0.02,
                        help='If normal initialization is used, this will ' +
                             'be the standard deviation used. Default: ' +
                             '%(default)s.')
    agroup.add_argument('--std_normal_temb', type=float, default=1.,
                        help='Std when initializing task embeddings. ' +
                             'Default: %(default)s.')
    agroup.add_argument('--std_normal_emb', type=float, default=1.,
                        help='If a chunked hypernetwork is used (including ' +
                             'self-attention hypernet), then this will be ' +
                             'the std of their initialization. Default: ' +
                             '%(default)s.')
    return agroup

def miscellaneous_args(parser, big_data=True, synthetic_data=False,
                       show_plots=False, no_cuda=False, dout_dir=None,
                       show_publication_style=False):
    """This is a helper method of the method `parse_cmd_arguments` to add
    an argument group for miscellaneous arguments.

    Arguments specified in this function:
        - `num_workers`
        - `out_dir`
        - `use_cuda`
        - `no_cuda`
        - `loglevel_info`
        - `deterministic_run`
        - `publication_style`
        - `show_plots`
        - `data_random_seed`
        - `random_seed`

    Args:
        parser: Object of class :class:`argparse.ArgumentParser`.
        big_data: If the program processes big datasets that need to be loaded
            from disk on the fly. In this case, more options are provided.
        synthetic_data: If data is randomly generated, then we want to decouple
            this randomness from the training randomness.
        show_plots: Whether the option `show_plots` should be provided.
        no_cuda: If True, the user has to explicitly set the flag `--use_cuda`
            rather than using CUDA by default.
        dout_dir (optional): Default value of option `out_dir`. If :code:`None`,
            the default value will be `./out/run_<YY>-<MM>-<DD>_<hh>-<mm>-<ss>`
            that contains the current date and time.
        show_publication_style: Whether the option `publication_style` should be
            provided.

    Returns:
        The created argument group, in case more options should be added.
    """
    if dout_dir is None:
        dout_dir = './out/run_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    ### Miscellaneous arguments
    agroup = parser.add_argument_group('Miscellaneous options')
    if big_data:
        agroup.add_argument('--num_workers', type=int, metavar='N', default=8,
                            help='Number of workers per dataset loader. ' +
                                 'Default: %(default)s.')
    agroup.add_argument('--out_dir', type=str, default=dout_dir,
                        help='Where to store the outputs of this simulation.')
    if no_cuda:
        agroup.add_argument('--use_cuda', action='store_true',
                            help='Flag to enable GPU usage.')
    else:
        agroup.add_argument('--no_cuda', action='store_true',
                            help='Flag to disable GPU usage.')
    agroup.add_argument('--loglevel_info', action='store_true',
                        help='If the console log level should be raised ' +
                             'from DEBUG to INFO.')
    agroup.add_argument('--deterministic_run', action='store_true',
                        help='Enable deterministic CuDNN behavior. Note, that' +
                             'CuDNN algorithms are not deterministic by ' +
                             'default and results might not be reproducible ' +
                             'unless this option is activated. Note, that ' +
                             'this may slow down training significantly!')  
    if show_publication_style:
        agroup.add_argument('--publication_style', action='store_true',
                            help='Whether plots should be publication-ready.')
    if show_plots:
        agroup.add_argument('--show_plots', action='store_true',
                            help='Whether plots should be shown.')
    if synthetic_data:
        agroup.add_argument('--data_random_seed', type=int, metavar='N',
                            defauforwardlt=42,
                            help='The data is randomly generated at every ' +
                             'run. This seed ensures that the randomness ' +
                             'during data generation is decoupled from the ' +
                             'training randomness. Default: %(default)s.')
    agroup.add_argument('--random_seed', type=int, metavar='N', default=42,
                        help='Random seed. Default: %(default)s.')
    return agroup

def eval_args(parser, dval_iter=500, show_val_batch_size=False,
              dval_batch_size=256):
    """This is a helper method of the method `parse_cmd_arguments` to add
    an argument group for validation and testing options.

    Arguments specified in this function:
        - `val_iter`
        - `val_batch_size`

    Args:
        parser: Object of class :class:`argparse.ArgumentParser`.
        dval_iter: Default value of argument `val_iter`.
        show_val_batch_size: Whether the `val_batch_size` argument should be
            shown.
        dval_batch_size: Default value of argument `val_batch_size`.

    Returns:
        The created argument group, in case more options should be added.
    """
    ### Eval arguments
    agroup = parser.add_argument_group('Evaluation options')
    agroup.add_argument('--val_iter', type=int, metavar='N', default=dval_iter,
                        help='How often the validation should be performed ' +
                             'during training. Default: %(default)s.')

    if show_val_batch_size:
        agroup.add_argument('--val_batch_size', type=int, metavar='N',
                            default=dval_batch_size,
                            help='Batch size during validation/testing. ' +
                                 'Default: %(default)s.')

    return agroup

def train_args(parser, show_lr=False, dlr=0.1, show_epochs=False, depochs=-1,
               dbatch_size=32, dn_iter=100001, show_use_adam=False,
               dadam_beta1=0.9, show_use_rmsprop=False, show_use_adadelta=False,
               show_use_adagrad=False, show_clip_grad_value=False,
               show_clip_grad_norm=False, show_adam_beta1=False,
               show_momentum=True):
    """This is a helper method of the method `parse_cmd_arguments` to add
    an argument group for options to configure network training.

    Arguments specified in this function:
        - `batch_size`
        - `n_iter`
        - `epochs`
        - `lr`
        - `momentum`
        - `weight_decay`
        - `use_adam`
        - `adam_beta1`
        - `use_rmsprop`
        - `use_adadelta`
        - `use_adagrad`
        - `clip_grad_value`
        - `clip_grad_norm`

    Args:
        parser: Object of class :class:`argparse.ArgumentParser`.
        show_lr: Whether the `lr` - learning rate - argument should be shown.
            Might not be desired if individual learning rates per optimizer
            should be specified.
        dlr: Default value for option `lr`.
        show_epochs: Whether the `epochs` argument should be shown.
        depochs: Default value for option `epochs`.
        dbatch_size: Default value for option `batch_size`.
        dn_iter: Default value for option `n_iter`.
        show_use_adam: Whether the `use_adam` argument should be shown. Will
            also show the `adam_beta1` argument.
        dadam_beta1: Default value for option `adam_beta1`.
        show_use_rmsprop: Whether the `use_rmsprop` argument should be shown.
        show_use_adadelta: Whether the `use_adadelta` argument should be shown.
        show_use_adagrad: Whether the `use_adagrad` argument should be shown.
        show_clip_grad_value: Whether the `clip_grad_value` argument should be
            shown.
        show_clip_grad_norm: Whether the `clip_grad_norm` argument should be
            shown.
        show_adam_beta1: Whether the `adam_beta1` argument should be
            shown. Note, this argument is also shown when ``show_use_adam`` is
            ``True``.
        show_momentum: Whether the `momentum` argument should be
            shown.

    Returns:
        The created argument group, in case more options should be added.
    """
    ### Training options.
    agroup = parser.add_argument_group('Training options')
    agroup.add_argument('--batch_size', type=int, metavar='N',
                        default=dbatch_size,
                        help='Training batch size. Default: %(default)s.')
    agroup.add_argument('--n_iter', type=int, metavar='N', default=dn_iter,
                        help='Number of training iterations per task. ' +
                             'Default: %(default)s.')
    if show_epochs:
        agroup.add_argument('--epochs', type=int, metavar='N', default=depochs,
                            help='Number of epochs per task. If -1, "n_iter" ' +
                                 'is used instead. Default: %(default)s.')
    if show_lr:
        agroup.add_argument('--lr', type=float, default=dlr,
                            help='Learning rate of optimizer(s). Default: ' +
                                 '%(default)s.')
    if show_momentum:
        agroup.add_argument('--momentum', type=float, default=0.0,
                            help='Momentum of the optimizer (only used in ' +
                                 'SGD and RMSprop). Default: %(default)s.')
    agroup.add_argument('--weight_decay', type=float, default=0,
                        help='Weight decay of the optimizer(s). Default: ' +
                             '%(default)s.')
    if show_use_adam:
        agroup.add_argument('--use_adam', action='store_true',
                            help='Use Adam rather than SGD optimizer.')
    if show_use_adam or show_adam_beta1:
        agroup.add_argument('--adam_beta1', type=float, default=dadam_beta1,
                        help='The "beta1" parameter when using torch.optim.' +
                             'Adam as optimizer. Default: %(default)s.')
    if show_use_rmsprop:
        agroup.add_argument('--use_rmsprop', action='store_true',
                            help='Use RMSprop rather than SGD optimizer.')
    if show_use_adadelta:
        agroup.add_argument('--use_adadelta', action='store_true',
                            help='Use Adadelta rather than SGD optimizer.')
    if show_use_adagrad:
        agroup.add_argument('--use_adagrad', action='store_true',
                            help='Use Adagrad rather than SGD optimizer.')

    if show_clip_grad_value:
        agroup.add_argument('--clip_grad_value', type=float, default=-1,
                        help='If not "-1", gradients will be clipped using ' +
                             '"torch.nn.utils.clip_grad_value_". Default: ' +
                             '%(default)s.')

    if show_clip_grad_norm:
        agroup.add_argument('--clip_grad_norm', type=float, default=-1,
                        help='If not "-1", gradient norms will be clipped ' +
                             'using "torch.nn.utils.clip_grad_norm_". ' +
                             'Default: %(default)s.')

    return agroup

def cl_args(parser, show_beta=True, dbeta=0.01, show_from_scratch=False,
            show_multi_head=False, show_cl_scenario=False,
            show_split_head_cl3=True, dcl_scenario=1,
            show_num_tasks=False, dnum_tasks=1):
    """This is a helper method of the method `parse_cmd_arguments` to add
    an argument group for typical continual learning arguments.

    Arguments specified in this function:
        - `beta`
        - `train_from_scratch`
        - `multi_head`
        - `cl_scenario`
        - `split_head_cl3`
        - `num_tasks`

    Args:
        parser: Object of class :class:`argparse.ArgumentParser`.
        show_beta: Whether option `beta` should be shown.
        dbeta: Default value of option `beta`.
        show_from_scratch: Whether option `train_from_scratch` should be shown.
        show_multi_head: Whether option `multi_head` should be shown.
        show_cl_scenario: Whether option `cl_scenario` should be shown.
        show_split_head_cl3: Whether option `split_head_cl3` should be shown.
            Only has an effect if ``show_cl_scenario`` is ``True``.
        dcl_scenario: Default value of option `cl_scenario`.
        show_num_tasks: Whether option `num_tasks` should be shown.
        dnum_tasks: Default value of option `num_tasks`.

    Returns:
        The created argument group, in case more options should be added.
    """
    ### Continual learning options.
    agroup = parser.add_argument_group('Continual learning options')

    if show_beta:
        agroup.add_argument('--beta', type=float, default=dbeta,
                            help='Trade-off for the CL regularizer. ' +
                                 'Default: %(default)s.')

    if show_from_scratch:
        agroup.add_argument('--train_from_scratch', action='store_true',
                        help='If set, all networks are recreated after ' +
                             'training on each task. Hence, training starts ' +
                             'from scratch.')

    if show_multi_head:
        agroup.add_argument('--multi_head', action='store_true',
                        help='Use a multihead setting, where each task has ' +
                             'its own output head.')

    if show_cl_scenario:
        agroup.add_argument('--cl_scenario', type=int, default=dcl_scenario,
                            help='Continual learning scenarios according to ' +
                                 'https://arxiv.org/pdf/1809.10635.pdf. ' +
                                 '"1" - Task-incremental learning; ' +
                                 '"2" - Domain-incremental learning; ' +
                                 '"3" - Class-incremental learning. ' +
                                 'Default: %(default)s.',
                            choices=[1, 2, 3])

    if show_cl_scenario and show_split_head_cl3:
        agroup.add_argument('--split_head_cl3', action='store_true',
                            help='CL scenario 3 (CL3, cmp. "cl_scenario") ' +
                                 'originally requires to compute the softmax ' +
                                 'across all output neurons. Though, if a ' +
                                 'task-conditioned hypernetwork is used, the ' +
                                 'task identity had to be inferred a priori. ' +
                                 'Hence, in CL2 and CL3 we always know the ' +
                                 'task identity, which is why we can also ' +
                                 'compute the softmax over single output ' +
                                 'heads in CL3 using this option.')

    if show_num_tasks:
        agroup.add_argument('--num_tasks', type=int, metavar='N',
                            default=dnum_tasks,
                            help='Number of tasks. Default: %(default)s.')

    return agroup

def generator_args(agroup, dlatent_dim=3):
    """This is a helper method of the method `parse_cmd_arguments` (or more
    specifically an auxillary method to :func:`train_args`) to add arguments to
    an argument group for options specific to a main network that should act as
    a generator.

    Arguments specified in this function:
        - `latent_dim`
        - `latent_std`

    Args:
        agroup: The argument group returned by, for instance, function
            :func:`main_net_args`.
        dlatent_dim: Default value of option `latent_dim`.
    """
    ### Generator options.
    agroup.add_argument('--latent_dim', type=int, metavar='N',
                        default=dlatent_dim,
                        help='Dimensionality of the latent vector (noise ' +
                             'input to the generator. Default: %(default)s.')
    agroup.add_argument('--latent_std', type=float, default=1.0,
                        help='Standard deviation of the latent space. ' +
                             'Default: %(default)s.')

def data_args(parser, show_disable_data_augmentation=False):
    """This is a helper method of the function `parse_cmd_arguments` to add
    an argument group for typical dataset related options.

    Arguments specified in this function:
        - `disable_data_augment`

    Args:
        parser: Object of class :class:`argparse.ArgumentParser`.
        show_disable_data_augmentation: Whether option
            `disable_data_augmentation` should be shown.

    Returns:
        The created argument group, in case more options should be added.
    """
    ### Continual learning options.
    agroup = parser.add_argument_group('Data-specific options')

    # FIXME At the moment, this is the only argument added by this function!
    assert(show_disable_data_augmentation)

    if show_disable_data_augmentation:
        agroup.add_argument('--disable_data_augmentation', action='store_true',
                        help='If activated, no data augmentation will be ' +
                             'applied. Note, this option only affects ' +
                             'datasets that have preprocessing implemented ' +
                             '(such CIFAR-10).')

    return agroup

def check_invalid_argument_usage(args):
    """This method checks for common conflicts when using the arguments defined
    by methods in this module.

    The following things will be checked:

        - Based on the optimizer choices specified in :func:`train_args`, we
          assert here that only one optimizer is selected at a time.
        - Assert that `clip_grad_value` and `clip_grad_norm` are not set at the
          same time.
        - Assert that `split_head_cl3` is only set for `cl_scenario=3`
        - Assert that the arguments specified in function :func:`main_net_args`
          are correctly used.

          .. note::
              The checks can't handle prefixes yet.

    Args:
        args: The parsed command-line arguments, i.e., the output of method
            :meth:`argparse.ArgumentParser.parse_args`.

    Raises:
        ValueError: If invalid argument combinations are used.
    """
    optim_args = ['use_adam', 'use_rmsprop', 'use_adadelta', 'use_adagrad']
    for i, o1 in enumerate(optim_args):
        if not hasattr(args, o1):
            continue

        for j, o2 in enumerate(optim_args):
            if i == j or not hasattr(args, o2):
                continue

            if getattr(args, o1) and getattr(args, o2):
                raise ValueError('Cannot simultaneously use 2 optimizers ' +
                                 '(arguments "%s" and "%s").' % (o1, o2))

    if hasattr(args, 'clip_grad_value') and hasattr(args, 'clip_grad_norm'):
        if args.clip_grad_value != -1 and args.clip_grad_norm != -1:
            raise ValueError('Cannot simultaneously clip gradiant values and ' +
                             'gradient norm.')

    if hasattr(args, 'cl_scenario') and hasattr(args, 'split_head_cl3'):
        if args.cl_scenario != 3 and args.split_head_cl3:
            raise ValueError('Flag "split_head_cl3" may only be set when ' +
                             'running CL scenario 3 (CL3)!')

    # TODO if `custom_network_init` is used but deactivated, then the other init
    # options have no effect -> user should be warned.

    ### Check consistent use of arguments from `main_net_args`.
    # FIXME These checks don't deal with prefixes yet!
    if hasattr(args, 'net_type') and hasattr(args, 'dropout_rate'):
        if args.net_type in ['resnet', 'bio_conv_net'] and \
                args.dropout_rate != -1:
            warn('Dropout is not implement for network %s.' % args.net_type)

    if hasattr(args, 'net_type') and hasattr(args, 'specnorm'):
        if args.net_type in ['resnet', 'zenke', 'bio_conv_net'] and \
                args.specnorm:
            warn('Spectral Normalization is not implement for network %s.'
                 % args.net_type)

    if hasattr(args, 'net_type') and hasattr(args.net_act):
        if args.net_type in ['resnet', 'zenke'] and args.net_act != 'relu':
            warn('%s network uses ReLU activation functions. ' % args.net_type +
                 'Ignoring option "net_act".')

        if args.net_type in ['bio_conv_net']: # and args.net_act != 'tanh':
            warn('%s network uses Tanh activation functions. ' % args.net_type +
                 'Ignoring option "net_act".')

    if hasattr(args, 'net_type') and hasattr(args.no_bias):
        # FIXME Should be configurable for resnet in future!
        if args.net_type in ['resnet', 'zenke', 'bio_conv_net'] and \
                args.no_bias:
            warn('%s network always uses biases!' % args.net_type)

    bn_used = False
    if hasattr(args, 'batchnorm'):
        bn_used = args.batchnorm
    elif hasattr(args, 'no_batchnorm'):
        bn_used = not args.no_batchnorm
    else:
        # We don't know whether it is used.
        bn_used = None

    if bn_used is not None and bn_used and hasattr(args, 'net_type'):
        if args.net_type in ['zenke', 'bio_conv_net']:
            warn('Batch Normalization is not implemented for network %s.'
                 % args.net_type)

    if bn_used is not None and hasattr(args, 'bn_no_running_stats'):
        if not bn_used and args.bn_no_running_stats:
            warn('Option "bn_no_running_stats" has no effect if batch ' +
                 'normalization not activated.')

    if bn_used is not None and hasattr(args, 'bn_distill_stats'):
        if not bn_used and args.bn_distill_stats:
            warn('Option "bn_distill_stats" has no effect if batch ' +
                 'normalization not activated.')

    if bn_used is not None and hasattr(args, 'bn_no_stats_checkpointing'):
        if not bn_used and args.bn_no_stats_checkpointing:
            warn('Option "bn_no_stats_checkpointing" has no effect if batch ' +
                 'normalization not activated.')

    if hasattr(args, 'bn_no_stats_checkpointing') and \
            hasattr(args, 'bn_no_running_stats') and \
            args.bn_no_stats_checkpointing and args.bn_no_running_stats:
        raise ValueError('Options "bn_no_stats_checkpointing" and ' +
                         '"bn_no_running_stats" are not compatible')
    if hasattr(args, 'bn_no_stats_checkpointing') and \
            hasattr(args, 'bn_no_running_stats') and \
            args.bn_no_stats_checkpointing and args.bn_distill_stats:
        raise ValueError('Options "bn_no_running_stats" and ' +
                         '"bn_distill_stats" are not compatible')
    if hasattr(args, 'bn_no_running_stats') and \
            hasattr(args, 'bn_distill_stats') and \
            args.bn_no_stats_checkpointing and args.bn_distill_stats:
        raise ValueError('Options "bn_no_running_stats" and ' +
                         '"bn_distill_stats" are not compatible')

if __name__ == '__main__':
    pass


