from cyclegan.alogrithms.registry import register

import tensorflow as tf

from cyclegan.alogrithms.lsgan import LSGAN


@register('lsgangp')
class LSGANGP(LSGAN):
  """
  LSGAN with gradient penalty
  """

  def __init__(self, args, G, F, X, Y, strategy: tf.distribute.Strategy):
    super(LSGANGP, self).__init__(args, G, F, X, Y, strategy)
    self.gp_shape = (1,) * len(args.input_shape)
    self.gp_axis = list(range(1, len(args.input_shape) + 1))
    self.lambda_gp = args.lambda_gp

  def interpolate(self, x):
    shape = (x.shape[0],) + self.gp_shape
    alpha = tf.random.uniform(shape, minval=0.0, maxval=1.0)
    perturbed = x + 0.5 * tf.math.reduce_std(x) * tf.random.uniform(x.shape)
    return x + alpha * (perturbed - x)

  def gradient_penalty(self, discriminator, real, fake, training: bool = True):
    interpolated = self.interpolate(real)
    with tf.GradientTape() as tape:
      tape.watch(interpolated)
      discriminate_interpolated = discriminator(interpolated, training=training)
    gradient = tape.gradient(discriminate_interpolated, interpolated)
    norm = tf.sqrt(tf.reduce_sum(tf.square(gradient), axis=self.gp_axis))
    return self.lambda_gp * self.reduce_mean(tf.square(norm - 1.0))
