from PIL import Image 
from torchvision.utils import save_image, make_grid
import os 
import torchvision.transforms as T
import torch

saved_folder = 'DATASETS/concat_laion_11k_64'
folder_path = 'DATASETS/preprocessed_laion_11k/'
image_path_list = os.listdir(folder_path)
image_path_list = [folder_path + p for p in image_path_list if 'txt' not in p and 'original' in p]
image_path_list.sort()

image_shrink_size = 64
image_number = (512**2)//(image_shrink_size**2)
count = 0 
total_save_image = 50
for i in range(0,len(image_path_list),image_number):
    image_array = []
    for image_path in image_path_list[i:i+image_number]:
        #print("image_path:{}".format(image_path))
        img = Image.open(image_path).convert('RGB').resize(size=(image_shrink_size,image_shrink_size))
        img = T.ToTensor()(img)
        image_array.append(img.unsqueeze(dim=0))
    
    image_array = torch.cat(image_array,dim=0)
    #print("image array:{}".format(image_array.shape))
    image_grid = make_grid(image_array,nrow=(512//image_shrink_size),padding=0)
    save_image(image_grid,fp='{}/{}.png'.format(saved_folder,count))
    count += 1
    if count == total_save_image:
        break
