import os
import numpy as np
import tensorflow as tf
import bayesflow as bf
from tensorflow_probability import distributions as tfd
from scipy.ndimage import gaussian_filter
from skimage.util import random_noise
import matplotlib.pyplot as plt
from configurator import grayscale_camera, configurator
from networks import likelihood_net, posterior_net, summary_network_p, summary_network_l
from amortizers import AmortizedPosteriorLikelihoodSC
from schedules import BatchCyclingSchedule, LinearSchedule, ConstantSchedule, ZeroOneSchedule, RampSchedule
from prior_blur import amortized_generator

# Load simulated data from the saved directory
data_dir = "./sim_data"

train_images_npz = np.load(f"{data_dir}/train_images_12k_1.npz", allow_pickle=True)
test_images_npz  = np.load(f"{data_dir}/test_images_1k_1.npz", allow_pickle=True)
train_images = train_images_npz["train_images"]
test_images  = test_images_npz["test_images"]

forward_train = {"prior_draws": train_images, "sim_data": train_images}
forward_test = {"prior_draws": test_images, "sim_data": test_images,}

mnist = tf.keras.datasets.mnist
(train_images_m, train_labels), (test_images_m, test_labels) = mnist.load_data()

# Filter training set: keep only digit 0
train_mask = train_labels == 0
train_images_filtered = train_images_m[train_mask]
train_labels_filtered = train_labels[train_mask]

test_mask = test_labels == 0
test_images_filtered = test_images_m[test_mask]
test_labels_filtered = test_labels[test_mask]

def clip_255(img):
    return img.astype(np.float32)/127.5 - 1.0

train_images_filtered = clip_255(train_images_filtered) # normalise to [-1, 1]
test_images_filtered = clip_255(test_images_filtered) # normalise to [-1, 1]

sc_images = train_images_filtered[..., tf.newaxis]
sc_images_test = test_images_filtered[..., tf.newaxis]

def config_real(f):
    out = {}

    B = f.shape[0]
    H = f.shape[1]
    W = f.shape[2]
    
    blurred = np.stack([grayscale_camera(f[b]) for b in range(B)]).astype(
        np.float32
    )

    return blurred

modified_images = config_real(sc_images_test)


forward_sc_test = {"prior_draws": sc_images_test, "sim_data": sc_images_test}

dim = int(28 * 28)
loc = [0.0] * dim
scale = tf.linalg.LinearOperatorDiag([1.0] * dim)
latent_dist = tfd.MultivariateStudentTLinearOperator(df=100, loc=loc, scale=scale)

amortized_posterior = bf.amortizers.AmortizedPosterior(
    posterior_net,
    summary_net=summary_network_p,
    latent_dist=latent_dist,
    summary_loss_fun="MMD",
)

amortized_likelihood = bf.amortizers.AmortizedPosterior(
    likelihood_net, summary_net=summary_network_l, latent_dist=latent_dist
)

amortizer = AmortizedPosteriorLikelihoodSC(prior = amortized_generator,
                                           real_data = BatchCyclingSchedule(real_data=modified_images, batch_size=16, 
                                                                           steps_per_batch=64, init_steps=8000),
                                           lambda_schedule= RampSchedule(max_steps=15000, init_step=8000, max_val=1),
                                           n_consistency_samples=32,
                                           amortized_posterior = amortized_posterior,
                                           amortized_likelihood = amortized_likelihood, 
                                           output_numpy=False)

trainer = bf.trainers.Trainer(
    amortizer=amortizer,
    checkpoint_path="./checkpoints/sc_12k_1_n0",
    configurator=configurator,
    default_lr=1e-4,
    memory=False,
    max_to_keep=1,
)


# comment out after training
h = trainer.train_offline(forward_train, epochs=100, batch_size=32)
