import torch
import torch.nn as nn

def compute_gradient_penalty(discriminator, real_imgs, fake_imgs):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand([real_imgs.shape[0], 1], device=real_imgs.device)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_imgs + ((1 - alpha) * fake_imgs)).requires_grad_(True)
    d_interpolates = nn.Sigmoid()(discriminator(interpolates))
    fake = torch.full([real_imgs.shape[0], 1], 1.0, device=real_imgs.device)
    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

def train_wgan_pg(real_imgs, generator, discriminator, optimizer_gen, optimizer_dis, device, train_gen = False, **kwargs):
    # ---------------------
    #  Train Discriminator
    # ---------------------
    lambda_gp=10
    n_critic=5

    optimizer_dis.zero_grad()

    # Sample noise as generator input
    noise = torch.randn([real_imgs.shape[0], generator.input_dim], device=device)

    # Generate a batch of images
    fake_imgs = generator(noise)

    # Real images
    real_validity = nn.Sigmoid()(discriminator(real_imgs))
    # Fake images
    fake_validity = nn.Sigmoid()(discriminator(fake_imgs))

    # Gradient penalty
    gradient_penalty = compute_gradient_penalty(discriminator, real_imgs, fake_imgs)
    # Adversarial loss
    d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty

    d_loss.backward()
    optimizer_dis.step()

    log_dict = {"loss_discriminator": d_loss.item(), "output_fake": fake_validity.mean().item(), "output_real": real_validity.mean().item(),
            "norm_fake":fake_imgs.norm(dim=1).mean(dim=0).item(), "norm_real":real_imgs.norm(dim=1).mean(dim=0).item()}

    if train_gen:
        # -----------------
        #  Train Generator
        # -----------------

        optimizer_gen.zero_grad()
        # Generate a batch of images
        fake_imgs = generator(noise)
        # Loss measures generator's ability to fool the discriminator
        # Train on fake images
        fake_validity = nn.Sigmoid()(discriminator(fake_imgs))
        g_loss = 1-torch.mean(fake_validity)

        g_loss.backward()
        optimizer_gen.step()
        log_dict["loss_generator"] = g_loss.item()

    return log_dict

def tain_gan(data, generator, discriminator, optimizer_gen, optimizer_dis, device, **kwargs):
    criterion = nn.BCELoss()

    real_label = torch.full([data.shape[0], 1], 1.0, device=device)
    fake_label = torch.full([data.shape[0], 1], 0.0, device=device)
    noise = torch.randn([data.shape[0], generator.input_dim], device=device)

    # Initialize the discriminator model gradient.
    optimizer_dis.zero_grad()

    # Calculate the loss of the discriminator model on the real image.
    output = nn.Sigmoid()(discriminator(data))
    d_loss_real = criterion(output, real_label)
    d_loss_real.backward()
    d_real = output.mean().item()
    # Generate a fake image.
    fake = generator(noise)

    # Calculate the loss of the discriminator model on the fake image.
    output = nn.Sigmoid()(discriminator(fake.detach()))

    d_loss_fake = criterion(output, fake_label)
    d_loss_fake.backward()
    d_fake1 = output.mean().item()
    # Update the weights of the discriminator model.
    d_loss = d_loss_real.item() + d_loss_fake.item()
    optimizer_dis.step()

    # Initialize the generator model gradient.
    optimizer_gen.zero_grad()
    # Calculate the loss of the discriminator model on the fake image.
    output = nn.Sigmoid()(discriminator(fake))
    # Adversarial loss.
    g_loss = criterion(output, real_label)
    # Update the weights of the generator model.
    g_loss.backward()
    optimizer_gen.step()

    return {"loss_discriminator": d_loss, "loss_generator": g_loss.mean().item(), "output_fake": d_fake1, "output_real": d_real, "norm_fake":fake.norm(dim=1).mean(dim=0).item(), "norm_real":data.norm(dim=1).mean(dim=0).item()}
