from tqdm import tqdm
from moduleloader import outervar
import torch

def warmup(state, event, trainloader=outervar, net=outervar, *args, **kwargs):
    is_training = net.training
    if not is_training:
        net.train()

    with torch.no_grad():
        for data in tqdm(trainloader, desc="Batchnorm Warmup"):
            _inputs = event.send_data_to_device(data[0])
            _labels = event.send_labels_to_device(data[1])
            net(_inputs)

            # print the running mean for debugging
            # for m in filter(lambda m: isinstance(m,torch.nn.BatchNorm2d),net.modules()):
            #     print(m.running_mean.mean())

    if not is_training:
        net.eval()


def register(mf):
    mf.register_event('before_training', warmup)
    mf.register_event('batchnorm_warmup', warmup)
