from tqdm import tqdm
import torch
import sys
from moduleloader import outervar
import time

class IdentityCriterion(torch.nn.Module):
    def __init__(self,state,event):
        super().__init__()
        self.state = state
        self.event = event
    def forward(self, outputs, labels):
        self.event.optional.plot_scalar(Y=outputs.detach().cpu(),X=labels.detach().cpu(),title="1D-path")
        time.sleep(1)
        return outputs

def init_loss(state,event):
    return IdentityCriterion(state,event)

def change_weights(state,event,net=outervar):
    # all_weights = torch.cat([w.reshape([-1]) for w in net.parameters()])
    # event.optional.plot_scalars(Y=all_weights.detach().cpu(),title="net weights")
    net.layers[state["change_layer"]].weight.data += state["change_step"]

def register(mf):
    mf.register_default_module("normalization.*", "batchnorm1d")
    mf.register_default_module("activations.*", "relu")
    mf.load([
        "seed",
        "1d-dataset",
        "net",
        # "identitynet",
        "non-optimizer",
        "gpu",
        "train",
        "log"
    ])
    mf.overwrite_globals({
        "main.epochs": lambda state,event: state["experiments.1d-net.1d-net.steps"],
        "experiments.1d-net.net.filters": 1,
        "experiments.1d-net.net.last_layer": False,
    })
    mf.register_defaults({
        "change_layer": 0,
        "change_step": 0.1,
        "steps": 100
    })

    mf.register_event('init_loss', init_loss, unique=True)
    mf.register_event('after_step', change_weights, unique=True)
