import tensorflow as tf
import math
import numpy as np

track_running_stats = True

convt_k4s2 = lambda filters, act: tf.keras.layers.Conv2DTranspose(filters=filters, kernel_size=4, strides=2,
                                                                  padding='SAME',
                                                                  activation=act)


class Generator(tf.keras.Sequential):
  def __init__(self, img_shape, latent_dim, nn_size=32,
               activation_type="leaky_relu"):
    activation = getattr(tf.nn, activation_type)
    self.img_shape = img_shape  # [H, W, C]
    dim = nn_size
    self.feature_shape = feature_shape = (math.ceil(self.img_shape[0] / 16), math.ceil(self.img_shape[1] / 16),
                                          8 * dim)
    layers = [
      # "latent_to_features"
      tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
      tf.keras.layers.Dense(np.prod(feature_shape)),
      tf.keras.layers.ReLU(),

      # "features_to_image"
      tf.keras.layers.Reshape(target_shape=feature_shape),
      convt_k4s2(4 * dim, activation),
      convt_k4s2(2 * dim, activation),
      # convt_k4s2(img_shape[2], 'sigmoid'),  # NERD code uses sigmoid to get output in [0, 1]
      convt_k4s2(img_shape[2], None),  # My normalized imgs are in [-0.5, 0.5]
      tf.keras.layers.CenterCrop(img_shape[0], img_shape[1])
    ]
    super().__init__(layers=layers, name="Generator")
