from path_learning.utils.log import get_logger
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR, ReduceLROnPlateau, MultiplicativeLR

logger = get_logger("lr_scheduler")


SCHEDULER_DICT = {
    "cosine_annealing_warm_restarts": CosineAnnealingWarmRestarts,
    "multiplicative_lr": MultiplicativeLR,
    "reduce_on_plateau": ReduceLROnPlateau,
    "lambda_lr": LambdaLR,
}

EXPECTED_KWARGS = {
    "cosine_annealing_warm_restarts": ["T_0", "T_mult", "eta_min"],
    "multiplicative_lr": ["lr_lambda"],
    "reduce_on_plateau": ["factor", "patience"]
}


def get_scheduler(scheduler_name: str, optimizer, kwargs):
    scheduler = None
    assert scheduler_name in SCHEDULER_DICT, f"Chosen scheduler name {scheduler_name} is not implemented " \
                                             f"in scheduler dict {SCHEDULER_DICT}"
    logger.info(f"kwargs: {kwargs}, expected kwargs: {EXPECTED_KWARGS[scheduler_name]}, ")
    for kwarg in EXPECTED_KWARGS[scheduler_name]:
        assert kwarg in kwargs.keys(), f"Expected keyword {kwarg} for LR scheduler selection not provided " \
                                       f"in keywords {kwargs.keys()}."

    if scheduler_name == "multiplicative_lr":
        logger.info(f'kwargs["lr_lambda"]: {kwargs["lr_lambda"]}')
        # Multiplicate LR needs function not value with which to rescale
        scheduler = SCHEDULER_DICT[scheduler_name](optimizer, lr_lambda=lambda epoch: kwargs["lr_lambda"])
    elif scheduler_name == "cosine_annealing_warm_restarts":
        scheduler = SCHEDULER_DICT[scheduler_name](optimizer, T_0=kwargs["T_0"], T_mult=kwargs["T_mult"],
                                                   eta_min=kwargs["eta_min"])
    assert scheduler is not None, "scheduler was not assigned"
    return scheduler


