import os
import numpy as np
import tensorflow as tf
import bayesflow as bf
from tensorflow_probability import distributions as tfd
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
from configurator import grayscale_camera, configurator
from networks import likelihood_net, posterior_net, summary_network_p, summary_network_l


# Load simulated data from the saved directory
data_dir = "./sim_data"

train_images_npz = np.load(f"{data_dir}/train_images_12k_1.npz", allow_pickle=True)
test_images_npz  = np.load(f"{data_dir}/test_images_1k_1.npz", allow_pickle=True)
train_images = train_images_npz["train_images"]
test_images  = test_images_npz["test_images"]

forward_train = {"prior_draws": train_images, "sim_data": train_images}
forward_test = {"prior_draws": test_images, "sim_data": test_images,}

mnist = tf.keras.datasets.mnist
(train_images_m, train_labels), (test_images_m, test_labels) = mnist.load_data()

# Filter training set: keep only digit 0
train_mask = train_labels == 0
train_images_filtered = train_images_m[train_mask]
train_labels_filtered = train_labels[train_mask]

test_mask = test_labels == 0
test_images_filtered = test_images_m[test_mask]
test_labels_filtered = test_labels[test_mask]

def clip_255(img):
    return img.astype(np.float32)/127.5 - 1.0

train_images_filtered = clip_255(train_images_filtered) # normalise to [-1, 1]
test_images_filtered = clip_255(test_images_filtered) # normalise to [-1, 1]

sc_images = train_images_filtered[..., tf.newaxis]
sc_images_test = test_images_filtered[..., tf.newaxis]

forward_sc_test = {"prior_draws": sc_images_test, "sim_data": sc_images_test}

dim = int(28 * 28)
loc = [0.0] * dim
scale = tf.linalg.LinearOperatorDiag([1.0] * dim)
latent_dist = tfd.MultivariateStudentTLinearOperator(df=100, loc=loc, scale=scale)

amortized_posterior = bf.amortizers.AmortizedPosterior(
    posterior_net,
    summary_net=summary_network_p,
    latent_dist=latent_dist,
    summary_loss_fun="MMD",
)

amortized_likelihood = bf.amortizers.AmortizedPosterior(
    likelihood_net, summary_net=summary_network_l, latent_dist=latent_dist
)

amortizer = bf.amortizers.AmortizedPosteriorLikelihood(
    amortized_posterior, amortized_likelihood
)
trainer = bf.trainers.Trainer(
    amortizer=amortizer,
    checkpoint_path="./checkpoints/nple_1",
    configurator=configurator,
    default_lr=1e-4,
    memory=False,
    max_to_keep=1,
)

# comment out after training
h = trainer.train_offline(forward_train, epochs=100, batch_size=32)
