import os
import random
import shutil
from torchvision.datasets import ImageFolder

path = "/root/autodl-tmp/imagenet/train"  # 修改为你的ImageNet数据集路径

imagenet_dataset = ImageFolder(path)

class_indices = imagenet_dataset.class_to_idx
classes = list(class_indices.keys())

random.seed(42)

classes_num = 200

selected_classes = random.sample(classes, classes_num)

forget_classes = selected_classes[:100]
retain_classes = selected_classes[100:]

forget_dataset_base_path = "/root/autodl-tmp/img2img_unlearning/mage-main/imagenet_forget"  
retain_dataset_base_path = "/root/autodl-tmp/img2img_unlearning/mage-main/imagenet_retain"  

def copy_images(selected_classes, target_dataset_base_path, num_train, num_val):
    for class_name in selected_classes:
        train_dir = os.path.join(target_dataset_base_path, 'train', class_name)
        valid_dir = os.path.join(target_dataset_base_path, 'valid', class_name)
        os.makedirs(train_dir, exist_ok=True)
        os.makedirs(valid_dir, exist_ok=True)

    for class_name in selected_classes:
        class_index = class_indices[class_name]
        image_indices = [idx for idx, target in enumerate(imagenet_dataset.targets) if target == class_index]
        random.shuffle(image_indices)
        
        train_indices = image_indices[:num_train]
        valid_indices = image_indices[num_train:num_train + num_val]
        
        for idx in train_indices:
            img_path, _ = imagenet_dataset.imgs[idx]
            shutil.copy(img_path, os.path.join(target_dataset_base_path, 'train', class_name))
        
        for idx in valid_indices:
            img_path, _ = imagenet_dataset.imgs[idx]
            shutil.copy(img_path, os.path.join(target_dataset_base_path, 'valid', class_name))

copy_images(forget_classes, forget_dataset_base_path, 100, 50)

copy_images(retain_classes, retain_dataset_base_path, 100, 50)

print("Forget and retain datasets created successfully.")