#this code is for generating multiple images for evaluation purposes. 
from diffusers import UNet2DModel, DDIMScheduler, VQModel
import torch
import tensorflow as tf 
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from datasets import load_dataset
import PIL.Image
import numpy as np
import tqdm

import os
import itertools

import io

import torchvision.models as models

#--------------------------------------------------------------------------
# Parameters initialization

#initial seeds
seed1 = 101

num_samples = 15000
batch_size = 100

steps = t = 500
eta = 0 #DDIM update only. 

dir_path = "" #please specify as required


# set to cuda
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

#load all models
unet = UNet2DModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="unet")
vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
scheduler = DDIMScheduler.from_config("CompVis/ldm-celebahq-256", subfolder="scheduler")

unet.to(torch_device)
vqvae.to(torch_device)

# set inference steps for DDIM
scheduler.set_timesteps(num_inference_steps=t)
#--------------------------------------------------------------------------

#--------------------------------------------------------------------------
#sampling function
for j in tqdm.tqdm(range(num_samples//batch_size)):

    i = 0 #counts number of iterations (forward)

    seed1 += 1

    # generate gaussian noise to be decoded
    generator1 = torch.manual_seed(seed1)

    noise = torch.randn(
    (batch_size, unet.in_channels, unet.sample_size, unet.sample_size),
    generator=generator1,
    ).to(torch_device)
    image = noise #initialization x for score network - x_cur

    for t in tqdm.tqdm(scheduler.timesteps):

        # predict noise residual of previous image
        with torch.no_grad():
            residual = unet(image, t)["sample"]

        # compute previous image x_t according to DDIM formula
        prev_image = scheduler.step(residual, t, image, eta=eta)["prev_sample"] #equivalent to x_next

        # x_t-1 -> x_t
        image = prev_image

    ## Save images.
    images_np = image.cpu().numpy()
    
    count = 0

    r = j+100
    
    with tf.io.gfile.GFile(os.path.join(dir_path, f"samples_{r}.npz"), "wb") as fout:
        io_buffer = io.BytesIO()
        np.savez_compressed(io_buffer, samples=images_np)
        fout.write(io_buffer.getvalue())

