import os
import shutil
import random
from tqdm import tqdm

src_root = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/merged_dataset_v2'
dst_root = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/merged_dataset_v3'
valid_exts = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff'}

os.makedirs(dst_root, exist_ok=True)

for attr in tqdm(os.listdir(src_root), desc="Processing attributes"):
    src_attr_dir = os.path.join(src_root, attr)
    if not os.path.isdir(src_attr_dir):
        continue

    # 筛选出所有有效图片
    images = [f for f in os.listdir(src_attr_dir)
              if os.path.splitext(f)[1].lower() in valid_exts]

    num_images = len(images)

    if num_images < 120:
        # 跳过该 attribute（不复制，等于删除）
        continue

    # 随机保留 480 张图像
    if num_images > 480:
        images = random.sample(images, 480)

    # 创建目标 attribute 文件夹
    dst_attr_dir = os.path.join(dst_root, attr)
    os.makedirs(dst_attr_dir, exist_ok=True)

    for fname in images:
        src_path = os.path.join(src_attr_dir, fname)
        dst_path = os.path.join(dst_attr_dir, fname)
        shutil.copy2(src_path, dst_path)
