# mmr_gym/push_splits_to_hub.py (add mathv and normalize all splits before push)
from datasets import load_dataset, Image
import glob

data_files = {
    "train": "dataset/train/*.parquet",
    "test_dg": "dataset/test-dg/*.parquet",
    "test_ood": "dataset/test-ood/*.parquet",
    "test_iid": "dataset/test-iid/*.parquet",
    "test_benchmark": "dataset/test-benchmark/*.parquet",
    "mathv": "dataset/mathv/*.parquet",
}

ds = load_dataset("parquet", data_files=data_files)
expected = ["id","image","problem","answer","task","seed"]

for k in list(ds.keys()):
    if "images" in ds[k].column_names:
        ds[k] = ds[k].map(lambda ex: {"image": ex["images"][0]}).remove_columns(["images"])
    ds[k] = ds[k].cast_column("image", Image())
    keep = [c for c in expected if c in ds[k].column_names]
    ds[k] = ds[k].select_columns(keep)

ds.push_to_hub("Tanvirul/symrl", private=True)
