import torch

def infiniteloop(dataloader):
    while True:
        for batch in iter(dataloader):
            yield batch


@torch.no_grad()
def ema(model, ema_model, decay=0.9999):
    for ema_param, model_param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(decay).add_(model_param.data, alpha=1 - decay)

    return ema_model