import os
import numpy as np
from torchvision import datasets, transforms
from PIL import Image
from tqdm import tqdm
from multiprocessing import cpu_count, Pool

class GaussianDataset:
    def __init__(self, dataset, save_dir):
        self.dataset = dataset
        self.save_dir = save_dir

    def generate_gaussian_image(self, data):
        idx, (image, label) = data
        width, height = image.size
        channels = 3 if image.mode == 'RGB' else 1
        
        class_dir = os.path.join(self.save_dir, f'class_{label}')
        os.makedirs(class_dir, exist_ok=True)

        gaussian_image = np.random.randn(height, width, channels).astype(np.float32)

        img = Image.fromarray((gaussian_image * 255).clip(0, 255).astype(np.uint8), mode='RGB' if channels == 3 else 'L')
        img.save(os.path.join(class_dir, f'gaussian_image_{idx}.png'))

    def generate_and_save_images(self):
        num_cores = cpu_count()
        
        with Pool(num_cores) as pool:
            list(tqdm(pool.imap(self.generate_gaussian_image, enumerate(self.dataset)), total=len(self.dataset), desc='Generating images'))

dataset_path = '/root/autodl-tmp/img2img_unlearning/diffusion_data/place2_forget_50_5000_100/train'
save_dir = '/root/autodl-tmp/img2img_unlearning/diffusion_data/gaussian_images_multi_5000_100'

os.makedirs(save_dir, exist_ok=True)

dataset_forget = datasets.ImageFolder(dataset_path, transform=None)

new_gaussian_dataset = GaussianDataset(dataset_forget, save_dir)
new_gaussian_dataset.generate_and_save_images()

new_dataset = datasets.ImageFolder(save_dir, transform=transforms.ToTensor())

print(f'Number of images in the new dataset: {len(new_dataset)}')
print(new_dataset)