
import argparse
from argparse import ArgumentParser
from typing import Union

from datasets import NAMES as DATASET_NAMES
from models import get_all_models


def add_experiment_args(parser: ArgumentParser) -> None:
    """
    Adds the arguments used by all the models.
    :param parser: the parser instance
    """
    parser.add_argument('--dataset', type=str, required=False,
                        choices=DATASET_NAMES,
                        help='Which dataset to perform experiments on.')
    parser.add_argument('--resize_image_shape', type=int, default=-1)
    parser.add_argument('--model', type=str, required=False,
                        help='Model name.', choices=get_all_models())
    parser.add_argument('--classes_per_task', type=int)
    parser.add_argument('--classes_first_task', type=int, default=-1)

    parser.add_argument('--lr', type=float, required=False,
                        help='Learning rate.')
    parser.add_argument('--force_no_augmentations', type=str2bool, default=False)
    parser.add_argument('--optim_kind', type=str, default="sgd", choices=["sgd", "adam"])
    parser.add_argument('--optim_wd', type=float, default=0.,
                        help='optimizer weight decay.')
    parser.add_argument('--optim_mom', type=float, default=0.,
                        help='optimizer momentum.')
    parser.add_argument('--optim_nesterov', type=int, default=0,
                        help='optimizer nesterov momentum.')
    parser.add_argument('--optim_reset_every_task', type=str2bool, default=False)
    parser.add_argument('--optim_scheduler', type=str, default="none", choices=["none", "linear"])

    parser.add_argument('--n_epochs', type=int,
                        help='Batch size.')
    parser.add_argument('--batch_size', type=int,
                        help='Batch size.')

    parser.add_argument('--balance_truncate_data', type=str2bool, default=False)
    parser.add_argument('--svhn_balance_truncate_test_data', type=str2bool, default=False)
    parser.add_argument('--multiplicate_classes', type=eval)
    parser.add_argument('--multiplicate_by', type=eval)

    parser.add_argument('--distributed', type=str, default='no', choices=['no', 'dp', 'ddp'])

    parser.add_argument('--backbone', type=str)
    parser.add_argument('--resnet_num_blocks', type=eval, default=[2, 2, 2, 2])
    parser.add_argument('--resnet_num_filters', type=int, default=64)
    parser.add_argument('--mlp_hidden_size', type=int, default=100)

    parser.add_argument("--pec_architecture", type=str)
    parser.add_argument("--pec_num_layers", type=int, default=2)
    parser.add_argument("--pec_width", type=int)
    parser.add_argument("--pec_teacher_width_multiplier", type=int, default=100)
    parser.add_argument("--pec_output_dim", type=int)
    parser.add_argument("--pec_activation", type=str, default="relu")
    parser.add_argument("--pec_normalize_layers", type=str2bool, default=True)
    parser.add_argument("--pec_conv_layers", nargs="+", type=eval)
    parser.add_argument("--pec_conv_reduce_spatial_to", type=int)


def add_management_args(parser: ArgumentParser) -> None:
    parser.add_argument('--seed', type=int, default=None,
                        help='The random seed.')
    parser.add_argument('--notes', type=str, default=None,
                        help='Notes for this run.')

    parser.add_argument('--non_verbose', default=0, choices=[0, 1], type=int, help='Make progress bars non verbose')
    parser.add_argument('--disable_log', default=0, choices=[0, 1], type=int, help='Enable csv logging')

    parser.add_argument('--validation', default=0, choices=[0, 1], type=int,
                        help='Test on the validation set')
    parser.add_argument('--ignore_other_metrics', default=0, choices=[0, 1], type=int,
                        help='disable additional metrics')
    parser.add_argument('--debug_mode', type=int, default=0, help='Run only a few forward steps per epoch')
    parser.add_argument('--nowand', default=0, choices=[0, 1], type=int, help='Inhibit wandb logging')
    parser.add_argument('--wandb_entity', type=str, help='Wandb entity')
    parser.add_argument('--wandb_project', type=str, help='Wandb project name')
    parser.add_argument('--eval_every_n_task', type=int, default=1)


def add_rehearsal_args(parser: ArgumentParser) -> None:
    """
    Adds the arguments used by all the rehearsal-based methods
    :param parser: the parser instance
    """
    parser.add_argument('--buffer_size', type=int, required=False,
                        help='The size of the memory buffer.')
    parser.add_argument('--minibatch_size', type=int,
                        help='The batch size of the memory buffer.')


def str2bool(v: Union[bool, str]) -> bool:
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")
