#!/usr/bin/env python
import argparse
from pathlib import Path
from typing import Any, Dict

from datasets import load_dataset, Dataset, DatasetDict


def not_truthfulqa(ex: Dict[str, Any]) -> bool:
    """Filter out TruthfulQA-sourced examples (avoid eval leakage)."""
    src = (ex.get("source") or "").strip().lower()
    return (
        "truthful_qa" not in src
        and "truthfulqa" not in src
        and "truthful qa" not in src
    )


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dataset_name", type=str, default="openbmb/UltraFeedback", help="HF dataset name")
    ap.add_argument("--output_dir", type=str, required=True, help="Directory to save train/val/test JSON files")
    ap.add_argument("--seed", type=int, default=42, help="Random seed for splitting")
    args = ap.parse_args()

    # Load original UltraFeedback train split
    ds: Dataset = load_dataset(args.dataset_name, "default", split="train")

    # Filter out TruthfulQA-sourced examples
    ds = ds.filter(not_truthfulqa)

    ds_90_10: DatasetDict = ds.train_test_split(
        test_size=0.10,
        seed=args.seed,
        shuffle=True,
    )
    temp: Dataset = ds_90_10["test"]
    ds_val_test: DatasetDict = temp.train_test_split(
        test_size=0.5,
        seed=args.seed,
        shuffle=True,
    )
    train_ds = ds_90_10["train"]
    val_ds = ds_val_test["train"]
    test_ds = ds_val_test["test"]

    print("Total (filtered) size:", len(ds))
    print("Train size:", len(train_ds))
    print("Val size:", len(val_ds))
    print("Test size:", len(test_ds))

    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    train_path = out_dir / "ultrafeedback_train.jsonl"
    val_path = out_dir / "ultrafeedback_val.jsonl"
    test_path = out_dir / "ultrafeedback_test.jsonl"

    # Save as JSONL (one example per line)
    train_ds.to_json(str(train_path), lines=True)
    val_ds.to_json(str(val_path), lines=True)
    test_ds.to_json(str(test_path), lines=True)

    print(f"Saved train to {train_path}")
    print(f"Saved val   to {val_path}")
    print(f"Saved test  to {test_path}")


if __name__ == "__main__":
    main()
