import tqdm
import random 
import h5py 

def copy_attrs(source, dest):
    """Copy all attributes from one HDF5 object to another."""
    for key, value in source.attrs.items():
        dest.attrs[key] = value

def split_hdf5_file(input_file, train_file, val_file, val_split=0.03, seed=42):
    random.seed(seed)

    with h5py.File(input_file, 'r') as infile:
        demo_keys = [key for key in infile["data"].keys() if key.startswith('demo_')]
        random.shuffle(demo_keys)

        split_idx = int(len(demo_keys) * (1 - val_split))
        train_keys = demo_keys[:split_idx]
        val_keys = demo_keys[split_idx:]

        # Create and write training file
        with h5py.File(train_file, 'w') as train_out:
            data_grp = train_out.create_group("data")
            copy_attrs(infile["data"], data_grp)  # Copy file-level attributes
            for key in tqdm.tqdm(train_keys):
                infile["data"].copy(key, data_grp)
                copy_attrs(infile["data"][key], data_grp[key])  # Copy dataset-level attributes

        # Create and write validation file
        with h5py.File(val_file, 'w') as val_out:
            data_grp = val_out.create_group("data")
            copy_attrs(infile["data"], data_grp)  # Copy file-level attributes
            for key in tqdm.tqdm(val_keys):
                infile["data"].copy(key, data_grp)
                copy_attrs(infile["data"][key], data_grp[key])  # Copy dataset-level attributes

    print(f"Split complete: {len(train_keys)} train demos, {len(val_keys)} val demos.")

# Example usage
split_hdf5_file('/yourfolderhere/dataset/UMI_Cup/data.hdf5', 
    '/yourfolderhere/dataset/UMI_Cup/data_train.hdf5', 
    '/yourfolderhere/dataset/UMI_Cup/data_valid.hdf5', val_split=0.03)