import os
import numpy as np
import tensorflow as tf
import bayesflow as bf
from prior_blur import amortized_generator

np.random.seed(42)
tf.random.set_seed(42)

input_dict = {
    "parameters": tf.zeros([1, 784]),
    "direct_conditions": tf.zeros([1, 784])
}

train_images_12k = amortized_generator.sample(input_dict=input_dict, n_samples=12000).reshape(-1, 28, 28)
test_images = amortized_generator.sample(input_dict=input_dict, n_samples=1000).reshape(-1, 28, 28)

train_images_12k = np.clip(train_images_12k, -1.0, 1.0)
test_images = np.clip(test_images, -1.0, 1.0)

train_images_12k = train_images_12k[..., tf.newaxis]  #(num_samples, 28, 28, 1)
test_images = test_images[..., tf.newaxis]

data_dir = "./sim_data"
os.makedirs(data_dir, exist_ok=True)

np.savez_compressed(f"{data_dir}/train_images_12k_1.npz", train_images=train_images_12k)
np.savez_compressed(f"{data_dir}/test_images_1k_1.npz", test_images=test_images)
