import os
import numpy as np
from torchvision import datasets, transforms
from PIL import Image
from tqdm import tqdm
from multiprocessing import cpu_count, Pool

# 定义高斯数据集的类
class GaussianDataset:
    def __init__(self, dataset, save_dir):
        self.dataset = dataset
        self.save_dir = save_dir

    def generate_gaussian_image(self, data):
        idx, (image, label) = data
        # 获取当前图片的尺寸信息
        width, height = image.size
        channels = 3 if image.mode == 'RGB' else 1
        
        # 创建对应类别的目录
        class_dir = os.path.join(self.save_dir, f'class_{label}')
        os.makedirs(class_dir, exist_ok=True)

        # 生成与当前图片尺寸匹配的高斯噪声图像
        gaussian_image = np.random.randn(height, width, channels).astype(np.float32)

        # 转换图像为PIL Image格式并保存
        img = Image.fromarray((gaussian_image * 255).clip(0, 255).astype(np.uint8), mode='RGB' if channels == 3 else 'L')
        img.save(os.path.join(class_dir, f'gaussian_image_{idx}.png'))  # 保存图像

    def generate_and_save_images(self):
        # 获取可用的CPU核心数量
        num_cores = cpu_count()
        
        # 使用多进程加速图像处理
        with Pool(num_cores) as pool:
            list(tqdm(pool.imap(self.generate_gaussian_image, enumerate(self.dataset)), total=len(self.dataset), desc='Generating images'))

# 定义数据集路径和保存路径
dataset_path = '/root/autodl-tmp/img2img_unlearning/mae-main/imagenet_forget_10'
save_dir = '/root/autodl-tmp/img2img_unlearning/mae-main/gaussian_images_multi'

# 检查并创建保存目录
os.makedirs(save_dir, exist_ok=True)

# 加载原始数据集
dataset_forget = datasets.ImageFolder(dataset_path, transform=None)

# 创建并保存高斯图像
new_gaussian_dataset = GaussianDataset(dataset_forget, save_dir)
new_gaussian_dataset.generate_and_save_images()

# 使用ImageFolder加载新的高斯数据集
new_dataset = datasets.ImageFolder(save_dir, transform=transforms.ToTensor())

# 输出新数据集的信息
print(f'Number of images in the new dataset: {len(new_dataset)}')
print(new_dataset)