from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel
from HPO.base_grid_hpo import BaseGridHPO
from argparse import Namespace


def register_hyperparams(model: ContinualModel, dataset: ContinualDataset,
          args: Namespace, hp_optimiser: BaseGridHPO) -> None:

    register_general_hyperparams(model, dataset, args, hp_optimiser)
    register_model_hyperparams(model, dataset, args, hp_optimiser)


def register_lr_hyperparam(model: ContinualModel, dataset: ContinualDataset,
          args: Namespace, hp_optimiser: BaseGridHPO) -> None:
    def get_lr():
        return model.args.lr

    def set_lr(lr):
        args.lr = lr
        model.args.lr = lr
        #for g in model.opt.param_groups:
        #    g['lr'] = lr
        model.opt.param_groups[0]['lr'] = lr # this only works if all trainable params are in group 0 (do this so we can use LUCIR code)

    hp_optimiser.register_hyperparam('lr', get_lr, set_lr, [0.2, 0.15, 0.1, 0.075, 0.05, 0.03, 0.01, 0.0075, 0.005, 0.0025])


def register_general_hyperparams(model: ContinualModel, dataset: ContinualDataset,
          args: Namespace, hp_optimiser: BaseGridHPO) -> None:

    def get_lr():
        return model.args.lr

    def set_lr(lr):
        args.lr = lr
        model.args.lr = lr
        #for g in model.opt.param_groups:
        #    g['lr'] = lr
        model.opt.param_groups[0]['lr'] = lr # this only works if all trainable params are in group 0 (do this so we can use LUCIR code)

    hp_optimiser.register_hyperparam('lr', get_lr, set_lr,
                                     [0.2, 0.15, 0.1, 0.075, 0.05, 0.03, 0.01, 0.0075, 0.005, 0.0025])

    #hp_optimiser.register_hyperparam('lr', get_lr, set_lr, [0.2, 0.1, 0.03, 0.01, 0.005])

    def get_wd():
        return model.args.optim_wd

    def set_wd(wd):
        args.optim_wd = wd
        model.args.optim_wd = wd
        # for g in model.opt.param_groups:
        #    g['weight_decay'] = wd
        model.opt.param_groups[0]['weight_decay'] = wd  # this only works if all trainable params are in group 0 (do this so we can use LUCIR code)

    #hp_optimiser.register_hyperparam('optim_wd', get_wd, set_wd, [0.0, 0.00001])

    def get_mom():
        return model.args.optim_mom

    def set_mom(mom):
        args.optim_mom = mom
        model.args.optim_mom = mom
        # for g in model.opt.param_groups:
        #    g['momentum'] = mom
        model.opt.param_groups[0]['momentum'] = mom  # this only works if all trainable params are in group 0 (do this so we can use LUCIR code)

    #hp_optimiser.register_hyperparam('optim_mom', get_mom, set_mom, [0.0, 0.99])

    def get_n_epochs():
        return model.args.n_epochs

    def set_n_epochs(n_epochs):
        args.n_epochs = n_epochs
        model.args.n_epochs = n_epochs

    #hp_optimiser.register_hyperparam('n_epochs', get_n_epochs, set_n_epochs, [20, 50, 100])


def register_model_hyperparams(model: ContinualModel, dataset: ContinualDataset,
          args: Namespace, hp_optimiser: BaseGridHPO) -> None:

    if model.NAME == 'der':
        def get_alpha():
            return model.args.alpha

        def set_alpha(alpha):
            args.alpha = alpha
            model.args.alpha = alpha

        hp_optimiser.register_hyperparam('alpha', get_alpha, set_alpha, [0.2, 0.5, 1.0])

    if model.NAME == 'lwf':
        def get_alpha():
            return model.args.alpha

        def set_alpha(alpha):
            args.alpha = alpha
            model.args.alpha = alpha

        hp_optimiser.register_hyperparam('alpha', get_alpha, set_alpha, [1, 3, 10, 20, 50])

    if model.NAME == 'esmer':
        def get_beta():
            return model.loss_margin

        def set_beta(beta):
            args.loss_margin = beta
            model.args.loss_margin = beta
            model.loss_margin = beta

        hp_optimiser.register_hyperparam('beta', get_beta, set_beta, [1.5, 1.2, 1.0])

    if model.NAME == 'derpp':
        def get_alpha():
            return model.args.alpha

        def set_alpha(alpha):
            args.alpha = alpha
            model.args.alpha = alpha

        hp_optimiser.register_hyperparam('alpha', get_alpha, set_alpha, [0.2, 0.5, 1.0])

        def get_beta():
            return model.args.beta

        def set_beta(beta):
            args.beta = beta
            model.args.beta = beta

        hp_optimiser.register_hyperparam('beta', get_beta, set_beta, [0.2, 0.5, 1.0])









