import os
import random
import shutil
from pathlib import Path


def create_dataset_subset(source_dir, target_dir, n_samples):
    """
    从源目录创建数据集子集

    参数:
    source_dir (str): 源数据集目录路径
    target_dir (str): 目标数据集目录路径
    n_samples (int): 每个子目录要抽取的图片数量
    """

    # 确保源目录存在
    if not os.path.exists(source_dir):
        raise ValueError(f"源目录 {source_dir} 不存在")

    # 创建目标根目录
    os.makedirs(target_dir, exist_ok=True)

    # 遍历源目录中的所有子目录
    for subdir in os.listdir(source_dir):
        source_subdir = os.path.join(source_dir, subdir)

        # 跳过非目录文件
        if not os.path.isdir(source_subdir):
            continue

        # 获取当前子目录中的所有JPEG图片
        jpeg_files = [f for f in os.listdir(source_subdir)
                      if f.lower().endswith(('.jpg', '.jpeg'))]

        # 如果图片数量不足n张，打印警告
        if len(jpeg_files) < n_samples:
            print(f"警告: 目录 {subdir} 中只有 {len(jpeg_files)} 张图片，少于要求的 {n_samples} 张")
            selected_files = jpeg_files  # 如果图片不足，使用所有图片
        else:
            # 随机选择n张图片
            selected_files = random.sample(jpeg_files, n_samples)

        # 创建目标子目录
        target_subdir = os.path.join(target_dir, subdir)
        os.makedirs(target_subdir, exist_ok=True)

        # 复制选中的图片
        for filename in selected_files:
            source_file = os.path.join(source_subdir, filename)
            target_file = os.path.join(target_subdir, filename)
            shutil.copy2(source_file, target_file)

        print(f"已处理目录 {subdir}: 复制了 {len(selected_files)} 张图片")


def main():
    # 设置源目录和目标目录的路径
    source_dir = "/mnt/data1/DATA/ImageNet2012/val"  # 源数据集目录
    target_dir = "/data/ImageNet_5k/val"  # 目标数据集目录
    n_samples = 5  # 每个子目录要抽取的图片数量

    try:
        create_dataset_subset(source_dir, target_dir, n_samples)
        print(f"成功创建数据集子集，保存在目录 {target_dir}")
    except Exception as e:
        print(f"创建数据集子集时发生错误: {str(e)}")


if __name__ == "__main__":
    main()