from cyclegan.alogrithms.registry import register

import re
import os
from glob import glob
import tensorflow as tf
from functools import partial

from cyclegan.alogrithms import losses
from cyclegan.alogrithms.optimizer import Optimizer
from cyclegan.utils.utils import unscale, update_dict, save_json, load_json


@register('lsgan')
class LSGAN:
  """ Least Squares GANs https://arxiv.org/abs/1611.04076
  Reference: https://github.com/xudonmao/LSGAN
  """

  def __init__(self, args, G, F, X, Y, strategy: tf.distribute.Strategy):
    """ 
      G(x): x -> y
      F(y): y -> x
    """
    self.strategy = strategy

    self.G = G
    self.F = F
    self.X = X
    self.Y = Y

    self.G_optimizer = Optimizer(model=self.G,
                                 learning_rate=args.g_lr,
                                 strategy=strategy,
                                 mixed_precision=args.mixed_precision,
                                 name='G_optimizer')
    self.F_optimizer = Optimizer(model=self.F,
                                 learning_rate=args.g_lr,
                                 strategy=strategy,
                                 mixed_precision=args.mixed_precision,
                                 name='F_optimizer')
    self.X_optimizer = Optimizer(model=self.X,
                                 learning_rate=args.d_lr,
                                 strategy=strategy,
                                 mixed_precision=args.mixed_precision,
                                 name='X_optimizer')
    self.Y_optimizer = Optimizer(model=self.Y,
                                 learning_rate=args.d_lr,
                                 strategy=strategy,
                                 mixed_precision=args.mixed_precision,
                                 name='Y_optimizer')

    # initialize checkpoint
    self.checkpoint_dir = os.path.join(args.output_dir, 'checkpoints')
    with self.strategy.scope():
      self.checkpoint = tf.train.Checkpoint(
          G=self.G,
          F=self.F,
          X=self.X,
          Y=self.Y,
          G_optimizer=self.G_optimizer.optimizer,
          F_optimizer=self.F_optimizer.optimizer,
          X_optimizer=self.X_optimizer.optimizer,
          Y_optimizer=self.Y_optimizer.optimizer)

    self.mixed_precision = args.mixed_precision
    self.label_smoothing = args.label_smoothing
    self.global_batch_size = args.global_batch_size
    self.g_updates, self.d_updates = args.g_updates, args.d_updates

    self.scaled_data = args.scaled_data
    self.ds_min, self.ds_max = args.ds_min, args.ds_max

    self.lambda_cycle = args.lambda_cycle
    self.lambda_identity = 0.5 * args.lambda_cycle \
      if args.lambda_identity is None else args.lambda_identity

    # per-sample mean absolute error
    self.mean_absolute_error = losses.get_loss_function('mae')
    self.loss_function = losses.get_loss_function('mse')
    self.error_function = losses.get_loss_function(args.error)

  def get_models(self):
    return self.G, self.F, self.X, self.Y

  def get_optimizers(self):
    return self.G_optimizer, self.F_optimizer, self.X_optimizer, self.Y_optimizer

  def save_checkpoint(self, args, epoch: int):
    filename = os.path.join(self.checkpoint_dir, f'epoch-{epoch:03d}')
    self.checkpoint.write(filename)
    if args.verbose:
      print(f'saved checkpoint to {filename}\n')

  def load_checkpoint(self, expect_partial: bool = False):
    ''' load latest checkpoint from self.checkpoint_dir if exists '''
    epoch = 0
    checkpoints = sorted(glob(os.path.join(self.checkpoint_dir, '*.index')))
    if checkpoints:
      checkpoint = checkpoints[-1].replace('.index', '')
      with self.strategy.scope():
        status = self.checkpoint.restore(checkpoint)
      if expect_partial:
        status.expect_partial()
      match = re.match(r'.+epoch-(\d{3})', checkpoint)
      epoch = int(match.groups()[0])
      print(f'\nloaded checkpoint from {checkpoint}\n')
    return epoch

  def real_labels(self, inputs):
    if self.label_smoothing:
      labels = tf.random.uniform(shape=inputs.shape,
                                 minval=0.9,
                                 maxval=1.0,
                                 dtype=inputs.dtype)
    else:
      labels = tf.ones_like(inputs)
    return labels

  def fake_labels(self, inputs):
    if self.label_smoothing:
      labels = tf.random.uniform(shape=inputs.shape,
                                 minval=0.0,
                                 maxval=0.1,
                                 dtype=inputs.dtype)
    else:
      labels = tf.zeros_like(inputs)
    return labels

  def reduce_mean(self, inputs):
    """ return the global_batch_size mean """
    return tf.reduce_sum(inputs) / self.global_batch_size

  def reduce_dict(self, d: dict):
    """ reduce dictionary in mirrored strategy """
    for k, v in d.items():
      try:
        d[k] = self.strategy.reduce(tf.distribute.ReduceOp.SUM, v, axis=None)
      except ValueError as e:
        raise ValueError(f'ValueError: {e}\n\nkey: {k}\n{v}\n\n')

  def generator_loss(self, discriminate_fake):
    per_sample_loss = self.loss_function(
        y_true=self.real_labels(discriminate_fake), y_pred=discriminate_fake)
    return self.reduce_mean(per_sample_loss)

  def cycle_loss(self, real_samples, cycle_samples):
    per_sample_loss = self.error_function(y_true=real_samples,
                                          y_pred=cycle_samples)
    return self.lambda_cycle * self.reduce_mean(per_sample_loss)

  def identity_loss(self, real_samples, identity_samples):
    per_sample_loss = self.error_function(y_true=real_samples,
                                          y_pred=identity_samples)
    return self.lambda_identity * self.reduce_mean(per_sample_loss)

  def discriminator_loss(self, discriminate_real, discriminate_fake):
    real_loss = self.loss_function(y_true=self.real_labels(discriminate_real),
                                   y_pred=discriminate_real)
    fake_loss = self.loss_function(y_true=self.fake_labels(discriminate_fake),
                                   y_pred=discriminate_fake)
    per_sample_loss = 0.5 * (real_loss + fake_loss)
    return self.reduce_mean(per_sample_loss)

  def gradient_penalty(self, discriminator, real, fake, training: bool = True):
    return 0.0

  @tf.function
  def cycle_step(self, x: tf.Tensor, y: tf.Tensor, training: bool = False):
    # x -> fake y -> cycle x
    fake_y = self.G(x, training=training)
    cycle_x = self.F(fake_y, training=training)
    # y -> fake x -> cycle y
    fake_x = self.F(y, training=training)
    cycle_y = self.G(fake_x, training=training)
    return fake_x, fake_y, cycle_x, cycle_y

  def _train_generators(self, x: tf.Tensor, y: tf.Tensor):
    result = {}
    with tf.GradientTape(persistent=True) as tape:
      fake_y = self.G(x, training=True)
      fake_x = self.F(y, training=True)

      discriminate_fake_x = self.X(fake_x, training=True)
      discriminate_fake_y = self.Y(fake_y, training=True)

      G_loss = self.generator_loss(discriminate_fake_y)
      F_loss = self.generator_loss(discriminate_fake_x)

      G_cycle_loss = self.cycle_loss(y, self.G(fake_x, training=True))
      F_cycle_loss = self.cycle_loss(x, self.F(fake_y, training=True))

      G_identity_loss = self.identity_loss(y, self.G(y, training=True))
      F_identity_loss = self.identity_loss(x, self.F(x, training=True))

      G_total_loss = G_loss + G_cycle_loss + G_identity_loss
      F_total_loss = F_loss + F_cycle_loss + F_identity_loss

      result.update({
          'loss_G/loss': G_loss,
          'loss_G/cycle': G_cycle_loss,
          'loss_G/identity': G_identity_loss,
          'loss_G/total': G_total_loss,
          'loss_F/loss': F_loss,
          'loss_F/cycle': F_cycle_loss,
          'loss_F/identity': F_identity_loss,
          'loss_F/total': F_total_loss
      })

      if self.mixed_precision:
        G_total_loss = self.G_optimizer.get_scaled_loss(G_total_loss)
        F_total_loss = self.F_optimizer.get_scaled_loss(F_total_loss)

    self.G_optimizer.minimize(G_total_loss, tape)
    self.F_optimizer.minimize(F_total_loss, tape)

    return result

  def _train_discriminators(self, x: tf.Tensor, y: tf.Tensor):
    result = {}
    with tf.GradientTape(persistent=True) as tape:
      fake_y = self.G(x, training=True)
      fake_x = self.F(y, training=True)

      discriminate_x = self.X(x, training=True)
      discriminate_y = self.Y(y, training=True)
      discriminate_fake_x = self.X(fake_x, training=True)
      discriminate_fake_y = self.Y(fake_y, training=True)

      X_loss = self.discriminator_loss(discriminate_x, discriminate_fake_x)
      X_gradient_penalty = self.gradient_penalty(discriminator=self.X,
                                                 real=x,
                                                 fake=fake_x,
                                                 training=True)
      X_total_loss = X_loss + X_gradient_penalty

      Y_loss = self.discriminator_loss(discriminate_y, discriminate_fake_y)
      Y_gradient_penalty = self.gradient_penalty(discriminator=self.Y,
                                                 real=y,
                                                 fake=fake_y,
                                                 training=True)
      Y_total_loss = Y_loss + Y_gradient_penalty

      result.update({
          'loss_X/loss': X_loss,
          'loss_X/gradient_penalty': X_gradient_penalty,
          'loss_X/total': X_total_loss,
          'loss_Y/loss': Y_loss,
          'loss_Y/gradient_penalty': Y_gradient_penalty,
          'loss_Y/total': Y_total_loss
      })

      if self.mixed_precision:
        X_total_loss = self.X_optimizer.get_scaled_loss(X_total_loss)
        Y_total_loss = self.Y_optimizer.get_scaled_loss(Y_total_loss)

    self.X_optimizer.minimize(X_total_loss, tape)
    self.Y_optimizer.minimize(Y_total_loss, tape)

    return result

  def train_step(self, x: tf.Tensor, y: tf.Tensor):
    results = {}
    for _ in range(self.d_updates):
      result = self._train_discriminators(x, y)
      update_dict(results, result)
    for _ in range(self.g_updates):
      result = self._train_generators(x, y)
      update_dict(results, result)
    return {k: tf.reduce_mean(v) for k, v in results.items()}

  @tf.function
  def distributed_train_step(self, x: tf.Tensor, y: tf.Tensor):
    results = self.strategy.run(self.train_step, args=(x, y))
    self.reduce_dict(results)
    return results

  def validation_step(self, x: tf.Tensor, y: tf.Tensor):
    result = {}

    fake_x, fake_y, cycle_x, cycle_y = self.cycle_step(x, y, training=False)

    discriminate_x = self.X(x, training=False)
    discriminate_y = self.Y(y, training=False)
    discriminate_fake_x = self.X(fake_x, training=False)
    discriminate_fake_y = self.Y(fake_y, training=False)

    result.update({
        'critic/Dx(X)':
            self.reduce_mean(losses.per_sample_mean(discriminate_x)),
        'critic/Dy(Y)':
            self.reduce_mean(losses.per_sample_mean(discriminate_y)),
        'critic/Dx(F(Y))':
            self.reduce_mean(losses.per_sample_mean(discriminate_fake_x)),
        'critic/Dy(G(X))':
            self.reduce_mean(losses.per_sample_mean(discriminate_fake_y))
    })

    G_loss = self.generator_loss(discriminate_fake_y)
    F_loss = self.generator_loss(discriminate_fake_x)

    F_cycle_loss = self.cycle_loss(x, cycle_x)
    G_cycle_loss = self.cycle_loss(y, cycle_y)

    same_x = self.F(x, training=False)
    same_y = self.G(y, training=False)
    G_identity_loss = self.identity_loss(y, same_y)
    F_identity_loss = self.identity_loss(x, same_x)

    G_total_loss = G_loss + G_cycle_loss + G_identity_loss
    F_total_loss = F_loss + F_cycle_loss + F_identity_loss

    result.update({
        'loss_G/loss': G_loss,
        'loss_G/cycle': G_cycle_loss,
        'loss_G/identity': G_identity_loss,
        'loss_G/total': G_total_loss,
        'loss_F/loss': F_loss,
        'loss_F/cycle': F_cycle_loss,
        'loss_F/identity': F_identity_loss,
        'loss_F/total': F_total_loss
    })

    X_loss = self.discriminator_loss(discriminate_x, discriminate_fake_x)
    X_gradient_penalty = self.gradient_penalty(discriminator=self.X,
                                               real=x,
                                               fake=fake_x,
                                               training=False)
    X_total_loss = X_loss + X_gradient_penalty

    Y_loss = self.discriminator_loss(discriminate_y, discriminate_fake_y)
    Y_gradient_penalty = self.gradient_penalty(discriminator=self.Y,
                                               real=y,
                                               fake=fake_y,
                                               training=False)
    Y_total_loss = Y_loss + Y_gradient_penalty

    result.update({
        'loss_X/loss': X_loss,
        'loss_X/gradient_penalty': X_gradient_penalty,
        'loss_X/total': X_total_loss,
        'loss_Y/loss': Y_loss,
        'loss_Y/gradient_penalty': Y_gradient_penalty,
        'loss_Y/total': Y_total_loss
    })

    if self.scaled_data:
      x, y, same_x, same_y, cycle_x, cycle_y = [
          unscale(i, ds_min=self.ds_min, ds_max=self.ds_max)
          for i in [x, y, same_x, same_y, cycle_x, cycle_y]
      ]

    result.update({
        'MAE/MAE(X, F(G(X)))':
            self.reduce_mean(self.mean_absolute_error(x, cycle_x)),
        'MAE/MAE(Y, G(F(Y)))':
            self.reduce_mean(self.mean_absolute_error(y, cycle_y)),
        'MAE/MAE(X, F(X))':
            self.reduce_mean(self.mean_absolute_error(x, same_x)),
        'MAE/MAE(Y, G(Y))':
            self.reduce_mean(self.mean_absolute_error(y, same_y))
    })

    return result

  @tf.function
  def distributed_validation_step(self, x: tf.Tensor, y: tf.Tensor):
    results = self.strategy.run(self.validation_step, args=(x, y))
    self.reduce_dict(results)
    return results
