# define optimization techniques
import torch 
from moduleloader import outervar, like

def optimize(state, event, inputs, labels, net, criterion):
    optimizer = state["optimizer"]

    # get result & loss
    outputs = net(inputs)
    current_loss = torch.mean(criterion(outputs, labels))
    regularizer = sum(event.optional.regularizer(net))

    total_loss = current_loss + regularizer
    return current_loss, 0

def init_optimizer(state,event,net=outervar):
    state["optimizer"] = event.init_optimizer(net, net.parameters())

def get_optimizer(state,event):
    if state["optimizer"] is None:
        raise ValueError("Optimizer not set yet. Maybe you used the wrong order in your module loading?")
    return state["optimizer"]

def register(mf):
    mf.redefine_scope("optimizer")
    mf.register_default_module("^(?!optimizer.optimizer$)optimizer.*", "sgd")
    mf.register_defaults({
        "plot": lambda state,event: any(event.optional.plot()) is True,
        "plot.steps": like("util.log.steps.scalar", alt=0)
    })
    mf.register_helpers({
        "optimizer": None,
        "momentum": 0.0
    })
    mf.register_event('step', optimize, unique=True)
    mf.register_event('before_training', init_optimizer)
    mf.register_event('get_optimizer', get_optimizer, unique=True)
