# define optimization techniques
import torch 
from moduleloader import outervar, like

def optimize(state, event, inputs, labels, net, criterion):
    optimizer = state["optimizer"]

    # init step
    optimizer.zero_grad()

    # get result & loss
    outputs = net(inputs)
    current_loss = torch.mean(criterion(outputs, labels))
    regularizer = sum(event.optional.regularizer(net))

    # collect all losses
    if hasattr(event,'mix_total_loss'):
        total_loss = event.mix_total_loss(current_loss,regularizer)
    else:
        total_loss = current_loss + regularizer

    # optimize
    total_loss.backward()
    event.optional.before_actual_step()
    optimizer.step()

    # get acc
    current_acc = get_acc(outputs, labels, net)
    state.all["last_accuracy"] = current_acc

    # event after every step
    plot_step(state, event, current_loss, current_acc)

    return current_loss, current_acc

def get_acc(outputs, labels, net):
    pass
    with torch.no_grad():
        try:
            _, predicted = outputs.max(dim=1)
        except:
            raise RuntimeError("Optimize must be called before get_acc")
        correct = (predicted == labels).sum().double()
        total = torch.tensor(labels.size(0)).double().to(correct.device)
        acc = correct / total
        return acc

def plot_step(state, event, current_loss, current_acc):
    # if state["plot"] and state["examples_seen"] // state["plot.steps"] > state["plot.last_plot"]:
    #     state["last_plot"] = state["examples_seen"] // state["plot.steps"]
    if state["plot"] and state.all["main.current_batch"] % state["plot.steps"] == 0:
        event.optional.plot_scalar(current_loss.item(),title="loss")
        event.optional.plot_scalar(current_acc.item(),title="accuracy")

def init_optimizer(state,event,net=outervar):
    state["optimizer"] = event.init_optimizer(net, net.parameters())

def get_optimizer(state,event):
    if state["optimizer"] is None:
        raise ValueError("Optimizer not set yet. Maybe you used the wrong order in your module loading?")
    return state["optimizer"]

def register(mf):
    mf.redefine_scope("optimizer")
    mf.register_default_module("^(?!optimizer.optimizer$)optimizer.*", "sgd")
    mf.register_defaults({
        "plot": lambda state,event: any(event.optional.plot()) is True,
        "plot.steps": like("util.log.steps.scalar", alt=0)
    })
    mf.register_helpers({
        "optimizer": None,
        "momentum": 0.0
    })
    mf.register_event('step', optimize, unique=True)
    mf.register_event('before_training', init_optimizer)
    mf.register_event('get_optimizer', get_optimizer, unique=True)
