import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras import layers, losses
import keras
from keras import ops, Input, random
import numpy as np

def create_autoencoder_model(shape, latent_dim=2, learning_rate=0.001, kl_weight=10., model_name='fc_1d'):
  if model_name == 'fc_1d':
    model = FC1DAutoencoder(latent_dim, shape, kl_weight)
    model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=learning_rate), loss=losses.MeanSquaredError())
  elif model_name == 'vae_2d':
    model = V2DAutoencoder(latent_dim, shape, kl_weight)
    model.compile(optimizer=tf.keras.optimizers.Adamax(learning_rate=learning_rate))
  return model

class FC1DAutoencoder(Model):

  def __init__(self, latent_dim, shape):
    super(FCAutoencoder, self).__init__()
    self.latent_dim = latent_dim
    self.shape = shape
    self.encoder = tf.keras.Sequential([
      layers.Dense(32, activation='sigmoid'),
      layers.Dense(latent_dim)
    ])
    self.decoder = tf.keras.Sequential([
      layers.Dense(32, activation='sigmoid'),
      layers.Dense(self.shape[0])
    ])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

class Sampling(layers.Layer):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = random.SeedGenerator(1337)

    def call(self, inputs):
        z_mean, z_var = inputs
        batch = ops.shape(z_mean)[0]
        dim = ops.shape(z_mean)[1]
        epsilon = random.normal(shape=(batch, dim), seed=self.seed_generator)
        return z_mean + z_var * epsilon

class V2DAutoencoder(Model):

  def __init__(self, latent_dim, shape, kl_weight=10.):
    super(V2DAutoencoder, self).__init__()
    self.latent_dim = latent_dim
    self.im_shape = shape
    self.kl_weight = kl_weight

    encoder_inputs = Input(shape=self.im_shape)
    x = layers.Flatten()(encoder_inputs)
    x = layers.Dense(100, activation="relu")(x)
    z_mean = layers.Dense(latent_dim, name="z_mean")(x)
    z_var = layers.Dense(latent_dim, name="z_var")(x)
    z = Sampling()([z_mean, z_var])
    self.encoder = Model(encoder_inputs, [z_mean, z_var, z], name="encoder")

    latent_inputs = Input(shape=(latent_dim,))
    x = layers.Dense(50, activation="relu")(latent_inputs)
    x = layers.Dense(100, activation="relu")(x)
    x = layers.Dense(self.im_shape[0] * self.im_shape[1] * 1)(x)
    decoder_outputs = layers.Reshape((self.im_shape[0], self.im_shape[1], 1))(x)
    self.decoder = Model(latent_inputs, decoder_outputs, name="decoder")

    self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
    self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss")
    self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")

  def train_step(self, data):
    with tf.GradientTape() as tape:
      z_mean, z_var, z = self.encoder(data)
      reconstruction = self.decoder(z)
      reconstruction_loss = ops.mean(ops.sum(keras.losses.mean_squared_error(data, reconstruction), axis=(1, 2)))
      kl_loss = -0.5 * (1 + ops.log(ops.square(z_var)) - ops.square(z_var) - ops.square(z_mean)) 
      kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
      total_loss = reconstruction_loss + self.kl_weight * kl_loss
    grads = tape.gradient(total_loss, self.trainable_weights)
    self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
    self.total_loss_tracker.update_state(total_loss)
    self.reconstruction_loss_tracker.update_state(reconstruction_loss)
    self.kl_loss_tracker.update_state(kl_loss)
    return {
      "loss": self.total_loss_tracker.result(),
      "reconstruction_loss": self.reconstruction_loss_tracker.result(),
      "kl_loss": self.kl_loss_tracker.result(),
    }

  @property
  def metrics(self):
    return [self.total_loss_tracker, self.reconstruction_loss_tracker, self.kl_loss_tracker]

