import os
import random
import shutil

mode = 'r'
if mode == 'f':
    imagenet_path = "/root/autodl-tmp/img2img_unlearning/mae-mage-data/imagenet_forget_100/train"
else:
    imagenet_path = "/root/autodl-tmp/img2img_unlearning/mae-mage-data/imagenet_retain_100/train"

image_extensions = ['.jpeg', '.jpg', '.png', '.bmp', '.gif']

def is_image_file(filename):
    return any(filename.lower().endswith(ext) for ext in image_extensions)

def copy_sample_images(dataset_path, target_dataset_base_path, num_samples):
    os.makedirs(target_dataset_base_path, exist_ok=True)
    
    index = 0
    for class_dir in os.listdir(dataset_path):
        if index > 10:
            break
        index += 1
        class_path = os.path.join(dataset_path, class_dir)
        
        if os.path.isdir(class_path):
            class_images = [os.path.join(class_path, img) for img in os.listdir(class_path) if is_image_file(img)]
            
            if len(class_images) < num_samples:
                print(f"Not enough images in class {class_dir} to sample {num_samples}. Skipping.")
                continue

            selected_images = random.sample(class_images, num_samples)
            
            target_class_dir = os.path.join(target_dataset_base_path, class_dir)
            os.makedirs(target_class_dir, exist_ok=True)
            
            for img_path in selected_images:
                shutil.copy(img_path, target_class_dir)

random.seed(42)

num_samples = 100

if mode == 'f':
    target_dataset_base_path = "/root/autodl-tmp/img2img_unlearning/mae-mage-data/imagenet_forget_100"
else:
    target_dataset_base_path = "/root/autodl-tmp/img2img_unlearning/mae-mage-data/imagenet_retain_100"

copy_sample_images(imagenet_path, target_dataset_base_path, num_samples)

print("Sampled images dataset created successfully.")