import torch

# hyperparameters
lambda_gp = 10


def loop(dataloader):
    while True:
        for batch in dataloader:
            yield batch


def cal_grad_penalty(netD, real_data, fake_data, bs, nc=3):
    alpha = torch.rand(bs, 1)
    alpha = alpha.expand(bs, int(real_data.nelement() / bs)).contiguous()
    if nc == 1:
        wh = 28
    else:
        wh = 32
    alpha = alpha.view(bs, nc, wh, wh)
    alpha = alpha.cuda()

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    interpolates = interpolates.cuda()
    interpolates.requires_grad_(True)
    disc_interpolates = netD(interpolates)

    gradients = torch.autograd.grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp

    return gradient_penalty
