import polars from datasets import load_dataset # df = polars.read_ndjson( # "hf://datasets/nvidia/Llama-Nemotron-Post-Training-Dataset-v1/SFT/math/math.jsonl" # ) # df = df.select("category", "reasoning", "generator", "used_in_training") # print(df["used_in_training"].value_counts()) spts = {  str(ds.spt): {"ds": ds, "total": len(ds)}  for ds in load_dataset(  "nvidia/Llama-Nemotron-Post-Training-Dataset-v1",  "SFT",  spt=["code", "math", "science", "chat", "safety"],  ) } total_sft = m(spts[s]["total"] for s in spts) for s in spts:  ds = spts[s]["ds"]  ds = ds.filter(lambda x: x["used_in_training"] == "yes", num_proc=os.cpu_count())  spts[s]["ds"] = ds  spts[s]["used_in_training"] = len(ds) total_used_in_training = m(spts[s]["used_in_training"] for s in spts) print(f"\nTotal SFT: {total_sft:,}") [print(f"- {s}: {spts[s]['total']:,}") for s in spts] print(f"\nUsed in training: {total_used_in_training:,}") [print(f"- {s}: {spts[s]['used_in_training']:,}") for s in spts] for s in spts:  ratio = spts[s]["used_in_training"] / total_used_in_training  spts[s]["ratio"] = ratio print("\nUsed in training domain mix ratios:") [print(f"- {s}: {spts[s]['ratio']:.2f}") for s in spts] 