import tensorflow as tf
from tensorflow.keras.layers import Layer, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Input, LSTM, Lambda, Concatenate
from tensorflow.keras import Model
from tensorflow.image import resize, ResizeMethod
import numpy as np
tf.keras.backend.set_image_data_format('channels_last')

# Conditional sparse AutoEncoder
class CSAE(tf.keras.Model):

  def __init__(self, latent_dim, Lambda=0.1, training=True, channels=3):
    super(CSAE, self).__init__()

    self.latent_dim = latent_dim
    self.Lambda = Lambda
    self.training = training

    self.concat = Concatenate()

    self.conv1 = Conv2D(
      filters = 64,
      kernel_size=(4,4),
      strides=(2,2),
      activation=tf.keras.layers.LeakyReLU(alpha=0.01))
    self.conv2 = Conv2D(
      filters = 64,
      kernel_size=(4,4),
      strides=(2,2),
      activation=tf.keras.layers.LeakyReLU(alpha=0.01))
    self.flatten = Flatten()
    self.dense_concat = Dense(128, activation=tf.keras.layers.LeakyReLU(alpha=0.01))
    self.dense1 = Dense(64, activation=tf.keras.layers.LeakyReLU(alpha=0.01))
    self.dense2 = Dense(self.latent_dim, name="z_mean")

    self.dense3 = Dense(7*7*64, name="dec_dense", activation=tf.keras.layers.LeakyReLU(alpha=0.01)) 
    self.reshape1 = Reshape(target_shape=(7,7,64), name="dec_reshape")

    self.conv3 =  Conv2D(
      filters = 64,
      kernel_size=(4,4),
      strides=(1,1),
      padding='same',
      activation= tf.keras.layers.LeakyReLU(alpha=0.01))
    self.conv4 =  Conv2D(
      filters = 64,
      kernel_size=(4,4),
      strides=(1,1),
      padding='same',
      activation= tf.keras.layers.LeakyReLU(alpha=0.01))
    self.conv5 =  Conv2D(
      filters = channels,
      kernel_size=(4,4),
      strides=(1,1),
      padding='same')

  def encoder(self, x, y):
    y = self.conv1(y)
    y = self.conv2(y)
    y = self.flatten(y)
    t = self.concat([x,y])
    t = self.dense_concat(t)
    t = self.dense1(t)
    z = self.dense2(t)
    return z

  def decoder(self, x, z):
    y = self.concat([x,z])
    y = self.dense3(y)
    y = self.reshape1(y)
    y = resize(y, [14,14], method=ResizeMethod.BILINEAR)
    y = self.conv3(y)
    y = resize(y, [28,28], method=ResizeMethod.BILINEAR)
    y = self.conv4(y)
    y = self.conv5(y)
    if self.training == False:
      y = tf.clip_by_value(y, clip_value_min=0, clip_value_max=1)
    return y

  def reconstruction(self, inputs, out):
    rec = tf.keras.backend.mean(tf.keras.backend.abs(inputs - out), axis=(1,2,3))
    return rec

  def regularization(self, inputs, out):
    regu = self.Lambda*tf.math.reduce_sum(tf.math.abs(self.z), axis=1)
    return regu

  def loss_(self, inputs, out):
    rec = self.reconstruction(inputs, out)
    regu = self.regularization(inputs, out)
    loss = rec + regu
    return loss

  def cf_generation(self, label_real, label_cf, y):
    z = self.encoder(label_real, y)
    out = self.decoder(label_cf, z)
    return out

  def composition(self, labels, img, n):
    img_ = img.copy()
    for i in range(n):
      img_ = self.cf_generation(labels, labels, img_)
    return img_

  def reversibility(self, parents, parents_cf, img, n):
    img_ = img.copy()
    for i in range(n):
      img_cf = self.cf_generation(parents, parents_cf, img_)
      img_ = self.cf_generation(parents_cf, parents, img_cf)
    return img_

  def call(self, inputs):
    input_x, input_y = inputs
    z = self.encoder(input_x, input_y)
    self.z = z
    out = self.decoder(input_x, z)
    return out



class CVAE(tf.keras.Model):

  def __init__(self, latent_dim, kl_weight=1, training=True, channels=3):
    super(CVAE, self).__init__()

    self.latent_dim = latent_dim
    self.kl_weight = kl_weight
    self.training= training

    #for the KL analytical calculation
    self.var_prior = 1.
    det_cov_pz = self.var_prior**(self.latent_dim)
    self.log_det_cov_pz = tf.math.log(det_cov_pz)

    self.concat = Concatenate()

    self.conv1 = Conv2D(
      filters = 64,
      kernel_size=(4,4),
      strides=(2,2),
      activation=tf.keras.layers.LeakyReLU(alpha=0.01))
    self.conv2 = Conv2D(
      filters = 64,
      kernel_size=(4,4),
      strides=(2,2),
      activation=tf.keras.layers.LeakyReLU(alpha=0.01))
    self.flatten = Flatten()
    self.dense_concat = Dense(128, activation=tf.keras.layers.LeakyReLU(alpha=0.01))
    self.dense1 = Dense(64, activation=tf.keras.layers.LeakyReLU(alpha=0.01))
    self.dense2a = Dense(self.latent_dim)
    self.dense2b = Dense(self.latent_dim)

    self.dense3 = Dense(7*7*64, name="dec_dense", activation=tf.keras.layers.LeakyReLU(alpha=0.01))  #self.encoder_last_dense_dim, name="dec_dense", activation='relu'
    self.reshape1 = Reshape(target_shape=(7,7,64), name="dec_reshape")

    self.conv3 =  Conv2D(
      filters = 64,
      kernel_size=(4,4),
      strides=(1,1),
      padding='same',
      activation= tf.keras.layers.LeakyReLU(alpha=0.01))
    self.conv4 =  Conv2D(
      filters = 64,
      kernel_size=(4,4),
      strides=(1,1),
      padding='same',
      activation= tf.keras.layers.LeakyReLU(alpha=0.01))
    self.conv5 =  Conv2D(
      filters = channels,
      kernel_size=(4,4),
      strides=(1,1),
      padding='same')

  def encoder(self, x, y):
    y = self.conv1(y)
    y = self.conv2(y)
    y = self.flatten(y)
    t = self.concat([x,y])
    t = self.dense_concat(t)
    t = self.dense1(t)
    mean = self.dense2a(t)
    logvar = self.dense2b(t)
    #print(logvar)
    return mean, logvar # tf.fill(tf.shape(mean), 0.0)

  def var(self, inputs, out):
    return tf.math.exp(self.logvar)

  def sample(self, mean, logvar):
    eps = tf.random.normal(shape=(tf.shape(mean)[0], self.latent_dim))
    z = eps * tf.exp(logvar * .5) + mean
    return z

  def decoder(self, x, z):
    y = self.concat([x,z])
    y = self.dense3(y)
    y = self.reshape1(y)
    y = resize(y, [14,14], method=ResizeMethod.BILINEAR)
    y = self.conv3(y)
    y = resize(y, [28,28], method=ResizeMethod.BILINEAR)
    y = self.conv4(y)
    y = self.conv5(y)
    if self.training == False:
      y = tf.clip_by_value(y, clip_value_min=0, clip_value_max=1)
    return y

  def reconstruction(self, inputs, out):
    logpx_z = tf.keras.backend.mean(tf.keras.backend.abs(inputs - out), axis=(1,2,3))
    return logpx_z

  def kl(self, inputs, out):
    mean = self.mean
    logvar = self.logvar
    var = tf.math.exp(logvar)
    det_cov_qz_x = tf.math.reduce_prod(var, axis=1)
    kl = 0.5*(self.log_det_cov_pz - tf.math.log(det_cov_qz_x) - self.latent_dim + tf.math.reduce_sum((mean*mean/self.var_prior), axis=1) + tf.math.reduce_sum(var/self.var_prior, axis=1))
    return kl

  def loss_(self, inputs, out):
    rec = tf.reshape(self.reconstruction(inputs, out), [-1,1])
    kl =  tf.reshape(self.kl(inputs, out), [-1,1])
    vae_loss = 28*28*3*rec + (self.kl_weight)*kl
    return vae_loss

  def cf_generation(self, label_real, label_cf, y):
    mean, logvar = self.encoder(label_real, y)
    z = self.sample(mean, logvar)
    out = self.decoder(label_cf, z)
    return out

  def composition(self, labels, img, n):
    img_ = img.copy()
    for i in range(n):
      img_ = self.cf_generation(labels, labels, img_)
    return img_

  def reversibility(self, parents, parents_cf, img, n):
    img_ = img.copy()
    for i in range(n):
      img_cf = self.cf_generation(parents, parents_cf, img_)
      img_ = self.cf_generation(parents_cf, parents, img_cf)
    return img_

  def call(self, inputs):
    input_x, input_y = inputs
    self.mean, self.logvar = self.encoder(input_x, input_y)
    self.z = self.sample(self.mean, self.logvar)
    z = self.z
    out = self.decoder(input_x, z)
    return out


class oracle(tf.keras.Model):

  def __init__(self, target):
    super(oracle, self).__init__()

    self.conv1 = Conv2D(
    filters = 64,
    kernel_size=(4,4),
    strides=(2,2),
    activation=tf.keras.layers.LeakyReLU(alpha=0.01))
    self.conv2 = Conv2D(
    filters = 64,
    kernel_size=(4,4),
    strides=(2,2),
    activation=tf.keras.layers.LeakyReLU(alpha=0.01))
    self.flatten = Flatten()
    self.dense1 = Dense(128, activation=tf.keras.layers.LeakyReLU(alpha=0.01))
    if target == "hue":
      self.dense2 = Dense(1, activation = tf.keras.activations.sigmoid)
    elif target == "number":
      self.dense2 = Dense(10, activation=tf.keras.activations.softmax)

  def call(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.flatten(x)
    x = self.dense1(x)
    x = self.dense2(x)
    return x