from cyclegan.alogrithms.registry import register

import tensorflow as tf

from cyclegan.alogrithms.lsgan import LSGAN


@register('wgangp')
class WGANGP(LSGAN):
  """
  Improved Training of Wasserstein GANs https://arxiv.org/abs/1704.00028
  Reference: https://github.com/igul222/improved_wgan_training
  """

  def __init__(self, args, G, F, X, Y, strategy: tf.distribute.Strategy):
    super(WGANGP, self).__init__(args, G, F, X, Y, strategy)
    self.gp_shape = (1,) * len(args.input_shape)
    self.gp_axis = list(range(1, len(args.input_shape) + 1))
    self.lambda_gp = args.lambda_gp

  def generator_loss(self, discriminate_fake):
    return -self.reduce_mean(discriminate_fake)

  def discriminator_loss(self, discriminate_real, discriminate_fake):
    real_loss = self.reduce_mean(discriminate_real)
    fake_loss = self.reduce_mean(discriminate_fake)
    return fake_loss - real_loss

  def interpolate(self, real, fake):
    shape = (real.shape[0],) + self.gp_shape
    alpha = tf.random.uniform(shape, minval=0.0, maxval=1.0)
    return real + alpha * (fake - real)

  def gradient_penalty(self, discriminator, real, fake, training: bool = True):
    interpolated = self.interpolate(real, fake)
    with tf.GradientTape() as tape:
      tape.watch(interpolated)
      discriminate_interpolated = discriminator(interpolated, training=training)
    gradient = tape.gradient(discriminate_interpolated, interpolated)
    norm = tf.sqrt(tf.reduce_sum(tf.square(gradient), axis=self.gp_axis))
    return self.lambda_gp * self.reduce_mean(tf.square(norm - 1.0))
