import tensorflow as tf
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import UpSampling2D
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Reshape

from tensorflow.keras.models import Model

import numpy as np
class Sampling(tf.keras.layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
    def __init__(self, **kwargs):
        super(Sampling, self).__init__(**kwargs)

    def call(self, z_mean, z_log_var, seed=2021):
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim), seed=seed)
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
        # return z_mean


class NormalizationLayer(tf.keras.layers.Layer):
    def __init__(self, minval=0., maxval=1., **kwargs):
        super(NormalizationLayer, self).__init__(**kwargs)

        self.minval = minval
        self.maxval = maxval

    def build(self, input_shape):
        super(NormalizationLayer, self).build(input_shape)

    def call(self, inputs):
        return tf.math.sigmoid(inputs) * (self.maxval -
                                          self.minval) + self.minval

    def get_config(self):
        config = super(NormalizationLayer, self).get_config()
        config.update({"minval": self.minval, "maxval": self.maxval})
        return config


class VAE(tf.keras.Model):
    """
    Reference: https://keras.io/examples/generative/vae/
    """
    def __init__(self,
                 input_shape,
                 latent_dim,
                 arch_string='d512.d256.d128',
                 normalize=None,
                 **kwargs):
        super(VAE, self).__init__(**kwargs)
        self._latent_dim = latent_dim
        self._input_shape = input_shape
        self._normalize = normalize
        self.build_encoder(arch_string)
        self.build_decoder(arch_string)
        self.reprameterization = Sampling()

    @staticmethod
    def log_normal_pdf(sample, mean, logvar, raxis=1):
        log2pi = tf.math.log(2. * np.pi)
        return tf.reduce_sum(
            -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
            axis=raxis)

    def train_step(self, data):
        (x, y) = data
        with tf.GradientTape() as tape:
            z_mean, z_log_var = self.encoder(x)
            z = self.reprameterization(z_mean, z_log_var)
            reconstruction = self.decoder(z)
            reconstruction_loss = self.compiled_loss(y, reconstruction, regularization_losses=self.losses)
            logpx_z = -reconstruction_loss
            logpz = self.log_normal_pdf(z, 0., 0.)
            logqz_x = self.log_normal_pdf(z, z_mean, z_log_var)
            total_loss = -tf.reduce_mean(logpx_z + logpz - logqz_x)
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.compiled_metrics.update_state(y, reconstruction)

        return_dict = {m.name: m.result() for m in self.metrics}
        return_dict['total_loss'] = total_loss
        return return_dict

    def test_step(self, data):
        (x, y) = data
        z_mean, z_log_var = self.encoder(x)
        z = self.reprameterization(z_mean, z_log_var)
        reconstruction = self.decoder(z)
        reconstruction_loss = self.compiled_loss(y, reconstruction, regularization_losses=self.losses)
        logpx_z = -reconstruction_loss
        logpz = self.log_normal_pdf(z, 0., 0.)
        logqz_x = self.log_normal_pdf(z, z_mean, z_log_var)
        total_loss = -tf.reduce_mean(logpx_z + logpz - logqz_x)

        self.compiled_metrics.update_state(y, reconstruction)

        return_dict = {m.name: m.result() for m in self.metrics}
        return_dict['total_loss'] = total_loss
        return return_dict

    def build_encoder(self, arch_string):
        x = Input(shape=self._input_shape)
        y = x
        for s in arch_string.split('.'):
            if s.startswith('d'):
                dim = int(s[1:])
                y = Dense(dim)(y)
                y = Activation('relu')(y)
            else:
                raise NotImplementedError(
                    f"Unknown layer identifier {s[0]} in the Encoder")

        z_mean = Dense(self._latent_dim, name="z_mean")(y)
        z_log_var = Dense(self._latent_dim, name="z_log_var")(y)
        self.encoder = Model(x, [z_mean, z_log_var], name="encoder")

    def build_decoder(self, arch_string):
        x = Input(shape=(self._latent_dim, ))
        y = x
        for s in arch_string.split('.')[::-1]:
            if s.startswith('d'):
                dim = int(s[1:])
                y = Dense(dim)(y)
                y = Activation('relu')(y)
            else:
                raise NotImplementedError(
                    f"Unknown layer identifier {s[0]} in the Decoder")
        y = Dense(self._input_shape[0])(y)
        if self._normalize is not None:
            y = NormalizationLayer(minval=self._normalize[0],
                                   maxval=self._normalize[1],
                                   name='normalize')(y)
        self.decoder = Model(x, y, name="decoder")

    def encode(self, x, detemintristic=True):
        x = tf.constant(x)
        z_mean, z_log_var = self.encoder(x)
        if detemintristic:
            seed = 2021
        else:
            seed = None
        return self.reprameterization(z_mean, z_log_var, seed=seed).numpy()

    def decode(self, z):
        z = tf.constant(z)
        return self.decoder(z).numpy()
    
    def call(self, x, detemintristic=True, use_var=True):
        z_mean, z_log_var = self.encoder(x)
        if detemintristic:
            seed = 2021
        else:
            seed = None
        if not use_var:
            z_log_var = 0.
        z = self.reprameterization(z_mean, z_log_var, seed=seed)
        return self.decoder(z)

class AE(tf.keras.Model):
    def __init__(self,
                 input_shape,
                 latent_dim,
                 arch_string='d512.d256.d128',
                 normalize=None):
        super(AE, self).__init__()
        self._latent_dim = latent_dim
        self._input_shape = input_shape
        self._normalize = normalize
        self.build_encoder(arch_string)
        self.build_decoder(arch_string)

    def encode(self, x, batch_size=256):
        return self.encoder.predict(x, batch_size=batch_size)

    def decode(self, z, batch_size=256):
        return self.decoder.predict(z, batch_size=batch_size)

    def build_encoder(self, arch_string):
        x = Input(shape=self._input_shape)
        y = x
        for s in arch_string.split('.'):
            if s.startswith('d'):
                dim = int(s[1:])
                y = Dense(dim)(y)
                y = Activation('relu')(y)
            else:
                raise NotImplementedError(
                    f"Unknown layer identifier {s[0]} in the Encoder")

        y = Dense(self._latent_dim)(y)

        self.encoder = Model(x, y)

    def build_decoder(self, arch_string):
        x = Input(shape=(self._latent_dim, ))
        y = x
        for s in arch_string.split('.')[::-1]:
            if s.startswith('d'):
                dim = int(s[1:])
                y = Dense(dim)(y)
                y = Activation('relu')(y)
            else:
                raise NotImplementedError(
                    f"Unknown layer identifier {s[0]} in the Decoder")

        y = Dense(self._input_shape[0])(y)
        if self._normalize is not None:
            y = NormalizationLayer(minval=self._normalize[0],
                                   maxval=self._normalize[1],
                                   name='normalize')(y)
        self.decoder = Model(x, y, name="decoder")

    def call(self, x):
        z = self.encoder(x)
        return self.decoder(z)


class AE_Conv(AE):
    def __init__(self, input_shape, latent_dim, top_sigmoid=True):
        super(AE_Conv, self).__init__(input_shape,
                                      latent_dim,
                                      top_sigmoid=top_sigmoid)

    def build_encoder(self):
        x = Input(shape=self._input_shape)
        y = Conv2D(16, (3, 3), padding='same')(x)
        y = Activation('relu')(y)
        y = MaxPooling2D((2, 2), padding='same')(y)
        y = Conv2D(8, (3, 3), padding='same')(y)
        y = Activation('relu')(y)
        y = MaxPooling2D((2, 2), padding='same')(y)
        y = Conv2D(8, (3, 3), padding='same')(y)
        y = Activation('relu')(y)
        y = MaxPooling2D((2, 2), padding='same')(y)

        self.encoder = Model(x, y)

    def build_decoder(self):

        x = Input(shape=(4, 4, 8))
        y = Conv2D(8, (3, 3), padding='same')(x)
        y = Activation('relu')(y)
        y = UpSampling2D((2, 2))(y)
        y = Conv2D(8, (3, 3), padding='same')(y)
        y = Activation('relu')(y)
        y = UpSampling2D((2, 2))(y)
        y = Conv2D(16, (3, 3))(y)
        y = Activation('relu')(y)
        y = UpSampling2D((2, 2))(y)
        y = Conv2D(1, (3, 3), padding='same')(y)

        if y.shape[1:] != self._input_shape:
            raise ValueError(
                f"The decoder output has shape {y.shape} that mismatches the input shape {self._input_shape}"
            )

        if self._top_sigmoid:
            y = Activation('sigmoid')(y)
        self.decoder = Model(x, y)