from moduleloader import like
import numpy as np

def mix_total_loss(state, event, current_loss, regularizer):
    toepoch       = state["toepoch"]
    current_batch = state.all["main.current_batch"]
    num_batches   = state.all["main.num_batches"]
    current_epoch = state.all["main.current_epoch"]
    progress      = (current_batch + current_epoch*num_batches)/num_batches
    if progress < toepoch:
        loss_weight = progress / toepoch
        event.optional.plot_scalar(loss_weight,title="loss_weight")
        if state["blendloss"]:
            current_loss = current_loss*loss_weight
        if state["blendreg"]:
            regularizer = regularizer*loss_weight
        if state["ceoffbefore"]:
            current_loss = current_loss*0
        if state["regoffbefore"]:
            regularizer = regularizer*0
        return current_loss + regularizer
    else:
        if state["ceoffafter"]:
            current_loss = current_loss*0
        if state["regoffafter"]:
            regularizer = regularizer*0
        return current_loss + regularizer

def register(mf):
    mf.register_defaults({
        "plot": True,
        "plot.steps": like("util.log.steps.scalar", alt=0),
        "toepoch": 1.0,
        "blendloss": True,
        "blendreg": True,
        "ceoffafter": False,
        "regoffafter": False,
        "ceoffbefore": False,
        "regoffbefore": False,

        # TODO: make more general commands
        # "blend": ["fi 3.25 onr 10 for 15.3 onl"]
    })
    mf.register_event('mix_total_loss',mix_total_loss, unique=True)
