import torch
from moduleloader import outervar

def before_training(state, event, start_epoch=0, **kwargs):
    state["optimizer"] = optimizer = event.get_optimizer()
    state["scheduler"] = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=state["milestones"], last_epoch=start_epoch - 1, gamma=state["gamma"])

def scheduler_step(state, event, *args,**kwargs):
    #print(' current lr {:.5e}'.format(state["optimizer"].param_groups[0]['lr']))
    state["scheduler"].step()

def register(mf):
    mf.register_defaults({
        "milestones": [10, 20],
        "gamma": 0.1,
    })
    mf.register_event('before_training', before_training)
    mf.register_event('after_epoch', scheduler_step)
