#!/usr/bin/env python3
import argparse
import os
import pickle
from glob import glob

def iter_pickle_stream(path):
    """Yield successive pickle.load entries until EOF."""
    with open(path, "rb") as f:
        while True:
            try:
                yield pickle.load(f)
            except EOFError:
                break

def main():
    parser = argparse.ArgumentParser(description="Merge per-device MME feature pickles into one file.")
    parser.add_argument("--input-dir", type=str, default="features/eagle_x5_7b_mme",
                        help="Directory containing per-device rank pickles (e.g., 0.pkl, 1.pkl, ...).")
    parser.add_argument("--output", type=str, default="MER/VLMEvalKit/features/eagle_x5_7b_mme/eagle_x5_7b_mme_all.pkl",
                        help="Output pickle path for merged features.")
    args = parser.parse_args()

    rank_files = sorted(
        (p for p in glob(os.path.join(args.input_dir, "*.pkl"))),
        key=lambda p: int(os.path.splitext(os.path.basename(p))[0]) if os.path.splitext(os.path.basename(p))[0].isdigit() else 1e9
    )
    if not rank_files:
        raise FileNotFoundError(f"No .pkl files found in {args.input_dir}")

    merged = []
    for rf in rank_files:
        for entry in iter_pickle_stream(rf):
            # entry is a list[Tensor] (one per backbone) for a batch
            merged.append(entry)

    # Write out a single pickle containing a list where each item is a batch's list[Tensor]
    os.makedirs(os.path.dirname(args.output), exist_ok=True)
    with open(args.output, "wb") as f:
        pickle.dump(merged, f, protocol=pickle.HIGHEST_PROTOCOL)

    print(f"Merged {len(rank_files)} files with {len(merged)} batch entries -> {args.output}")

if __name__ == "__main__":
    main()