import h5py
import os
from tqdm import tqdm
import random

def merge_h5_files(file1, file2, output_file, prefix1="p1_", prefix2="p2_", sampling_rate=0.5):
    """
    Merge two HDF5 datasets into a single file with prefixes and optional random sampling.

    Args:
        file1, file2: input .h5 file paths
        output_file: merged .h5 output path
        prefix1, prefix2: prefixes for group names
        sampling_rate: fraction of groups to keep (0.0–1.0)
    """
    if os.path.exists(output_file):
        raise ValueError(f"Output file {output_file} already exists. Remove it first.")

    with h5py.File(file1, "r") as f1, h5py.File(file2, "r") as f2, h5py.File(output_file, "w") as fout:
        print(f"Copying groups from {file1} (keeping ~{int(sampling_rate*100)}%)...")
        for key in tqdm(f1.keys(), desc="File1"):
            if random.random() <= sampling_rate:
                new_key = prefix1 + key
                f1.copy(key, fout, name=new_key)

        print(f"Copying groups from {file2} (keeping ~{int(sampling_rate*100)}%)...")
        for key in tqdm(f2.keys(), desc="File2"):
            if random.random() <= sampling_rate:
                new_key = prefix2 + key
                f2.copy(key, fout, name=new_key)

    print(f"Merged sampled data into {output_file}")


if __name__ == "__main__":
    merge_h5_files(
        file1="p4rl_assets/inv_dynamics_dataset/exploration_INV_ensemble_0813.h5",
        file2="p4rl_assets/inv_dynamics_dataset/exploration_INV_ensemble_0825_rough.h5",
        output_file="p4rl_assets/inv_dynamics_dataset/inv_dataset_merged_flat_and_rough.h5",
        prefix1="flat_",
        prefix2="rough_",
        sampling_rate=0.5,
    )
