import os
import numpy as np
import tensorflow as tf
import bayesflow as bf
from tensorflow_probability import distributions as tfd

summary_network_l = tf.keras.models.Sequential(
    [
     	tf.keras.layers.Conv2D(
            filters=64,
            kernel_size=(3, 3),
            activation="relu",
            kernel_initializer="he_normal",
            input_shape=(28, 28, 1),
            kernel_regularizer=tf.keras.regularizers.l2(1e-5),
        ),
	tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Conv2D(
            filters=64,
            kernel_size=(3, 3),
            activation="relu",
            kernel_initializer="he_normal",
            kernel_regularizer=tf.keras.regularizers.l2(1e-5),
        ),
	tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Conv2D(
            filters=128,
            kernel_size=(3, 3),
            activation="relu",
            kernel_initializer="he_normal",
            kernel_regularizer=tf.keras.regularizers.l2(1e-5),
        ),
	tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Conv2D(
            filters=128,
            kernel_size=(3, 3),
            activation="relu",
            kernel_initializer="he_normal",
            kernel_regularizer=tf.keras.regularizers.l2(1e-5),
        ),
	tf.keras.layers.BatchNormalization(),
        tf.keras.layers.GlobalAveragePooling2D(),
    ]
)
summary_network_p = tf.keras.models.Sequential(
    [
     	tf.keras.layers.Conv2D(
            filters=64,
            kernel_size=(3, 3),
            activation="relu",
            kernel_initializer="he_normal",
            input_shape=(28, 28, 1),
            kernel_regularizer=tf.keras.regularizers.l2(1e-5),
        ),
	tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Conv2D(
            filters=64,
            kernel_size=(3, 3),
            activation="relu",
            kernel_initializer="he_normal",
            kernel_regularizer=tf.keras.regularizers.l2(1e-5),
        ),
	tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Conv2D(
            filters=128,
            kernel_size=(3, 3),
            activation="relu",
            kernel_initializer="he_normal",
            kernel_regularizer=tf.keras.regularizers.l2(1e-5),
        ),
	tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Conv2D(
            filters=128,
            kernel_size=(3, 3),
            activation="relu",
            kernel_initializer="he_normal",
            kernel_regularizer=tf.keras.regularizers.l2(1e-5),
        ),
	tf.keras.layers.BatchNormalization(),
        tf.keras.layers.GlobalAveragePooling2D(),
    ]
)

coupling_settings = {
    "dense_args": dict(
        units=512, activation="relu", kernel_regularizer=tf.keras.regularizers.l2(1e-4)
    ),
    "num_dense": 1,
    "dropout_prob": 0.15,
}

posterior_net = bf.networks.InvertibleNetwork(
    num_params=int(28 * 28), num_coupling_layers=12, coupling_settings=coupling_settings
)

likelihood_net = bf.networks.InvertibleNetwork(
    num_params=int(28 * 28), num_coupling_layers=12, coupling_settings=coupling_settings
)

dim = int(28 * 28)
loc = [0.0] * dim
scale = tf.linalg.LinearOperatorDiag([1.0] * dim)
latent_dist = tfd.MultivariateStudentTLinearOperator(df=100, loc=loc, scale=scale)
