from tqdm import tqdm
from moduleloader import outervar
import torch

def validate(state, event, valloader, net=outervar, *args, **kwargs):
    if valloader:
        net.eval()
        acc_1 = event.Welford()
        acc_5 = event.Welford()
        with torch.no_grad():
            for n_iter, data in enumerate(valloader):
                _inputs = event.send_data_to_device(data[0])
                _labels = event.send_labels_to_device(data[1])
                state.all["main.labels"] = _labels
                output = net(_inputs)
                _, pred = output.topk(5, 1, largest=True, sorted=True)

                _labels = _labels.view(_labels.size(0), -1).expand_as(pred)
                correct = pred.eq(_labels).float()

                # compute top-1/top-5
                correct_5 = correct[:, :5].sum(1).cpu().numpy()
                correct_1 = correct[:, :1].sum(1).cpu().numpy()

                [acc_1(c) for c in correct_1]
                [acc_5(c) for c in correct_5]

        state.all["val_accuracy"] = acc_1.mean
        net.train()

        event.optional.plot_scalar(acc_1.mean,title="validation_acc_1")
        event.optional.plot_scalar(acc_5.mean,title="validation_acc_5")


def register(mf):
    mf.load('Welford')
    #mf.register_event('after_epoch', validate)
    #mf.register_event('after_training', validate)
    mf.register_event('validate', validate)
