import os
import numpy as np
from torchvision import datasets, transforms
from PIL import Image
from tqdm import tqdm
from multiprocessing import Pool, cpu_count

# 定义处理单张图像的函数
def process_and_save_image(image_label_tuple):
    idx, (image, label) = image_label_tuple
    
    # 获取当前图片的尺寸信息
    channels = 3 if image.mode == 'RGB' else 1
    width, height = image.size
    
    # 生成与当前图片尺寸匹配的高斯噪声图像
    gaussian_image = np.random.normal(loc=0.0, scale=1.0, size=(height, width, channels) if channels == 3 else (height, width))
    
    # 将值缩放到0-255并转换为uint8
    gaussian_image_uint8 = np.clip(gaussian_image * 255, 0, 255).astype(np.uint8)

    # 获取保存目录
    class_dir = os.path.join(save_dir, f'class_{label}')

    # 根据channels创建对应的PIL.Image对象
    img = Image.fromarray(gaussian_image_uint8, mode='RGB' if channels == 3 else 'L')
    
    # 保存高斯噪声图像
    img.save(os.path.join(class_dir, f'gaussian_image_{idx}.png'))

# 加载原始数据集
dataset_forget = datasets.ImageFolder('/root/autodl-tmp/img2img_unlearning/mae-main/imagenet_forget_10', transform=None)

# 创建并保存高斯图像文件夹
save_dir = '/root/autodl-tmp/img2img_unlearning/mae-main/gaussian_images_single'
os.makedirs(save_dir, exist_ok=True)

# 预先创建类别目录
class_dirs = set()
for _, label in dataset_forget.imgs:
    class_dir = os.path.join(save_dir, f'class_{label}')
    if class_dir not in class_dirs:
        os.makedirs(class_dir, exist_ok=True)
        class_dirs.add(class_dir)

# 使用多进程进行图像处理
num_processes = cpu_count()  # 可以调整为更适合您机器的进程数
print(f'Using {num_processes} processes for image generation.')

# 创建一个进程池
with Pool(processes=num_processes) as pool:
    # 使用 tqdm 获取进度条
    for _ in tqdm(pool.imap_unordered(process_and_save_image, enumerate(dataset_forget)), total=len(dataset_forget), desc='Generating images'):
        pass

# 定义转换，以确保加载图像时的维度正确
transform = transforms.Compose([
    transforms.ToTensor(),
])

# 使用ImageFolder加载新的高斯数据集
new_dataset = datasets.ImageFolder(save_dir, transform=transform)
print(len(new_dataset))