import sys
from tqdm import tqdm
from moduleloader import outervar
import torch
import numpy as np

def init_grad_measure(state, event, net=outervar, trainloader=outervar, criterion=outervar):

    value = event.Welford()

    # loop
    if state["neteval"]:
        net.eval()
    tqdm_batch = tqdm(total=len(trainloader), position=2, desc="Batches")
    for state["current_batch"], data in enumerate(trainloader):

        # get the inputs; data is a list of [inputs, labels]
        inputs = event.send_data_to_device(data[0])
        labels = event.send_labels_to_device(data[1])
        state["labels"] = labels

        # get result & loss
        outputs = net(inputs)
        loss = torch.mean(criterion(outputs, labels))
        regularizer = sum(event.optional.regularizer(net))

        # calculate gradients
        loss.backward()

        # measure gradients
        for name,w in filter(lambda w: w[1].grad is not None, net.named_parameters()):
            if w.dim() == 1:
                pass
                # val = w.grad
                # [value(v.item()) for v in val]
            elif w.dim() == 2:
                pass
                # val = w.grad.norm(2,dim=0)
                # [value(v.item()) for v in val]
            else:
                val = w.grad.norm(2,dim=[1,2,3])
                [value(v.item()) for v in val]

        # measure only for one batch
        break

        tqdm_batch.update(1)
    tqdm_batch.reset()

    # log
    event.plot_scalar(value.mean,state["current_depth"], title="Gradient Mean", xlabel="Network Depth")
    event.plot_scalar(value.var,state["current_depth"], title="Gradient Var", xlabel="Network Depth")


# ensure that network with different depths is created
def overwrite_init_net(state, event):
    def init_net(current_depth):
        state["current_depth"] = current_depth
        model = list(filter(lambda x: 'model' in x, event._mf.modules_loaded.keys()))[0].split(".")[1]
        if model == "convnet":
            state.all["model.convnet.depth"] = current_depth
            state.all["model.convnet.conv_blocks"] = [state.all["model.convnet.filters"]]*current_depth
        else:
            raise ValueError("There is no parameter 'depth' for that model (%s)." % model)
        return event.init_net_old()
    event.init_net_old = event.init_net
    event.init_net = init_net

def register(mf):
    mf.redefine_scope("_")
    mf.load(["measure-vs-netdepth","sotacifar10","log","Welford"])

    mf.register_defaults({
        "depths": list(range(1,10)),
        "neteval": True,
    })

    mf.register_event("measure", init_grad_measure)
    mf.register_event('init', overwrite_init_net)

