import os
import os.path as osp
import random
from PIL import Image
import argparse

IMG_EXT = ('.jpg', '.jpeg', '.png', '.bmp')


def load_image_paths(class_dir):
    return sorted([
        osp.join(class_dir, f)
        for f in os.listdir(class_dir)
        if f.lower().endswith(IMG_EXT)
    ])


def main(args):
    random.seed(args.seed)

    # ---------- load ImageNet class order ----------
    with open(args.class_indices, 'r') as f:
        all_classes = [x.strip() for x in f.readlines()]

    phase = max(0, args.phase)
    cls_from = args.nclass * phase
    cls_to = args.nclass * (phase + 1)
    sel_classes = all_classes[cls_from:cls_to]

    print(f"Phase {phase}: selecting classes [{cls_from}, {cls_to})")
    print(f"Total classes in this phase: {len(sel_classes)}")

    os.makedirs(args.save_dir, exist_ok=True)

    # ---------- per-class random selection ----------
    for class_id, class_name in enumerate(sel_classes):
        class_dir = osp.join(args.data_root, class_name)
        if not osp.isdir(class_dir):
            print(f"[Skip] {class_name}: directory not found.")
            continue

        img_paths = load_image_paths(class_dir)
        if len(img_paths) == 0:
            print(f"[Skip] {class_name}: no images.")
            continue

        k = min(args.num_per_class, len(img_paths))
        selected = random.sample(img_paths, k)

        save_class_dir = osp.join(args.save_dir, class_name)
        os.makedirs(save_class_dir, exist_ok=True)

        for i, pth in enumerate(selected):
            img = Image.open(pth).convert("RGB")
            # 文件名风格和你原脚本一致
            filename = f"{(cls_from + class_id) * args.num_per_class + i:06d}.png"
            img.save(osp.join(save_class_dir, filename))

        print(f"[Done] {class_name}: selected {k} images.")

    print("All classes processed.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data-root", type=str, required=True,
        help="Root dir of ImageNet-style dataset (class_name/xxx.jpg)"
    )
    parser.add_argument(
        "--class-indices", type=str,
        default="./misc/class_indices.txt",
        help="ImageNet class order file"
    )
    parser.add_argument("--save-dir", type=str, default="./random_100_per_class")
    parser.add_argument("--num-per-class", type=int, default=100)
    parser.add_argument("--nclass", type=int, default=10)
    parser.add_argument("--phase", type=int, default=0)
    parser.add_argument("--seed", type=int, default=0)

    args = parser.parse_args()
    main(args)
