from cyclegan.alogrithms.registry import register

import tensorflow as tf

from cyclegan.alogrithms import losses
from cyclegan.alogrithms.lsgan import LSGAN


@register('gan')
class GAN(LSGAN):
  """
  Vanilla GANs objective https://arxiv.org/abs/1406.2661
  """

  def __init__(self, args, G, F, X, Y, strategy: tf.distribute.Strategy):
    super(GAN, self).__init__(args, G, F, X, Y, strategy)
    self.loss_function = losses.get_loss_function('bce')

  def generator_loss(self, discriminate_fake):
    per_sample_loss = self.loss_function(
        y_true=self.real_labels(discriminate_fake), y_pred=discriminate_fake)
    return self.reduce_mean(per_sample_loss)

  def discriminator_loss(self, discriminate_real, discriminate_fake):
    real_loss = self.loss_function(y_true=self.real_labels(discriminate_real),
                                   y_pred=discriminate_real)
    fake_loss = self.loss_function(y_true=self.fake_labels(discriminate_fake),
                                   y_pred=discriminate_fake)
    per_sample_loss = 0.5 * (real_loss + fake_loss)
    return self.reduce_mean(per_sample_loss)
