import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from datasets import load_dataset
import PIL.Image
import numpy as np
import tqdm

import os
import itertools

# set to cuda
torch_device = "cuda:7" if torch.cuda.is_available() else "cpu"

num_samples = 30000
batch_size = 100

#--------------------------------------------------------------------------
# Dataset transformation
transform = transforms.Compose([transforms.ToTensor()])
scaler = lambda x: 2. * x - 1.

dataset = "ffhq"
out_dir = "" #please specify out_dir as required

if dataset == 'celeba':
    ds = load_dataset("korexyz/celeba-hq-256x256")
elif dataset == 'ffhq':
    ds = load_dataset("merkol/ffhq-256")

class getDataset(Dataset):

    def __init__(self, dataset, transform = None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image = sample['image']
        if self.transform:
            image = self.transform(image)
        return image

transformed_ds = getDataset(ds['train'], transform = transform)

dataloader = DataLoader(transformed_ds, batch_size=batch_size, shuffle=False, drop_last = False)
data_iter = itertools.cycle(dataloader)

for i in tqdm.tqdm(range(num_samples//batch_size)):

    images = next(data_iter)
    images = scaler(images) #[-1,1] normalization
    
    ## Save images.
    images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()

    count = 0

    for image_np in images_np:
        image_path = os.path.join(out_dir, f'{i*batch_size+count:06d}.png')
        count += 1
        PIL.Image.fromarray(image_np, 'RGB').save(image_path)