#code to save the real data samples

import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from datasets import load_dataset
import PIL.Image
import numpy as np
import tqdm

import os
import io
import itertools

#stores dataset samples in the latent representation as .npz format
out_dir = "" #please specify as required

#set to cuda
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

#take training dataset size to be 10000
num_samples = 25000
batch_size = 100

#--------------------------------------------------------------------------
# Dataset transformation
# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]) # if this, don't use scalar function

transform = transforms.Compose([transforms.ToTensor()])
scaler = lambda x: 2. * x - 1.
ds = load_dataset("korexyz/celeba-hq-256x256")

class CelebAHQDataset(Dataset):

    def __init__(self, dataset, transform = None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image = sample['image']
        if self.transform:
            image = self.transform(image)
        return image

transformed_ds = CelebAHQDataset(ds['train'], transform = transform)

dataloader = DataLoader(transformed_ds, batch_size=batch_size, shuffle=True, drop_last = True)
data_iter = itertools.cycle(dataloader)
#--------------------------------------------------------------------------

#--------------------------------------------------------------------------
#load VAE model
from diffusers import VQModel
vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
vqvae.to(torch_device)
#--------------------------------------------------------------------------
real_latents = []

for i in tqdm.tqdm(range(num_samples//batch_size)):

    images_real = next(data_iter).to(torch_device)
    images_real = scaler(images_real).to(torch_device)

    with torch.no_grad():
        images_real = vqvae.encode(images_real)[0]

    images_np = images_real.cpu().numpy()

    real_latents.append(images_np)

real_latents = np.asarray(real_latents)
real_latents = real_latents.reshape(num_samples, 3, 64, 64)

np.savez(out_dir + "/real_latents", real_latents)

#checking out the code.
real = np.load(out_dir +"/real_latents.npz")['arr_0']
print(real.shape)
