import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
from keras.models import Model
from keras.layers import AveragePooling2D, MaxPooling2D, GlobalAveragePooling2D, ZeroPadding2D, Add, BatchNormalization
from tensorflow_probability.python.layers import util as tfp_layers_util
from keras import Input
from tensorflow import keras


def _untransformed_scale_constraint(t):
            return tf.clip_by_value(t, -1000, tf.math.log(kernel_posterior_scale_constraint))

kernel_posterior_scale_mean=-9.0,
kernel_posterior_scale_stddev=0.1,
kernel_posterior_scale_constraint=0.2

# kernel_posterior_fn = tfp.layers.default_mean_field_normal_fn(
#     untransformed_scale_initializer=tf.compat.v1.initializers.random_normal(mean=kernel_posterior_scale_mean,stddev=kernel_posterior_scale_stddev), 
#     untransformed_scale_constraint=_untransformed_scale_constraint) 

kernel_posterior_fn = tfp_layers_util.default_mean_field_normal_fn()

def bnn_vgg16(input_shape, num_classes, model_type):
    inputs = Input(input_shape)
    
    x = inputs
    # Block 1
    if model_type == 'BNNF':
        x = tfp.layers.Convolution2DFlipout(
              filters=64,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.Convolution2DReparameterization(
              filters=64,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    
    if model_type == 'BNNF':
        x = tfp.layers.Convolution2DFlipout(
              filters=64,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.Convolution2DReparameterization(
              filters=64,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)
    
    # Block 2
    if model_type == 'BNNF':
        x = tfp.layers.Convolution2DFlipout(
              filters=128,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.Convolution2DReparameterization(
              filters=128,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    
    if model_type == 'BNNF':
        x = tfp.layers.Convolution2DFlipout(
              filters=128,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.Convolution2DReparameterization(
              filters=128,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    x= tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)

    # Block 3
    if model_type == 'BNNF':
        x = tfp.layers.Convolution2DFlipout(
              filters=256,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.Convolution2DReparameterization(
              filters=256,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    
    if model_type == 'BNNF':
        x = tfp.layers.Convolution2DFlipout(
              filters=256,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.Convolution2DReparameterization(
              filters=256,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)

    if model_type == 'BNNF':
        x = tfp.layers.Convolution2DFlipout(
              filters=256,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.Convolution2DReparameterization(
              filters=256,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)
    
    # Block 4
    if model_type == 'BNNF':
        x = tfp.layers.Convolution2DFlipout(
              filters=512,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.Convolution2DReparameterization(
              filters=512,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    
    if model_type == 'BNNF':
        x = tfp.layers.Convolution2DFlipout(
              filters=512,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.Convolution2DReparameterization(
              filters=512,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)

    if model_type == 'BNNF':
        x = tfp.layers.Convolution2DFlipout(
              filters=512,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.Convolution2DReparameterization(
              filters=512,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)
    
    # Block 5
    if model_type == 'BNNF':
        x = tfp.layers.Convolution2DFlipout(
              filters=512,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.Convolution2DReparameterization(
              filters=512,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    
    if model_type == 'BNNF':
        x = tfp.layers.Convolution2DFlipout(
              filters=512,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.Convolution2DReparameterization(
              filters=512,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)

    if model_type == 'BNNF':
        x = tfp.layers.Convolution2DFlipout(
              filters=512,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.Convolution2DReparameterization(
              filters=512,
              kernel_size=3,
              padding='same',
              kernel_posterior_fn=kernel_posterior_fn,
              bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)

    x = tf.keras.layers.Flatten()(x)
    
    if model_type == 'BNNF':
        x = tfp.layers.DenseFlipout(512, kernel_posterior_fn=kernel_posterior_fn, 
                                    bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.DenseLocalReparameterization(512, kernel_posterior_fn=kernel_posterior_fn, 
                                                    bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)

    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)  

    if model_type == 'BNNF':
        x = tfp.layers.DenseFlipout(256, kernel_posterior_fn=kernel_posterior_fn, 
                                    bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.DenseLocalReparameterization(256, kernel_posterior_fn=kernel_posterior_fn, 
                                                    bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)

    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x) 
    
    if model_type == 'BNNF':
        x = tfp.layers.DenseFlipout(num_classes, kernel_posterior_fn=kernel_posterior_fn, 
                                    bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)
    else:
        x = tfp.layers.DenseLocalReparameterization(num_classes, kernel_posterior_fn=kernel_posterior_fn, 
                                                    bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=False))(x)

    outputs = x  
    
    model = keras.Model(inputs, outputs, name='BayesianVGG16')

    return model
    