import os
import random
import shutil

path = "/root/autodl-tmp/train_256_places365standard/data_256"

image_extensions = ['.jpeg', '.jpg', '.png', '.bmp', '.gif']

def is_image_file(filename):
    return any(filename.lower().endswith(ext) for ext in image_extensions)

def get_classes(dataset_path):
    class_indices = {}
    for letter in os.listdir(dataset_path):
        letter_path = os.path.join(dataset_path, letter)
        if os.path.isdir(letter_path):
            for class_dir in os.listdir(letter_path):
                class_path = os.path.join(letter_path, class_dir)
                if os.path.isdir(class_path):
                    for sub_path in os.listdir(class_path):
                        sub_full_path = os.path.join(class_path, sub_path)
                        if os.path.isdir(sub_full_path):
                            for item in os.listdir(sub_full_path):
                                if is_image_file(item):
                                    class_label = f"{class_dir}-{sub_path}"
                                    break
                            else:
                                continue
                        else:
                            if is_image_file(sub_full_path):
                                class_label = class_dir
                            else:
                                continue
                        if class_label not in class_indices:
                            class_indices[class_label] = len(class_indices)
    return class_indices

class_indices = get_classes(imagenet_path)
classes = list(class_indices.keys())

random.seed(42)

classes_num = 10

selected_classes = random.sample(classes, classes_num)

forget_classes = selected_classes[:5]
retain_classes = selected_classes[5:]

print(f'f class length: {len(forget_classes)}, f class: {forget_classes}')
print(f'r class length: {len(retain_classes)}, r class: {retain_classes}')

forget_dataset_base_path = "/root/autodl-tmp/img2img_unlearning/diffusion_data/place2_forget_test"  
retain_dataset_base_path = "/root/autodl-tmp/img2img_unlearning/diffusion_data/place2_retain_test"  

def copy_images(dataset_path, selected_classes, target_dataset_base_path, num_train, num_val):
    for class_label in selected_classes:
        train_dir = os.path.join(target_dataset_base_path, 'train', class_label)
        valid_dir = os.path.join(target_dataset_base_path, 'valid', class_label)
        os.makedirs(train_dir, exist_ok=True)
        os.makedirs(valid_dir, exist_ok=True)

        class_folders = class_label.split('-')
        class_images = []
        for letter in class_folders[0]:
            letter_path = os.path.join(dataset_path, letter)
            if os.path.isdir(letter_path):
                class_dir = os.path.join(letter_path, *class_folders)
                if os.path.isdir(class_dir):
                    class_images.extend([os.path.join(class_dir, img) for img in os.listdir(class_dir) if is_image_file(img)])

        random.shuffle(class_images)
        
        train_images = class_images[:num_train]
        valid_images = class_images[num_train:num_train + num_val]
        
        for img_path in train_images:
            shutil.copy(img_path, train_dir)
        
        for img_path in valid_images:
            shutil.copy(img_path, valid_dir)

copy_images(path, forget_classes, forget_dataset_base_path, 2, 1)
copy_images(path, retain_classes, retain_dataset_base_path, 2, 1)

print("Forget and retain datasets created successfully.")
