#!/usr/bin/env python3
"""
filter_mmeb_by_existing_images.py

Usage
-----
python filter_mmeb_by_existing_images.py \
    --root_dir /absolute/path/where/you/unzipped/MMEB-train-images \
    --save_dir ./mmeb_filtered        # optional; omit to keep in memory only
"""

import argparse
import os
from pathlib import Path

from datasets import load_dataset, DatasetDict

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Filter TIGER-Lab/MMEB-train rows whose qry_image_path exists."
    )
    parser.add_argument(
        "--root_dir",
        required=True,
        type=Path,
        help="Directory that contains the images (e.g. the folder that was created after running unzip_file.py).",
    )
    parser.add_argument(
        "--subset",
        default="ImageNet_1K",
        type=str,
        help="Name of the dataset to load (e.g. 'ImageNet_1K').",
    )

    parser.add_argument(
        "--repo_id",
        default=None,
        type=str,
        help="If provided, push the filtered dataset to this Hugging Face Hub repo ID.",
    )
    return parser.parse_args()

def main() -> None:
    args = parse_args()
    root_dir: Path = args.root_dir.expanduser().resolve()

    # 1. Load the full MMEB-train dataset (all 20 sub-tasks, all splits)
    print("📥 Loading TIGER-Lab/MMEB-train …")
    ds_all: DatasetDict = load_dataset(
        "TIGER-Lab/MMEB-train",
        args.subset,
        streaming=False,  # we want to filter the full dataset, not a stream
    )

    def img_exists(example):
        """Return True iff the image file exists after prepending root_dir."""
        return (root_dir / os.path.basename(example["qry_image_path"])).is_file()

    # 2. Filter each split separately so we keep the DatasetDict structure.
    filtered = DatasetDict()
    for split_name, split_ds in ds_all.items():
        print(f"🔍 Filtering split “{split_name}” …")
        # `load_dataset` gives ordinary (non-streaming) Datasets, so `.filter` works in-place.
        kept_ds = split_ds.filter(
            img_exists,
            desc=f"filter-{split_name}",
        )
        print(f"   → kept {len(kept_ds):,} / {len(split_ds):,} rows.")
        filtered[split_name] = kept_ds

    # 3. Optionally push the filtered dataset to disk.
    if args.repo_id:
        print(f"📤 Pushing filtered dataset to {args.repo_id} …")
        filtered.push_to_hub(
            repo_id=args.repo_id,
            private=True,  # default is public, but we want to keep it private
            token=os.getenv("HF_TOKEN"),  # use $HF_TOKEN env var
        )
        print(f"✅ Dataset pushed to {args.repo_id}")
    else:
        print("💾 Not pushing to Hugging Face Hub; keeping in memory only.")
        filtered.save_to_disk(
            args.save_dir or "./mmeb_filtered",
            safe_serialization=True,  # use the new format
        )

if __name__ == "__main__":
    main()
