#!/usr/bin/env python3
import os
import random
import shutil

# 源目录：已优化合并的数据
src_root = '/fs/scratch/PAS2099/Jiacheng/Places_merge/output/merged_optimized_data_v5'
# 目标目录：平衡后（每类最多 30 张）
dst_root = '/fs/scratch/PAS2099/Jiacheng/Places_merge/output/balanced_data_v4'

# 如无则创建目标根目录
os.makedirs(dst_root, exist_ok=True)

# 可选：为了结果可复现，设置随机种子
random.seed(42)

# 遍历每个标签文件夹
for label in os.listdir(src_root):
    src_label_dir = os.path.join(src_root, label)
    if not os.path.isdir(src_label_dir):
        continue

    # 筛选图片文件
    images = [
        fname for fname in os.listdir(src_label_dir)
        if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
    ]

    # 随机抽样：超过 30 张则抽 30 张，否则全部保留
    if len(images) > 30:
        selected = random.sample(images, 30)
    else:
        selected = images

    # 目标标签文件夹
    dst_label_dir = os.path.join(dst_root, label)
    os.makedirs(dst_label_dir, exist_ok=True)

    # 复制选中的图片
    for fname in selected:
        src_path = os.path.join(src_label_dir, fname)
        dst_path = os.path.join(dst_label_dir, fname)
        shutil.copy2(src_path, dst_path)

    print(f'Label "{label}": {len(images)} → {len(selected)} images')

print('平衡数据集构建完成，位于：', dst_root)
