import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras.initializers import RandomUniform, zeros
from tensorflow.keras.layers import LeakyReLU
import numpy as np

class Labelnet(tf.keras.layers.Layer):
    def __init__(self, feature_shape):
        super(Labelnet, self).__init__()
        self._feature_shape = feature_shape

        self._bottleneck_layer = tf.keras.Sequential([
            tf.keras.layers.Dense(400, input_shape=(4*feature_shape,), activation=LeakyReLU(alpha=0.01)),
            tf.keras.layers.Dense(300, activation=LeakyReLU(alpha=0.01)),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(1, activation='sigmoid')
        ])
        self._norm_64 = tf.keras.layers.BatchNormalization()

    def call(self, combine):

        norm_state = self._norm_64(combine)

        bottleneck = self._bottleneck_layer(norm_state)

        return bottleneck

class Labelnet_frame(tf.keras.layers.Layer):
    def __init__(self, feature_shape):
        super(Labelnet_frame, self).__init__()
        self._feature_shape = feature_shape

        self._bottleneck_layer = tf.keras.Sequential([
            tf.keras.layers.Dense(400, input_shape=(feature_shape,), activation=LeakyReLU(alpha=0.01)),
            tf.keras.layers.Dense(300, activation=LeakyReLU(alpha=0.01)),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(1, activation='sigmoid')
        ])
        self._norm_64 = tf.keras.layers.BatchNormalization()

    def call(self, combine):

        norm_state = self._norm_64(combine[:,self._feature_shape:2*self._feature_shape])

        bottleneck = self._bottleneck_layer(norm_state)

        return bottleneck

class WGANdiscriminator(tf.keras.layers.Layer):

    def __init__(self, feature_size, alpha):
        super(WGANdiscriminator, self).__init__()

        self.feature_size = feature_size
        self.alpha = alpha

        self._output_layer_img = tf.keras.Sequential([
            tf.keras.layers.Dense(units=400, activation=LeakyReLU(alpha=0.01),
                                  input_shape=(feature_size,)),
            tf.keras.layers.Dense(units=300, activation=LeakyReLU(alpha=0.01)),
            tf.keras.layers.Dense(units=1)
        ])

        self._output_layer_seq = tf.keras.Sequential([
            tf.keras.layers.Dense(units=400, activation=LeakyReLU(alpha=0.01),
                                  input_shape=(4 * feature_size,)),
            tf.keras.layers.Dense(units=300, activation=LeakyReLU(alpha=0.01)),
            tf.keras.layers.Dense(units=1)
        ])
        self._norm_64 = tf.keras.layers.BatchNormalization()

    def call(self, comb_feature):
        state_feature = comb_feature

        batch_size = comb_feature.shape[0]

        state_feature_norm = self._norm_64(state_feature)
        state_feature_reshape = tf.reshape(state_feature_norm, [batch_size, 4, self.feature_size])

        feature_one = state_feature_reshape[:, 1]
        out_img = self._output_layer_img(feature_one)
        out_seq = self._output_layer_seq(state_feature_norm)

        out_total = self.alpha * out_img + (1 - self.alpha) * out_seq

        return out_total

class Encoder(tf.keras.layers.Layer):
    def __init__(self, units, input_shapes):
        super(Encoder, self).__init__()
        self.units = units
        self.image_shape = input_shapes[2]

        self._state_emb = tf.keras.Sequential([
            tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation=LeakyReLU(alpha=0.01),
                                   input_shape=(
                                       input_shapes[2], input_shapes[3], input_shapes[4])),
            tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation=LeakyReLU(alpha=0.01),
                                   strides=(2,2)),
            tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation=LeakyReLU(alpha=0.01)),
            tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation=LeakyReLU(alpha=0.01), strides=(2,2)),
            tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation=LeakyReLU(alpha=0.01)),
            tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation=LeakyReLU(alpha=0.01), strides=(2,2)),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(units=self.units)
        ])

    def call(self, comb_data):
        batch_size = comb_data.shape[0]

        state_seq = tf.reshape(comb_data, [4*batch_size, self.image_shape,self.image_shape,3])

        state_feature = self._state_emb(state_seq)

        state_feature_reshape = tf.reshape(state_feature , [batch_size, 4*32])

        comb_feature = state_feature_reshape

        return comb_feature


class Decoder(tf.keras.layers.Layer):
    def __init__(self, input_shapes, recon_shape):
        super(Decoder, self).__init__()
        self._state_feature_shape = input_shapes[1]
        self._recon_state_shape = (recon_shape[1]//8)
        self.image_shape = recon_shape[1]

        self._state_recon_layer = tf.keras.Sequential([
            tf.keras.layers.Dense(units=self._recon_state_shape*self._recon_state_shape, activation=LeakyReLU(alpha=0.01), input_shape=(input_shapes[1],)),
            tf.keras.layers.Reshape((self._recon_state_shape, self._recon_state_shape, 1)),
            tf.keras.layers.Conv2DTranspose(64, (4, 4), padding='same', activation=LeakyReLU(alpha=0.01)),
            tf.keras.layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', activation=LeakyReLU(alpha=0.01)),
            tf.keras.layers.Conv2DTranspose(32, (4, 4), padding='same', activation=LeakyReLU(alpha=0.01)),
            tf.keras.layers.Conv2DTranspose(32, (4, 4), strides=(2, 2), padding='same', activation=LeakyReLU(alpha=0.01)),
            tf.keras.layers.Conv2DTranspose(16, (4, 4), padding='same', activation=LeakyReLU(alpha=0.01)),
            tf.keras.layers.Conv2DTranspose(16, (4, 4), strides=(2, 2), padding='same', activation=LeakyReLU(alpha=0.01)),
            tf.keras.layers.Conv2DTranspose(3, (4, 4), padding='same'),
        ])

    def call(self, comb_data):
        batch_size= comb_data.shape[0]

        state_feature_reshape = tf.reshape(comb_data, [4*batch_size, self._state_feature_shape])

        recon_state = self._state_recon_layer(state_feature_reshape)

        recon_state = tf.reshape(recon_state, [batch_size, 4, self.image_shape,self.image_shape,3])

        return recon_state

