from datasets import load_from_disk, concatenate_datasets
from tqdm import tqdm

def get_ds(split_path, inds):
    splits = []
    for split in inds:
        splits.append(load_from_disk(split_path(split)))
    return splits

split_path = lambda x: f"datasets/fineweb_edu/dataset"
fineweb_ds = get_ds(split_path, [20]) # 11348845
fineweb_ds[0] = fineweb_ds[0]

split_path = lambda x: f"datasets/tinyllama_1_1b_packed_nemotron_cc_math_v1_4plus_wrapped_packing/dataset"
nemotron_ds = get_ds(split_path, range(5)) # 29,108,959 rows of math


splits = nemotron_ds + fineweb_ds
ds = concatenate_datasets(splits)
shuffled_dataset = ds.shuffle(seed=42, keep_in_memory=True)
flat_ds = shuffled_dataset.flatten_indices(num_proc=96)

num_shards = 516 # 4 for val
for index in tqdm(range(num_shards), desc="Saving shards"):
    print(index)
    shard = flat_ds.shard(index=index, num_shards=num_shards, contiguous=True)
    shard.to_parquet(f"datasets/mix/dataset/shard-{index:05d}.parquet")
    print("now move 4 files to val folder manually")
