from cyclegan.alogrithms.registry import register

import tensorflow as tf

from cyclegan.alogrithms import losses
from cyclegan.alogrithms.lsgan import LSGAN


@register('dragan')
class DRAGAN(LSGAN):
  """
  On Convergence and Stability of GANs https://arxiv.org/abs/1705.07215
  Reference: https://github.com/kodalinaveen3/DRAGAN
  """

  def __init__(self, args, G, F, X, Y, strategy: tf.distribute.Strategy):
    super(DRAGAN, self).__init__(args, G, F, X, Y, strategy)
    self.gp_shape = (1,) * len(args.input_shape)
    self.gp_axis = list(range(1, len(args.input_shape) + 1))
    self.lambda_gp = args.lambda_gp
    self.loss_function = losses.get_loss_function('bce')

  def generator_loss(self, discriminate_fake):
    per_sample_loss = self.loss_function(
        y_true=self.real_labels(discriminate_fake), y_pred=discriminate_fake)
    return self.reduce_mean(per_sample_loss)

  def discriminator_loss(self, discriminate_real, discriminate_fake):
    real_loss = self.loss_function(y_true=self.real_labels(discriminate_real),
                                   y_pred=discriminate_real)
    fake_loss = self.loss_function(y_true=self.fake_labels(discriminate_fake),
                                   y_pred=discriminate_fake)
    per_sample_loss = 0.5 * (real_loss + fake_loss)
    return self.reduce_mean(per_sample_loss)

  def interpolate(self, x):
    shape = (x.shape[0],) + self.gp_shape
    alpha = tf.random.uniform(shape, minval=0.0, maxval=1.0)
    perturbed = x + 0.5 * tf.math.reduce_std(x) * tf.random.uniform(x.shape)
    return x + alpha * (perturbed - x)

  def gradient_penalty(self, discriminator, real, fake, training=True):
    interpolated = self.interpolate(real)
    with tf.GradientTape() as tape:
      tape.watch(interpolated)
      discriminate_interpolated = discriminator(interpolated, training=training)
    gradient = tape.gradient(discriminate_interpolated, interpolated)
    norm = tf.sqrt(tf.reduce_sum(tf.square(gradient), axis=self.gp_axis))
    return self.lambda_gp * self.reduce_mean(tf.square(norm - 1.0))
