import os
import shutil
import torch
import argparse

def select_images(index_file, source_dirs, target_dir, is_syn=False):
    indexes = torch.load(index_file)
    os.makedirs(target_dir, exist_ok=True)

    for source_dir in source_dirs:
        if not os.path.exists(source_dir):
            continue

        for root, _, files in os.walk(source_dir):
            for file in files:
                if is_syn:
                    file_parts = file.split('_')
                    file_index = int(file_parts[-3])
                else:
                    file_index = int(file.split('.')[0])

                if file_index in indexes:
                    file_path = os.path.join(root, file)
                    category = os.path.basename(root)
                    category_dir = os.path.join(target_dir, category)
                    os.makedirs(category_dir, exist_ok=True)
                    shutil.copy(file_path, category_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Select images based on indexes.')
    parser.add_argument('--index_file', required=True, help='Path to the index file (.pt)')
    parser.add_argument('--org_dir', required=True, help='Path to the org directory')
    parser.add_argument('--syn_dir', required=True, help='Path to the synthetic data directory')
    parser.add_argument('--output_dir', required=True, help='Path to the output directory')

    args = parser.parse_args()

    source_dirs = [os.path.join(args.org_dir, 'train')]
    val_dir = os.path.join(args.org_dir, 'val')
    if os.path.exists(val_dir):
        source_dirs.append(val_dir)

    select_images(args.index_file, source_dirs, os.path.join(args.output_dir, 'org'))
    select_images(args.index_file, [args.syn_dir], os.path.join(args.output_dir, 'syn'), is_syn=True)
