import os
import numpy as np
from torchvision import datasets, transforms
from PIL import Image
from tqdm import tqdm
from multiprocessing import Pool, cpu_count

def process_and_save_image(image_label_tuple):
    idx, (image, label) = image_label_tuple
    
    channels = 3 if image.mode == 'RGB' else 1
    width, height = image.size
    
    gaussian_image = np.random.normal(loc=0.0, scale=1.0, size=(height, width, channels) if channels == 3 else (height, width))
    
    gaussian_image_uint8 = np.clip(gaussian_image * 255, 0, 255).astype(np.uint8)

    class_dir = os.path.join(save_dir, f'class_{label}')

    img = Image.fromarray(gaussian_image_uint8, mode='RGB' if channels == 3 else 'L')
    
    img.save(os.path.join(class_dir, f'gaussian_image_{idx}.png'))

dataset_forget = datasets.ImageFolder('/root/autodl-tmp/img2img_unlearning/diffusion_data/place2_forget_50_5000_100/train', transform=None)

save_dir = '/root/autodl-tmp/img2img_unlearning/diffusion_data/gaussian_images_single_5000_100'
os.makedirs(save_dir, exist_ok=True)

class_dirs = set()
for _, label in dataset_forget.imgs:
    class_dir = os.path.join(save_dir, f'class_{label}')
    if class_dir not in class_dirs:
        os.makedirs(class_dir, exist_ok=True)
        class_dirs.add(class_dir)

num_processes = cpu_count() 
print(f'Using {num_processes} processes for image generation.')

with Pool(processes=num_processes) as pool:
    for _ in tqdm(pool.imap_unordered(process_and_save_image, enumerate(dataset_forget)), total=len(dataset_forget), desc='Generating images'):
        pass

transform = transforms.Compose([
    transforms.ToTensor(),
])

new_dataset = datasets.ImageFolder(save_dir, transform=transform)
print(len(new_dataset))