import torch
from tqdm import tqdm, trange
from moduleloader import outervar


def main(state, event):

    # load dataset
    trainloader, testloader, valloader = event.dataloader()

    # get network
    net = event.init_net()
    state.all["net"] = net

    # optionally load checkpoint
    event.optional_unique.load_ckpt(net)

    # send to device
    net = event.send_net_to_device(net)

    # get criterion
    criterion = event.init_loss()
    criterion = event.send_loss_to_device(criterion)

    # optional events (typically optimizer, learning rate scheduler, aso.)
    event.optional.before_training()

    # train loop
    net.train()
    tqdm_batch = tqdm(total=len(trainloader), position=2, desc="Batches")
    tqdm_epoch = tqdm(total=state["epochs"], position=1, desc="Epoch", initial=state["start_epoch"])

    for state["current_epoch"] in range(state["start_epoch"], state["epochs"]):
        event.optional.before_epoch()

        state["num_batches"] = len(trainloader)
        for state["current_batch"], data in enumerate(trainloader):

            # get the inputs; data is a list of [inputs, labels]
            inputs = event.send_data_to_device(data[0])
            labels = event.send_labels_to_device(data[1])
            state["labels"] = labels

            # step
            event.optional.before_step()
            event.step(inputs, labels, net, criterion)
            state.all["step"] += 1
            event.optional.after_step()

            state["examples_seen"] += len(inputs)
            tqdm_batch.update(1)

        tqdm_batch.reset()
        tqdm_epoch.update(1)

        # event for every epoch (eg. validate, scheduler)
        event.optional.after_epoch()
        event.optional.validate(valloader=valloader)
        event.optional.save_ckpt(net)

    # event after training (eg. test, saving)
    event.optional.validate(valloader=testloader)
    event.optional.after_training()



def register(mf):
    mf.set_scope("main")
    mf.register_defaults({
        "epochs": 5,
    })
    mf.register_helpers({
        "start_epoch": 0,
        "examples_seen": 0,
        # "plot.last_plot": -1
    })
    mf.register_globals({
        "step": 0,
    })
    mf.register_event('main', main)
