import os
import numpy as np
import tensorflow as tf
import bayesflow as bf
from tensorflow_probability import distributions as tfd
from scipy.ndimage import gaussian_filter
from skimage.util import random_noise
import matplotlib.pyplot as plt

mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Filter training set: keep only digits 1-9 (i.e., exclude 0)
train_mask = train_labels == 0
train_images_filtered = train_images[train_mask]
train_labels_filtered = train_labels[train_mask]

test_mask = test_labels == 0
test_images_filtered = test_images[test_mask]
test_labels_filtered = test_labels[test_mask]

def clip_255(img):
    
    return img.astype(np.float32)/127.5 - 1.0

train_images_filtered = clip_255(train_images_filtered)
test_images_filtered = clip_255(test_images_filtered)
train_images_n = train_images_filtered[..., tf.newaxis]
test_images_n = test_images_filtered[..., tf.newaxis]


def grayscale_camera(theta, psf_width=2.):

    image2 = gaussian_filter(theta, sigma=psf_width)

    return image2

train_images_m = grayscale_camera(train_images_n, psf_width=1.0)
test_images_m = grayscale_camera(test_images_n, psf_width=1.0)

forward_train = {"sim_data": train_images_m}

num_val = 200
perm = np.random.default_rng(seed=42).permutation(test_images_m.shape[0])
forward_val = {"sim_data": test_images_m[perm[:num_val]]}
forward_test = {"sim_data": test_images_m[perm[num_val:]]}


def configurator(f):
    out = {}
    B, H, W, C = f["sim_data"].shape
    # Flatten images to shape (B, H*W)
    p = f["sim_data"].reshape((B, H * W)).astype(np.float32)
    # Add dequantization noise
    p = p + 1e-6 * tf.random.normal(shape=p.shape, dtype=tf.float32)
    out["parameters"] = p
    out["direct_conditions"] = tf.zeros([B, H * W], dtype=tf.float32)
    return out

coupling_settings = {
    "dense_args": dict(
        units=512, activation="relu", kernel_regularizer=tf.keras.regularizers.l2(1e-4)
    ),
    "num_dense": 1,
    "dropout_prob": 0.15,
}

gen_net = bf.networks.InvertibleNetwork(
    num_params=int(28 * 28),
    num_coupling_layers=12,
    coupling_settings=coupling_settings
)

#latent distribution
dim = int(28 * 28)
loc = [0.0] * dim
scale = tf.linalg.LinearOperatorDiag([1.0] * dim)
latent_dist = tfd.MultivariateStudentTLinearOperator(df=100, loc=loc, scale=scale)


amortized_generator = bf.amortizers.AmortizedPosterior(
    gen_net,
    summary_net=None,
    latent_dist=latent_dist
)


trainer = bf.trainers.Trainer(
    amortizer=amortized_generator,
    checkpoint_path="./checkpoints/gen_zero_1_v2",
    configurator=configurator,
    default_lr=1e-3,
    memory=False,
    max_to_keep=1,
)

#comment out after training
history = trainer.train_offline(forward_train, epochs=120, batch_size=32, validation_sims=forward_val)
